"""Utilities related to determining the reachability of code (in semantic analysis)."""

from __future__ import annotations

from typing import Final, TypeVar

from mypy.literals import literal
from mypy.nodes import (
    LITERAL_YES,
    AssertStmt,
    Block,
    CallExpr,
    ComparisonExpr,
    Expression,
    FuncDef,
    IfStmt,
    Import,
    ImportAll,
    ImportFrom,
    IndexExpr,
    IntExpr,
    MatchStmt,
    MemberExpr,
    NameExpr,
    OpExpr,
    SliceExpr,
    StrExpr,
    TupleExpr,
    UnaryExpr,
)
from mypy.options import Options
from mypy.patterns import AsPattern, OrPattern, Pattern
from mypy.traverser import TraverserVisitor

# Inferred truth value of an expression.
ALWAYS_TRUE: Final = 1
MYPY_TRUE: Final = 2  # True in mypy, False at runtime
ALWAYS_FALSE: Final = 3
MYPY_FALSE: Final = 4  # False in mypy, True at runtime
TRUTH_VALUE_UNKNOWN: Final = 5

inverted_truth_mapping: Final = {
    ALWAYS_TRUE: ALWAYS_FALSE,
    ALWAYS_FALSE: ALWAYS_TRUE,
    TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
    MYPY_TRUE: MYPY_FALSE,
    MYPY_FALSE: MYPY_TRUE,
}

reverse_op: Final = {"==": "==", "!=": "!=", "<": ">", ">": "<", "<=": ">=", ">=": "<="}


def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
    for i in range(len(s.expr)):
        result = infer_condition_value(s.expr[i], options)
        if result in (ALWAYS_FALSE, MYPY_FALSE):
            # The condition is considered always false, so we skip the if/elif body.
            mark_block_unreachable(s.body[i])
        elif result in (ALWAYS_TRUE, MYPY_TRUE):
            # This condition is considered always true, so all of the remaining
            # elif/else bodies should not be checked.
            if result == MYPY_TRUE:
                # This condition is false at runtime; this will affect
                # import priorities.
                mark_block_mypy_only(s.body[i])
            for body in s.body[i + 1 :]:
                mark_block_unreachable(body)

            # Make sure else body always exists and is marked as
            # unreachable so the type checker always knows that
            # all control flow paths will flow through the if
            # statement body.
            if not s.else_body:
                s.else_body = Block([])
            mark_block_unreachable(s.else_body)
            break


def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None:
    for i, guard in enumerate(s.guards):
        pattern_value = infer_pattern_value(s.patterns[i])

        if guard is not None:
            guard_value = infer_condition_value(guard, options)
        else:
            guard_value = ALWAYS_TRUE

        if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) or guard_value in (
            ALWAYS_FALSE,
            MYPY_FALSE,
        ):
            # The case is considered always false, so we skip the case body.
            mark_block_unreachable(s.bodies[i])
        elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) and guard_value in (
            ALWAYS_TRUE,
            MYPY_TRUE,
        ):
            for body in s.bodies[i + 1 :]:
                mark_block_unreachable(body)

        if guard_value == MYPY_TRUE:
            # This condition is false at runtime; this will affect
            # import priorities.
            mark_block_mypy_only(s.bodies[i])


def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
    return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)


def infer_condition_value(expr: Expression, options: Options) -> int:
    """Infer whether the given condition is always true/false.

    Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
    MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
    false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
    """
    pyversion = options.python_version
    name = ""
    negated = False
    alias = expr
    if isinstance(alias, UnaryExpr):
        if alias.op == "not":
            expr = alias.expr
            negated = True
    result = TRUTH_VALUE_UNKNOWN
    if isinstance(expr, NameExpr):
        name = expr.name
    elif isinstance(expr, MemberExpr):
        name = expr.name
    elif isinstance(expr, OpExpr) and expr.op in ("and", "or"):
        left = infer_condition_value(expr.left, options)
        if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or (
            left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or"
        ):
            # Either `True and <other>` or `False or <other>`: the result will
            # always be the right-hand-side.
            return infer_condition_value(expr.right, options)
        else:
            # The result will always be the left-hand-side (e.g. ALWAYS_* or
            # TRUTH_VALUE_UNKNOWN).
            return left
    else:
        result = consider_sys_version_info(expr, pyversion)
        if result == TRUTH_VALUE_UNKNOWN:
            result = consider_sys_platform(expr, options.platform)
    if result == TRUTH_VALUE_UNKNOWN:
        if name == "PY2":
            result = ALWAYS_FALSE
        elif name == "PY3":
            result = ALWAYS_TRUE
        elif name == "MYPY" or name == "TYPE_CHECKING":
            result = MYPY_TRUE
        elif name in options.always_true:
            result = ALWAYS_TRUE
        elif name in options.always_false:
            result = ALWAYS_FALSE
    if negated:
        result = inverted_truth_mapping[result]
    return result


def infer_pattern_value(pattern: Pattern) -> int:
    if isinstance(pattern, AsPattern) and pattern.pattern is None:
        return ALWAYS_TRUE
    elif isinstance(pattern, OrPattern) and any(
        infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns
    ):
        return ALWAYS_TRUE
    else:
        return TRUTH_VALUE_UNKNOWN


def consider_sys_version_info(expr: Expression, pyversion: tuple[int, ...]) -> int:
    """Consider whether expr is a comparison involving sys.version_info.

    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
    """
    # Cases supported:
    # - sys.version_info[<int>] <compare_op> <int>
    # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
    # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
    #   (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
    if not isinstance(expr, ComparisonExpr):
        return TRUTH_VALUE_UNKNOWN
    # Let's not yet support chained comparisons.
    if len(expr.operators) > 1:
        return TRUTH_VALUE_UNKNOWN
    op = expr.operators[0]
    if op not in ("==", "!=", "<=", ">=", "<", ">"):
        return TRUTH_VALUE_UNKNOWN

    index = contains_sys_version_info(expr.operands[0])
    thing = contains_int_or_tuple_of_ints(expr.operands[1])
    if index is None or thing is None:
        index = contains_sys_version_info(expr.operands[1])
        thing = contains_int_or_tuple_of_ints(expr.operands[0])
        op = reverse_op[op]
    if isinstance(index, int) and isinstance(thing, int):
        # sys.version_info[i] <compare_op> k
        if 0 <= index <= 1:
            return fixed_comparison(pyversion[index], op, thing)
        else:
            return TRUTH_VALUE_UNKNOWN
    elif isinstance(index, tuple) and isinstance(thing, tuple):
        lo, hi = index
        if lo is None:
            lo = 0
        if hi is None:
            hi = 2
        if 0 <= lo < hi <= 2:
            val = pyversion[lo:hi]
            if len(val) == len(thing) or len(val) > len(thing) and op not in ("==", "!="):
                return fixed_comparison(val, op, thing)
    return TRUTH_VALUE_UNKNOWN


def consider_sys_platform(expr: Expression, platform: str) -> int:
    """Consider whether expr is a comparison involving sys.platform.

    Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
    """
    # Cases supported:
    # - sys.platform == 'linux'
    # - sys.platform != 'win32'
    # - sys.platform.startswith('win')
    if isinstance(expr, ComparisonExpr):
        # Let's not yet support chained comparisons.
        if len(expr.operators) > 1:
            return TRUTH_VALUE_UNKNOWN
        op = expr.operators[0]
        if op not in ("==", "!="):
            return TRUTH_VALUE_UNKNOWN
        if not is_sys_attr(expr.operands[0], "platform"):
            return TRUTH_VALUE_UNKNOWN
        right = expr.operands[1]
        if not isinstance(right, StrExpr):
            return TRUTH_VALUE_UNKNOWN
        return fixed_comparison(platform, op, right.value)
    elif isinstance(expr, CallExpr):
        if not isinstance(expr.callee, MemberExpr):
            return TRUTH_VALUE_UNKNOWN
        if len(expr.args) != 1 or not isinstance(expr.args[0], StrExpr):
            return TRUTH_VALUE_UNKNOWN
        if not is_sys_attr(expr.callee.expr, "platform"):
            return TRUTH_VALUE_UNKNOWN
        if expr.callee.name != "startswith":
            return TRUTH_VALUE_UNKNOWN
        if platform.startswith(expr.args[0].value):
            return ALWAYS_TRUE
        else:
            return ALWAYS_FALSE
    else:
        return TRUTH_VALUE_UNKNOWN


Targ = TypeVar("Targ", int, str, tuple[int, ...])


def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
    rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
    if op == "==":
        return rmap[left == right]
    if op == "!=":
        return rmap[left != right]
    if op == "<=":
        return rmap[left <= right]
    if op == ">=":
        return rmap[left >= right]
    if op == "<":
        return rmap[left < right]
    if op == ">":
        return rmap[left > right]
    return TRUTH_VALUE_UNKNOWN


def contains_int_or_tuple_of_ints(expr: Expression) -> None | int | tuple[int, ...]:
    if isinstance(expr, IntExpr):
        return expr.value
    if isinstance(expr, TupleExpr):
        if literal(expr) == LITERAL_YES:
            thing = []
            for x in expr.items:
                if not isinstance(x, IntExpr):
                    return None
                thing.append(x.value)
            return tuple(thing)
    return None


def contains_sys_version_info(expr: Expression) -> None | int | tuple[int | None, int | None]:
    if is_sys_attr(expr, "version_info"):
        return (None, None)  # Same as sys.version_info[:]
    if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, "version_info"):
        index = expr.index
        if isinstance(index, IntExpr):
            return index.value
        if isinstance(index, SliceExpr):
            if index.stride is not None:
                if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
                    return None
            begin = end = None
            if index.begin_index is not None:
                if not isinstance(index.begin_index, IntExpr):
                    return None
                begin = index.begin_index.value
            if index.end_index is not None:
                if not isinstance(index.end_index, IntExpr):
                    return None
                end = index.end_index.value
            return (begin, end)
    return None


def is_sys_attr(expr: Expression, name: str) -> bool:
    # TODO: This currently doesn't work with code like this:
    # - import sys as _sys
    # - from sys import version_info
    if isinstance(expr, MemberExpr) and expr.name == name:
        if isinstance(expr.expr, NameExpr) and expr.expr.name == "sys":
            # TODO: Guard against a local named sys, etc.
            # (Though later passes will still do most checking.)
            return True
    return False


def mark_block_unreachable(block: Block) -> None:
    block.is_unreachable = True
    block.accept(MarkImportsUnreachableVisitor())


class MarkImportsUnreachableVisitor(TraverserVisitor):
    """Visitor that flags all imports nested within a node as unreachable."""

    def visit_import(self, node: Import) -> None:
        node.is_unreachable = True

    def visit_import_from(self, node: ImportFrom) -> None:
        node.is_unreachable = True

    def visit_import_all(self, node: ImportAll) -> None:
        node.is_unreachable = True


def mark_block_mypy_only(block: Block) -> None:
    block.accept(MarkImportsMypyOnlyVisitor())


class MarkImportsMypyOnlyVisitor(TraverserVisitor):
    """Visitor that sets is_mypy_only (which affects priority)."""

    def visit_import(self, node: Import) -> None:
        node.is_mypy_only = True

    def visit_import_from(self, node: ImportFrom) -> None:
        node.is_mypy_only = True

    def visit_import_all(self, node: ImportAll) -> None:
        node.is_mypy_only = True

    def visit_func_def(self, node: FuncDef) -> None:
        node.is_mypy_only = True
