"""Plugin for supporting the functools standard library module."""

from __future__ import annotations

from typing import Final, NamedTuple

import mypy.checker
import mypy.plugin
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import (
    ARG_POS,
    ARG_STAR2,
    SYMBOL_FUNCBASE_TYPES,
    ArgKind,
    Argument,
    CallExpr,
    NameExpr,
    Var,
)
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
    AnyType,
    CallableType,
    Instance,
    Overloaded,
    ParamSpecFlavor,
    ParamSpecType,
    Type,
    TypeOfAny,
    TypeVarType,
    UnboundType,
    UnionType,
    get_proper_type,
)

functools_total_ordering_makers: Final = {"functools.total_ordering"}

_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}

PARTIAL: Final = "functools.partial"


class _MethodInfo(NamedTuple):
    is_static: bool
    type: CallableType


def functools_total_ordering_maker_callback(
    ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False
) -> bool:
    """Add dunder methods to classes decorated with functools.total_ordering."""
    comparison_methods = _analyze_class(ctx)
    if not comparison_methods:
        ctx.api.fail(
            'No ordering operation defined when using "functools.total_ordering": < > <= >=',
            ctx.reason,
        )
        return True

    # prefer __lt__ to __le__ to __gt__ to __ge__
    root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
    root_method = comparison_methods[root]
    if not root_method:
        # None of the defined comparison methods can be analysed
        return True

    other_type = _find_other_type(root_method)
    bool_type = ctx.api.named_type("builtins.bool")
    ret_type: Type = bool_type
    if root_method.type.ret_type != ctx.api.named_type("builtins.bool"):
        proper_ret_type = get_proper_type(root_method.type.ret_type)
        if not (
            isinstance(proper_ret_type, UnboundType)
            and proper_ret_type.name.split(".")[-1] == "bool"
        ):
            ret_type = AnyType(TypeOfAny.implementation_artifact)
    for additional_op in _ORDERING_METHODS:
        # Either the method is not implemented
        # or has an unknown signature that we can now extrapolate.
        if not comparison_methods.get(additional_op):
            args = [Argument(Var("other", other_type), other_type, None, ARG_POS)]
            add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)

    return True


def _find_other_type(method: _MethodInfo) -> Type:
    """Find the type of the ``other`` argument in a comparison method."""
    first_arg_pos = 0 if method.is_static else 1
    cur_pos_arg = 0
    other_arg = None
    for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
        if arg_kind.is_positional():
            if cur_pos_arg == first_arg_pos:
                other_arg = arg_type
                break

            cur_pos_arg += 1
        elif arg_kind != ARG_STAR2:
            other_arg = arg_type
            break

    if other_arg is None:
        return AnyType(TypeOfAny.implementation_artifact)

    return other_arg


def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | None]:
    """Analyze the class body, its parents, and return the comparison methods found."""
    # Traverse the MRO and collect ordering methods.
    comparison_methods: dict[str, _MethodInfo | None] = {}
    # Skip object because total_ordering does not use methods from object
    for cls in ctx.cls.info.mro[:-1]:
        for name in _ORDERING_METHODS:
            if name in cls.names and name not in comparison_methods:
                node = cls.names[name].node
                if isinstance(node, SYMBOL_FUNCBASE_TYPES) and isinstance(node.type, CallableType):
                    comparison_methods[name] = _MethodInfo(node.is_static, node.type)
                    continue

                if isinstance(node, Var):
                    proper_type = get_proper_type(node.type)
                    if isinstance(proper_type, CallableType):
                        comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
                        continue

                comparison_methods[name] = None

    return comparison_methods


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
    """Infer a more precise return type for functools.partial"""
    if not isinstance(ctx.api, mypy.checker.TypeChecker):  # use internals
        return ctx.default_return_type
    if len(ctx.arg_types) != 3:  # fn, *args, **kwargs
        return ctx.default_return_type
    if len(ctx.arg_types[0]) != 1:
        return ctx.default_return_type

    if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
        # TODO: handle overloads, just fall back to whatever the non-plugin code does
        return ctx.default_return_type
    return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0])


def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type:
    if not isinstance(ctx.api, mypy.checker.TypeChecker):  # use internals
        return ctx.default_return_type

    if isinstance(callee_proper := get_proper_type(callee), UnionType):
        return UnionType.make_union(
            [handle_partial_with_callee(ctx, item) for item in callee_proper.items]
        )

    fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type)
    if fn_type is None:
        return ctx.default_return_type

    # We must normalize from the start to have coherent view together with TypeChecker.
    fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()

    last_context = ctx.api.type_context[-1]
    if not fn_type.is_type_obj():
        # We wrap the return type to get use of a possible type context provided by caller.
        # We cannot do this in case of class objects, since otherwise the plugin may get
        # falsely triggered when evaluating the constructed call itself.
        ret_type: Type = ctx.api.named_generic_type(PARTIAL, [fn_type.ret_type])
        wrapped_return = True
    else:
        ret_type = fn_type.ret_type
        # Instead, for class objects we ignore any type context to avoid spurious errors,
        # since the type context will be partial[X] etc., not X.
        ctx.api.type_context[-1] = None
        wrapped_return = False

    # Flatten actual to formal mapping, since this is what check_call() expects.
    actual_args = []
    actual_arg_kinds = []
    actual_arg_names = []
    actual_types = []
    seen_args = set()
    for i, param in enumerate(ctx.args[1:], start=1):
        for j, a in enumerate(param):
            if a in seen_args:
                # Same actual arg can map to multiple formals, but we need to include
                # each one only once.
                continue
            # Here we rely on the fact that expressions are essentially immutable, so
            # they can be compared by identity.
            seen_args.add(a)
            actual_args.append(a)
            actual_arg_kinds.append(ctx.arg_kinds[i][j])
            actual_arg_names.append(ctx.arg_names[i][j])
            actual_types.append(ctx.arg_types[i][j])

    formal_to_actual = map_actuals_to_formals(
        actual_kinds=actual_arg_kinds,
        actual_names=actual_arg_names,
        formal_kinds=fn_type.arg_kinds,
        formal_names=fn_type.arg_names,
        actual_arg_type=lambda i: actual_types[i],
    )

    # We need to remove any type variables that appear only in formals that have
    # no actuals, to avoid eagerly binding them in check_call() below.
    can_infer_ids = set()
    for i, arg_type in enumerate(fn_type.arg_types):
        if not formal_to_actual[i]:
            continue
        can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

    # special_sig="partial" allows omission of args/kwargs typed with ParamSpec
    defaulted = fn_type.copy_modified(
        arg_kinds=[
            (
                ArgKind.ARG_OPT
                if k == ArgKind.ARG_POS
                else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
            )
            for k in fn_type.arg_kinds
        ],
        ret_type=ret_type,
        variables=[
            tv
            for tv in fn_type.variables
            # Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
            if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
        ],
        special_sig="partial",
    )
    if defaulted.line < 0:
        # Make up a line number if we don't have one
        defaulted.set_line(ctx.default_return_type)

    # Create a valid context for various ad-hoc inspections in check_call().
    call_expr = CallExpr(
        callee=ctx.args[0][0],
        args=actual_args,
        arg_kinds=actual_arg_kinds,
        arg_names=actual_arg_names,
        analyzed=ctx.context.analyzed if isinstance(ctx.context, CallExpr) else None,
    )
    call_expr.set_line(ctx.context)

    _, bound = ctx.api.expr_checker.check_call(
        callee=defaulted,
        args=actual_args,
        arg_kinds=actual_arg_kinds,
        arg_names=actual_arg_names,
        context=call_expr,
    )
    if not wrapped_return:
        # Restore previously ignored context.
        ctx.api.type_context[-1] = last_context

    bound = get_proper_type(bound)
    if not isinstance(bound, CallableType):
        return ctx.default_return_type

    if wrapped_return:
        # Reverse the wrapping we did above.
        ret_type = get_proper_type(bound.ret_type)
        if not isinstance(ret_type, Instance) or ret_type.type.fullname != PARTIAL:
            return ctx.default_return_type
        bound = bound.copy_modified(ret_type=ret_type.args[0])

    partial_kinds = []
    partial_types = []
    partial_names = []
    # We need to fully apply any positional arguments (they cannot be respecified)
    # However, keyword arguments can be respecified, so just give them a default
    for i, actuals in enumerate(formal_to_actual):
        if len(bound.arg_types) == len(fn_type.arg_types):
            arg_type = bound.arg_types[i]
            if not mypy.checker.is_valid_inferred_type(arg_type, ctx.api.options):
                arg_type = fn_type.arg_types[i]  # bit of a hack
        else:
            # TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
            # true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
            arg_type = fn_type.arg_types[i]

        if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
            partial_kinds.append(fn_type.arg_kinds[i])
            partial_types.append(arg_type)
            partial_names.append(fn_type.arg_names[i])
        else:
            assert actuals
            if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
                # Don't add params for arguments passed positionally
                continue
            # Add defaulted params for arguments passed via keyword
            kind = actual_arg_kinds[actuals[0]]
            if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
                kind = ArgKind.ARG_NAMED_OPT
            partial_kinds.append(kind)
            partial_types.append(arg_type)
            partial_names.append(fn_type.arg_names[i])

    ret_type = bound.ret_type
    if not mypy.checker.is_valid_inferred_type(ret_type, ctx.api.options):
        ret_type = fn_type.ret_type  # same kind of hack as above

    partially_applied = fn_type.copy_modified(
        arg_types=partial_types,
        arg_kinds=partial_kinds,
        arg_names=partial_names,
        ret_type=ret_type,
        special_sig="partial",
    )

    ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
    ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
    if partially_applied.param_spec():
        assert ret.extra_attrs is not None  # copy_with_extra_attr above ensures this
        attrs = ret.extra_attrs.copy()
        if ArgKind.ARG_STAR in actual_arg_kinds:
            attrs.immutable.add("__mypy_partial_paramspec_args_bound")
        if ArgKind.ARG_STAR2 in actual_arg_kinds:
            attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
        ret.extra_attrs = attrs
    return ret


def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
    """Infer a more precise return type for functools.partial.__call__."""
    if (
        not isinstance(ctx.api, mypy.checker.TypeChecker)  # use internals
        or not isinstance(ctx.type, Instance)
        or ctx.type.type.fullname != PARTIAL
        or not ctx.type.extra_attrs
        or "__mypy_partial" not in ctx.type.extra_attrs.attrs
    ):
        return ctx.default_return_type

    extra_attrs = ctx.type.extra_attrs
    partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
    if len(ctx.arg_types) != 2:  # *args, **kwargs
        return ctx.default_return_type

    # See comments for similar actual to formal code above
    actual_args = []
    actual_arg_kinds = []
    actual_arg_names = []
    seen_args = set()
    for i, param in enumerate(ctx.args):
        for j, a in enumerate(param):
            if a in seen_args:
                continue
            seen_args.add(a)
            actual_args.append(a)
            actual_arg_kinds.append(ctx.arg_kinds[i][j])
            actual_arg_names.append(ctx.arg_names[i][j])

    result, _ = ctx.api.expr_checker.check_call(
        callee=partial_type,
        args=actual_args,
        arg_kinds=actual_arg_kinds,
        arg_names=actual_arg_names,
        context=ctx.context,
    )
    if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
        return result

    args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
    kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable

    passed_paramspec_parts = [
        arg.node.type
        for arg in actual_args
        if isinstance(arg, NameExpr)
        and isinstance(arg.node, Var)
        and isinstance(arg.node.type, ParamSpecType)
    ]
    # ensure *args: P.args
    args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
    if not args_bound and not args_passed:
        ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
    elif args_bound and args_passed:
        ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)

    # ensure **kwargs: P.kwargs
    kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
    if not kwargs_bound and not kwargs_passed:
        ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)

    return result
