from __future__ import annotations

from mypy.nodes import (
    AssertTypeExpr,
    AssignmentStmt,
    CastExpr,
    ClassDef,
    ForStmt,
    FuncItem,
    NamedTupleExpr,
    NewTypeExpr,
    PromoteExpr,
    TypeAlias,
    TypeAliasExpr,
    TypeAliasStmt,
    TypeApplication,
    TypedDictExpr,
    TypeVarExpr,
    Var,
    WithStmt,
)
from mypy.traverser import TraverserVisitor
from mypy.types import Type
from mypy.typetraverser import TypeTraverserVisitor


class MixedTraverserVisitor(TraverserVisitor, TypeTraverserVisitor):
    """Recursive traversal of both Node and Type objects."""

    def __init__(self) -> None:
        self.in_type_alias_expr = False

    # Symbol nodes

    def visit_var(self, var: Var, /) -> None:
        self.visit_optional_type(var.type)

    def visit_func(self, o: FuncItem, /) -> None:
        super().visit_func(o)
        self.visit_optional_type(o.type)

    def visit_class_def(self, o: ClassDef, /) -> None:
        # TODO: Should we visit generated methods/variables as well, either here or in
        #       TraverserVisitor?
        super().visit_class_def(o)
        info = o.info
        if info:
            for base in info.bases:
                base.accept(self)

    def visit_type_alias_expr(self, o: TypeAliasExpr, /) -> None:
        super().visit_type_alias_expr(o)
        o.node.accept(self)

    def visit_type_var_expr(self, o: TypeVarExpr, /) -> None:
        super().visit_type_var_expr(o)
        o.upper_bound.accept(self)
        for value in o.values:
            value.accept(self)

    def visit_typeddict_expr(self, o: TypedDictExpr, /) -> None:
        super().visit_typeddict_expr(o)
        self.visit_optional_type(o.info.typeddict_type)

    def visit_namedtuple_expr(self, o: NamedTupleExpr, /) -> None:
        super().visit_namedtuple_expr(o)
        assert o.info.tuple_type
        o.info.tuple_type.accept(self)

    def visit__promote_expr(self, o: PromoteExpr, /) -> None:
        super().visit__promote_expr(o)
        o.type.accept(self)

    def visit_newtype_expr(self, o: NewTypeExpr, /) -> None:
        super().visit_newtype_expr(o)
        self.visit_optional_type(o.old_type)

    # Statements

    def visit_assignment_stmt(self, o: AssignmentStmt, /) -> None:
        super().visit_assignment_stmt(o)
        self.visit_optional_type(o.type)

    def visit_type_alias_stmt(self, o: TypeAliasStmt, /) -> None:
        super().visit_type_alias_stmt(o)
        if o.alias_node is not None:
            o.alias_node.accept(self)

    def visit_type_alias(self, o: TypeAlias, /) -> None:
        super().visit_type_alias(o)
        self.in_type_alias_expr = True
        o.target.accept(self)
        self.in_type_alias_expr = False

    def visit_for_stmt(self, o: ForStmt, /) -> None:
        super().visit_for_stmt(o)
        self.visit_optional_type(o.index_type)

    def visit_with_stmt(self, o: WithStmt, /) -> None:
        super().visit_with_stmt(o)
        for typ in o.analyzed_types:
            typ.accept(self)

    # Expressions

    def visit_cast_expr(self, o: CastExpr, /) -> None:
        super().visit_cast_expr(o)
        o.type.accept(self)

    def visit_assert_type_expr(self, o: AssertTypeExpr, /) -> None:
        super().visit_assert_type_expr(o)
        o.type.accept(self)

    def visit_type_application(self, o: TypeApplication, /) -> None:
        super().visit_type_application(o)
        for t in o.types:
            t.accept(self)

    # Helpers

    def visit_optional_type(self, t: Type | None, /) -> None:
        if t:
            t.accept(self)
