from __future__ import annotations

from enum import Enum

from mypy import checker, errorcodes
from mypy.messages import MessageBuilder
from mypy.nodes import (
    AssertStmt,
    AssignmentExpr,
    AssignmentStmt,
    BreakStmt,
    ClassDef,
    Context,
    ContinueStmt,
    DictionaryComprehension,
    Expression,
    ExpressionStmt,
    ForStmt,
    FuncDef,
    FuncItem,
    GeneratorExpr,
    GlobalDecl,
    IfStmt,
    Import,
    ImportFrom,
    LambdaExpr,
    ListExpr,
    Lvalue,
    MatchStmt,
    MypyFile,
    NameExpr,
    NonlocalDecl,
    RaiseStmt,
    ReturnStmt,
    StarExpr,
    SymbolTable,
    TryStmt,
    TupleExpr,
    TypeAliasStmt,
    WhileStmt,
    WithStmt,
    implicit_module_attrs,
)
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import Type, UninhabitedType, get_proper_type


class BranchState:
    """BranchState contains information about variable definition at the end of a branching statement.
    `if` and `match` are examples of branching statements.

    `may_be_defined` contains variables that were defined in only some branches.
    `must_be_defined` contains variables that were defined in all branches.
    """

    def __init__(
        self,
        must_be_defined: set[str] | None = None,
        may_be_defined: set[str] | None = None,
        skipped: bool = False,
    ) -> None:
        if may_be_defined is None:
            may_be_defined = set()
        if must_be_defined is None:
            must_be_defined = set()

        self.may_be_defined = set(may_be_defined)
        self.must_be_defined = set(must_be_defined)
        self.skipped = skipped

    def copy(self) -> BranchState:
        return BranchState(
            must_be_defined=set(self.must_be_defined),
            may_be_defined=set(self.may_be_defined),
            skipped=self.skipped,
        )


class BranchStatement:
    def __init__(self, initial_state: BranchState | None = None) -> None:
        if initial_state is None:
            initial_state = BranchState()
        self.initial_state = initial_state
        self.branches: list[BranchState] = [
            BranchState(
                must_be_defined=self.initial_state.must_be_defined,
                may_be_defined=self.initial_state.may_be_defined,
            )
        ]

    def copy(self) -> BranchStatement:
        result = BranchStatement(self.initial_state)
        result.branches = [b.copy() for b in self.branches]
        return result

    def next_branch(self) -> None:
        self.branches.append(
            BranchState(
                must_be_defined=self.initial_state.must_be_defined,
                may_be_defined=self.initial_state.may_be_defined,
            )
        )

    def record_definition(self, name: str) -> None:
        assert len(self.branches) > 0
        self.branches[-1].must_be_defined.add(name)
        self.branches[-1].may_be_defined.discard(name)

    def delete_var(self, name: str) -> None:
        assert len(self.branches) > 0
        self.branches[-1].must_be_defined.discard(name)
        self.branches[-1].may_be_defined.discard(name)

    def record_nested_branch(self, state: BranchState) -> None:
        assert len(self.branches) > 0
        current_branch = self.branches[-1]
        if state.skipped:
            current_branch.skipped = True
            return
        current_branch.must_be_defined.update(state.must_be_defined)
        current_branch.may_be_defined.update(state.may_be_defined)
        current_branch.may_be_defined.difference_update(current_branch.must_be_defined)

    def skip_branch(self) -> None:
        assert len(self.branches) > 0
        self.branches[-1].skipped = True

    def is_possibly_undefined(self, name: str) -> bool:
        assert len(self.branches) > 0
        return name in self.branches[-1].may_be_defined

    def is_undefined(self, name: str) -> bool:
        assert len(self.branches) > 0
        branch = self.branches[-1]
        return name not in branch.may_be_defined and name not in branch.must_be_defined

    def is_defined_in_a_branch(self, name: str) -> bool:
        assert len(self.branches) > 0
        for b in self.branches:
            if name in b.must_be_defined or name in b.may_be_defined:
                return True
        return False

    def done(self) -> BranchState:
        # First, compute all vars, including skipped branches. We include skipped branches
        # because our goal is to capture all variables that semantic analyzer would
        # consider defined.
        all_vars = set()
        for b in self.branches:
            all_vars.update(b.may_be_defined)
            all_vars.update(b.must_be_defined)
        # For the rest of the things, we only care about branches that weren't skipped.
        non_skipped_branches = [b for b in self.branches if not b.skipped]
        if non_skipped_branches:
            must_be_defined = non_skipped_branches[0].must_be_defined
            for b in non_skipped_branches[1:]:
                must_be_defined.intersection_update(b.must_be_defined)
        else:
            must_be_defined = set()
        # Everything that wasn't defined in all branches but was defined
        # in at least one branch should be in `may_be_defined`!
        may_be_defined = all_vars.difference(must_be_defined)
        return BranchState(
            must_be_defined=must_be_defined,
            may_be_defined=may_be_defined,
            skipped=len(non_skipped_branches) == 0,
        )


class ScopeType(Enum):
    Global = 1
    Class = 2
    Func = 3
    Generator = 4


class Scope:
    def __init__(self, stmts: list[BranchStatement], scope_type: ScopeType) -> None:
        self.branch_stmts: list[BranchStatement] = stmts
        self.scope_type = scope_type
        self.undefined_refs: dict[str, set[NameExpr]] = {}

    def copy(self) -> Scope:
        result = Scope([s.copy() for s in self.branch_stmts], self.scope_type)
        result.undefined_refs = self.undefined_refs.copy()
        return result

    def record_undefined_ref(self, o: NameExpr) -> None:
        if o.name not in self.undefined_refs:
            self.undefined_refs[o.name] = set()
        self.undefined_refs[o.name].add(o)

    def pop_undefined_ref(self, name: str) -> set[NameExpr]:
        return self.undefined_refs.pop(name, set())


class DefinedVariableTracker:
    """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""

    def __init__(self) -> None:
        # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
        self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
        # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
        # in things like try/except/finally statements.
        self.disable_branch_skip = False

    def copy(self) -> DefinedVariableTracker:
        result = DefinedVariableTracker()
        result.scopes = [s.copy() for s in self.scopes]
        result.disable_branch_skip = self.disable_branch_skip
        return result

    def _scope(self) -> Scope:
        assert len(self.scopes) > 0
        return self.scopes[-1]

    def enter_scope(self, scope_type: ScopeType) -> None:
        assert len(self._scope().branch_stmts) > 0
        initial_state = None
        if scope_type == ScopeType.Generator:
            # Generators are special because they inherit the outer scope.
            initial_state = self._scope().branch_stmts[-1].branches[-1]
        self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))

    def exit_scope(self) -> None:
        self.scopes.pop()

    def in_scope(self, scope_type: ScopeType) -> bool:
        return self._scope().scope_type == scope_type

    def start_branch_statement(self) -> None:
        assert len(self._scope().branch_stmts) > 0
        self._scope().branch_stmts.append(
            BranchStatement(self._scope().branch_stmts[-1].branches[-1])
        )

    def next_branch(self) -> None:
        assert len(self._scope().branch_stmts) > 1
        self._scope().branch_stmts[-1].next_branch()

    def end_branch_statement(self) -> None:
        assert len(self._scope().branch_stmts) > 1
        result = self._scope().branch_stmts.pop().done()
        self._scope().branch_stmts[-1].record_nested_branch(result)

    def skip_branch(self) -> None:
        # Only skip branch if we're outside of "root" branch statement.
        if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
            self._scope().branch_stmts[-1].skip_branch()

    def record_definition(self, name: str) -> None:
        assert len(self.scopes) > 0
        assert len(self.scopes[-1].branch_stmts) > 0
        self._scope().branch_stmts[-1].record_definition(name)

    def delete_var(self, name: str) -> None:
        assert len(self.scopes) > 0
        assert len(self.scopes[-1].branch_stmts) > 0
        self._scope().branch_stmts[-1].delete_var(name)

    def record_undefined_ref(self, o: NameExpr) -> None:
        """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
        assert len(self.scopes) > 0
        self._scope().record_undefined_ref(o)

    def pop_undefined_ref(self, name: str) -> set[NameExpr]:
        """If name has previously been reported as undefined, the NameExpr that was called will be returned."""
        assert len(self.scopes) > 0
        return self._scope().pop_undefined_ref(name)

    def is_possibly_undefined(self, name: str) -> bool:
        assert len(self._scope().branch_stmts) > 0
        # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
        return self._scope().branch_stmts[-1].is_possibly_undefined(name)

    def is_defined_in_different_branch(self, name: str) -> bool:
        """This will return true if a variable is defined in a branch that's not the current branch."""
        assert len(self._scope().branch_stmts) > 0
        stmt = self._scope().branch_stmts[-1]
        if not stmt.is_undefined(name):
            return False
        for stmt in self._scope().branch_stmts:
            if stmt.is_defined_in_a_branch(name):
                return True
        return False

    def is_undefined(self, name: str) -> bool:
        assert len(self._scope().branch_stmts) > 0
        return self._scope().branch_stmts[-1].is_undefined(name)


class Loop:
    def __init__(self) -> None:
        self.has_break = False


class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor):
    """Detects the following cases:
    - A variable that's defined only part of the time.
    - If a variable is used before definition

    An example of a partial definition:
    if foo():
        x = 1
    print(x)  # Error: "x" may be undefined.

    Example of a used before definition:
    x = y
    y: int = 2

    Note that this code does not detect variables not defined in any of the branches -- that is
    handled by the semantic analyzer.
    """

    def __init__(
        self,
        msg: MessageBuilder,
        type_map: dict[Expression, Type],
        options: Options,
        names: SymbolTable,
    ) -> None:
        self.msg = msg
        self.type_map = type_map
        self.options = options
        self.builtins = SymbolTable()
        builtins_mod = names.get("__builtins__", None)
        if builtins_mod:
            assert isinstance(builtins_mod.node, MypyFile)
            self.builtins = builtins_mod.node.names
        self.loops: list[Loop] = []
        self.try_depth = 0
        self.tracker = DefinedVariableTracker()
        for name in implicit_module_attrs:
            self.tracker.record_definition(name)

    def var_used_before_def(self, name: str, context: Context) -> None:
        if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF):
            self.msg.var_used_before_def(name, context)

    def variable_may_be_undefined(self, name: str, context: Context) -> None:
        if self.msg.errors.is_error_code_enabled(errorcodes.POSSIBLY_UNDEFINED):
            self.msg.variable_may_be_undefined(name, context)

    def process_definition(self, name: str) -> None:
        # Was this name previously used? If yes, it's a used-before-definition error.
        if not self.tracker.in_scope(ScopeType.Class):
            refs = self.tracker.pop_undefined_ref(name)
            for ref in refs:
                if self.loops:
                    self.variable_may_be_undefined(name, ref)
                else:
                    self.var_used_before_def(name, ref)
        else:
            # Errors in class scopes are caught by the semantic analyzer.
            pass
        self.tracker.record_definition(name)

    def visit_global_decl(self, o: GlobalDecl) -> None:
        for name in o.names:
            self.process_definition(name)
        super().visit_global_decl(o)

    def visit_nonlocal_decl(self, o: NonlocalDecl) -> None:
        for name in o.names:
            self.process_definition(name)
        super().visit_nonlocal_decl(o)

    def process_lvalue(self, lvalue: Lvalue | None) -> None:
        if isinstance(lvalue, NameExpr):
            self.process_definition(lvalue.name)
        elif isinstance(lvalue, StarExpr):
            self.process_lvalue(lvalue.expr)
        elif isinstance(lvalue, (ListExpr, TupleExpr)):
            for item in lvalue.items:
                self.process_lvalue(item)

    def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
        for lvalue in o.lvalues:
            self.process_lvalue(lvalue)
        super().visit_assignment_stmt(o)

    def visit_assignment_expr(self, o: AssignmentExpr) -> None:
        o.value.accept(self)
        self.process_lvalue(o.target)

    def visit_if_stmt(self, o: IfStmt) -> None:
        for e in o.expr:
            e.accept(self)
        self.tracker.start_branch_statement()
        for b in o.body:
            if b.is_unreachable:
                continue
            b.accept(self)
            self.tracker.next_branch()
        if o.else_body:
            if not o.else_body.is_unreachable:
                o.else_body.accept(self)
            else:
                self.tracker.skip_branch()
        self.tracker.end_branch_statement()

    def visit_match_stmt(self, o: MatchStmt) -> None:
        o.subject.accept(self)
        self.tracker.start_branch_statement()
        for i in range(len(o.patterns)):
            pattern = o.patterns[i]
            pattern.accept(self)
            guard = o.guards[i]
            if guard is not None:
                guard.accept(self)
            if not o.bodies[i].is_unreachable:
                o.bodies[i].accept(self)
            else:
                self.tracker.skip_branch()
            is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE
            if not is_catchall:
                self.tracker.next_branch()
        self.tracker.end_branch_statement()

    def visit_func_def(self, o: FuncDef) -> None:
        self.process_definition(o.name)
        super().visit_func_def(o)

    def visit_func(self, o: FuncItem) -> None:
        if o.is_dynamic() and not self.options.check_untyped_defs:
            return

        args = o.arguments or []
        # Process initializers (defaults) outside the function scope.
        for arg in args:
            if arg.initializer is not None:
                arg.initializer.accept(self)

        self.tracker.enter_scope(ScopeType.Func)
        for arg in args:
            self.process_definition(arg.variable.name)
            super().visit_var(arg.variable)
        o.body.accept(self)
        self.tracker.exit_scope()

    def visit_generator_expr(self, o: GeneratorExpr) -> None:
        self.tracker.enter_scope(ScopeType.Generator)
        for idx in o.indices:
            self.process_lvalue(idx)
        super().visit_generator_expr(o)
        self.tracker.exit_scope()

    def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
        self.tracker.enter_scope(ScopeType.Generator)
        for idx in o.indices:
            self.process_lvalue(idx)
        super().visit_dictionary_comprehension(o)
        self.tracker.exit_scope()

    def visit_for_stmt(self, o: ForStmt) -> None:
        o.expr.accept(self)
        self.process_lvalue(o.index)
        o.index.accept(self)
        self.tracker.start_branch_statement()
        loop = Loop()
        self.loops.append(loop)
        o.body.accept(self)
        self.tracker.next_branch()
        self.tracker.end_branch_statement()
        if o.else_body is not None:
            # If the loop has a `break` inside, `else` is executed conditionally.
            # If the loop doesn't have a `break` either the function will return or
            # execute the `else`.
            has_break = loop.has_break
            if has_break:
                self.tracker.start_branch_statement()
                self.tracker.next_branch()
            o.else_body.accept(self)
            if has_break:
                self.tracker.end_branch_statement()
        self.loops.pop()

    def visit_return_stmt(self, o: ReturnStmt) -> None:
        super().visit_return_stmt(o)
        self.tracker.skip_branch()

    def visit_lambda_expr(self, o: LambdaExpr) -> None:
        self.tracker.enter_scope(ScopeType.Func)
        super().visit_lambda_expr(o)
        self.tracker.exit_scope()

    def visit_assert_stmt(self, o: AssertStmt) -> None:
        super().visit_assert_stmt(o)
        if checker.is_false_literal(o.expr):
            self.tracker.skip_branch()

    def visit_raise_stmt(self, o: RaiseStmt) -> None:
        super().visit_raise_stmt(o)
        self.tracker.skip_branch()

    def visit_continue_stmt(self, o: ContinueStmt) -> None:
        super().visit_continue_stmt(o)
        self.tracker.skip_branch()

    def visit_break_stmt(self, o: BreakStmt) -> None:
        super().visit_break_stmt(o)
        if self.loops:
            self.loops[-1].has_break = True
        self.tracker.skip_branch()

    def visit_expression_stmt(self, o: ExpressionStmt) -> None:
        typ = self.type_map.get(o.expr)
        if typ is None or isinstance(get_proper_type(typ), UninhabitedType):
            self.tracker.skip_branch()
        super().visit_expression_stmt(o)

    def visit_try_stmt(self, o: TryStmt) -> None:
        """
        Note that finding undefined vars in `finally` requires different handling from
        the rest of the code. In particular, we want to disallow skipping branches due to jump
        statements in except/else clauses for finally but not for other cases. Imagine a case like:
        def f() -> int:
            try:
                x = 1
            except:
                # This jump statement needs to be handled differently depending on whether or
                # not we're trying to process `finally` or not.
                return 0
            finally:
                # `x` may be undefined here.
                pass
            # `x` is always defined here.
            return x
        """
        self.try_depth += 1
        if o.finally_body is not None:
            # In order to find undefined vars in `finally`, we need to
            # process try/except with branch skipping disabled. However, for the rest of the code
            # after finally, we need to process try/except with branch skipping enabled.
            # Therefore, we need to process try/finally twice.
            # Because processing is not idempotent, we should make a copy of the tracker.
            old_tracker = self.tracker.copy()
            self.tracker.disable_branch_skip = True
            self.process_try_stmt(o)
            self.tracker = old_tracker
        self.process_try_stmt(o)
        self.try_depth -= 1

    def process_try_stmt(self, o: TryStmt) -> None:
        """
        Processes try statement decomposing it into the following:
        if ...:
            body
            else_body
        elif ...:
            except 1
        elif ...:
            except 2
        else:
            except n
        finally
        """
        self.tracker.start_branch_statement()
        o.body.accept(self)
        if o.else_body is not None:
            o.else_body.accept(self)
        if len(o.handlers) > 0:
            assert len(o.handlers) == len(o.vars) == len(o.types)
            for i in range(len(o.handlers)):
                self.tracker.next_branch()
                exc_type = o.types[i]
                if exc_type is not None:
                    exc_type.accept(self)
                var = o.vars[i]
                if var is not None:
                    self.process_definition(var.name)
                    var.accept(self)
                o.handlers[i].accept(self)
                if var is not None:
                    self.tracker.delete_var(var.name)
        self.tracker.end_branch_statement()

        if o.finally_body is not None:
            o.finally_body.accept(self)

    def visit_while_stmt(self, o: WhileStmt) -> None:
        o.expr.accept(self)
        self.tracker.start_branch_statement()
        loop = Loop()
        self.loops.append(loop)
        o.body.accept(self)
        has_break = loop.has_break
        if not checker.is_true_literal(o.expr):
            # If this is a loop like `while True`, we can consider the body to be
            # a single branch statement (we're guaranteed that the body is executed at least once).
            # If not, call next_branch() to make all variables defined there conditional.
            self.tracker.next_branch()
        self.tracker.end_branch_statement()
        if o.else_body is not None:
            # If the loop has a `break` inside, `else` is executed conditionally.
            # If the loop doesn't have a `break` either the function will return or
            # execute the `else`.
            if has_break:
                self.tracker.start_branch_statement()
                self.tracker.next_branch()
            if o.else_body:
                o.else_body.accept(self)
            if has_break:
                self.tracker.end_branch_statement()
        self.loops.pop()

    def visit_as_pattern(self, o: AsPattern) -> None:
        if o.name is not None:
            self.process_lvalue(o.name)
        super().visit_as_pattern(o)

    def visit_starred_pattern(self, o: StarredPattern) -> None:
        if o.capture is not None:
            self.process_lvalue(o.capture)
        super().visit_starred_pattern(o)

    def visit_name_expr(self, o: NameExpr) -> None:
        if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
            return
        if self.tracker.is_possibly_undefined(o.name):
            # A variable is only defined in some branches.
            self.variable_may_be_undefined(o.name, o)
            # We don't want to report the error on the same variable multiple times.
            self.tracker.record_definition(o.name)
        elif self.tracker.is_defined_in_different_branch(o.name):
            # A variable is defined in one branch but used in a different branch.
            if self.loops or self.try_depth > 0:
                # If we're in a loop or in a try, we can't be sure that this variable
                # is undefined. Report it as "may be undefined".
                self.variable_may_be_undefined(o.name, o)
            else:
                self.var_used_before_def(o.name, o)
        elif self.tracker.is_undefined(o.name):
            # A variable is undefined. It could be due to two things:
            # 1. A variable is just totally undefined
            # 2. The variable is defined later in the code.
            # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should
            # be caught by this visitor. Save the ref for later, so that if we see a definition,
            # we know it's a used-before-definition scenario.
            self.tracker.record_undefined_ref(o)
        super().visit_name_expr(o)

    def visit_with_stmt(self, o: WithStmt) -> None:
        for expr, idx in zip(o.expr, o.target):
            expr.accept(self)
            self.process_lvalue(idx)
        o.body.accept(self)

    def visit_class_def(self, o: ClassDef) -> None:
        self.process_definition(o.name)
        self.tracker.enter_scope(ScopeType.Class)
        super().visit_class_def(o)
        self.tracker.exit_scope()

    def visit_import(self, o: Import) -> None:
        for mod, alias in o.ids:
            if alias is not None:
                self.tracker.record_definition(alias)
            else:
                # When you do `import x.y`, only `x` becomes defined.
                names = mod.split(".")
                if names:
                    # `names` should always be nonempty, but we don't want mypy
                    # to crash on invalid code.
                    self.tracker.record_definition(names[0])
        super().visit_import(o)

    def visit_import_from(self, o: ImportFrom) -> None:
        for mod, alias in o.names:
            name = alias
            if name is None:
                name = mod
            self.tracker.record_definition(name)
        super().visit_import_from(o)

    def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
        # Type alias target may contain forward references
        self.tracker.record_definition(o.name.name)
