from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from typing import Final, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, FakeInfo, Var
from mypy.state import state
from mypy.types import (
    ANY_STRATEGY,
    AnyType,
    BoolTypeQuery,
    CallableType,
    DeletedType,
    ErasedType,
    FunctionLike,
    Instance,
    LiteralType,
    NoneType,
    Overloaded,
    Parameters,
    ParamSpecFlavor,
    ParamSpecType,
    PartialType,
    ProperType,
    TrivialSyntheticTypeTranslator,
    TupleType,
    Type,
    TypeAliasType,
    TypedDictType,
    TypeOfAny,
    TypeType,
    TypeVarId,
    TypeVarLikeType,
    TypeVarTupleType,
    TypeVarType,
    UnboundType,
    UninhabitedType,
    UnionType,
    UnpackType,
    flatten_nested_unions,
    get_proper_type,
    split_with_prefix_and_suffix,
)
from mypy.typevartuples import split_with_instance

# Solving the import cycle:
import mypy.type_visitor  # ruff: isort: skip

# WARNING: these functions should never (directly or indirectly) depend on
# is_subtype(), meet_types(), join_types() etc.
# TODO: add a static dependency test for this.


@overload
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...


@overload
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...


@overload
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...


def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
    """Substitute any type variable references in a type given by a type
    environment.
    """
    return typ.accept(ExpandTypeVisitor(env))


@overload
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ...


@overload
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ...


@overload
def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ...


def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
    """Substitute type variables in type using values from an Instance.
    Type variables are considered to be bound by the class declaration."""
    if not instance.args and not instance.type.has_type_var_tuple_type:
        return typ
    else:
        variables: dict[TypeVarId, Type] = {}
        if instance.type.has_type_var_tuple_type:
            assert instance.type.type_var_tuple_prefix is not None
            assert instance.type.type_var_tuple_suffix is not None

            args_prefix, args_middle, args_suffix = split_with_instance(instance)
            tvars_prefix, tvars_middle, tvars_suffix = split_with_prefix_and_suffix(
                tuple(instance.type.defn.type_vars),
                instance.type.type_var_tuple_prefix,
                instance.type.type_var_tuple_suffix,
            )
            tvar = tvars_middle[0]
            assert isinstance(tvar, TypeVarTupleType)
            variables = {tvar.id: TupleType(list(args_middle), tvar.tuple_fallback)}
            instance_args = args_prefix + args_suffix
            tvars = tvars_prefix + tvars_suffix
        else:
            tvars = tuple(instance.type.defn.type_vars)
            instance_args = instance.args

        for binder, arg in zip(tvars, instance_args):
            assert isinstance(binder, TypeVarLikeType)
            variables[binder.id] = arg

        return expand_type(typ, variables)


F = TypeVar("F", bound=FunctionLike)


def freshen_function_type_vars(callee: F) -> F:
    """Substitute fresh type variables for generic function type variables."""
    if isinstance(callee, CallableType):
        if not callee.is_generic():
            return cast(F, callee)
        tvs = []
        tvmap: dict[TypeVarId, Type] = {}
        for v in callee.variables:
            tv = v.new_unification_variable(v)
            tvs.append(tv)
            tvmap[v.id] = tv
        fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
        return cast(F, fresh)
    else:
        assert isinstance(callee, Overloaded)
        fresh_overload = Overloaded([freshen_function_type_vars(item) for item in callee.items])
        return cast(F, fresh_overload)


class HasGenericCallable(BoolTypeQuery):
    def __init__(self) -> None:
        super().__init__(ANY_STRATEGY)

    def visit_callable_type(self, t: CallableType) -> bool:
        return t.is_generic() or super().visit_callable_type(t)


# Share a singleton since this is performance sensitive
has_generic_callable: Final = HasGenericCallable()


T = TypeVar("T", bound=Type)


def freshen_all_functions_type_vars(t: T) -> T:
    result: Type
    has_generic_callable.reset()
    if not t.accept(has_generic_callable):
        return t  # Fast path to avoid expensive freshening
    else:
        result = t.accept(FreshenCallableVisitor())
        assert isinstance(result, type(t))
        return result


class FreshenCallableVisitor(mypy.type_visitor.TypeTranslator):
    def visit_callable_type(self, t: CallableType) -> Type:
        result = super().visit_callable_type(t)
        assert isinstance(result, ProperType) and isinstance(result, CallableType)
        return freshen_function_type_vars(result)

    def visit_type_alias_type(self, t: TypeAliasType) -> Type:
        # Same as for ExpandTypeVisitor
        return t.copy_modified(args=[arg.accept(self) for arg in t.args])


class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
    """Visitor that substitutes type variables with values."""

    variables: Mapping[TypeVarId, Type]  # TypeVar id -> TypeVar value

    def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
        super().__init__()
        self.variables = variables
        self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}

    def visit_unbound_type(self, t: UnboundType) -> Type:
        return t

    def visit_any(self, t: AnyType) -> Type:
        return t

    def visit_none_type(self, t: NoneType) -> Type:
        return t

    def visit_uninhabited_type(self, t: UninhabitedType) -> Type:
        return t

    def visit_deleted_type(self, t: DeletedType) -> Type:
        return t

    def visit_erased_type(self, t: ErasedType) -> Type:
        # This may happen during type inference if some function argument
        # type is a generic callable, and its erased form will appear in inferred
        # constraints, then solver may check subtyping between them, which will trigger
        # unify_generic_callables(), this is why we can get here. Another example is
        # when inferring type of lambda in generic context, the lambda body contains
        # a generic method in generic class.
        return t

    def visit_instance(self, t: Instance) -> Type:
        args = self.expand_types_with_unpack(list(t.args))

        if isinstance(t.type, FakeInfo):
            # The type checker expands function definitions and bodies
            # if they depend on constrained type variables but the body
            # might contain a tuple type comment (e.g., # type: (int, float)),
            # in which case 't.type' is not yet available.
            #
            # See: https://github.com/python/mypy/issues/16649
            return t.copy_modified(args=args)

        if t.type.fullname == "builtins.tuple":
            # Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
            arg = args[0]
            if isinstance(arg, UnpackType):
                unpacked = get_proper_type(arg.type)
                if isinstance(unpacked, Instance):
                    # TODO: this and similar asserts below may be unsafe because get_proper_type()
                    # may be called during semantic analysis before all invalid types are removed.
                    assert unpacked.type.fullname == "builtins.tuple"
                    args = list(unpacked.args)
        return t.copy_modified(args=args)

    def visit_type_var(self, t: TypeVarType) -> Type:
        # Normally upper bounds can't contain other type variables, the only exception is
        # special type variable Self`0 <: C[T, S], where C is the class where Self is used.
        if t.id.is_self():
            t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
        repl = self.variables.get(t.id, t)
        if isinstance(repl, ProperType) and isinstance(repl, Instance):
            # TODO: do we really need to do this?
            # If I try to remove this special-casing ~40 tests fail on reveal_type().
            return repl.copy_modified(last_known_value=None)
        if isinstance(repl, TypeVarType) and repl.has_default():
            if (tvar_id := repl.id) in self.recursive_tvar_guard:
                return self.recursive_tvar_guard[tvar_id] or repl
            self.recursive_tvar_guard[tvar_id] = None
            repl = repl.accept(self)
            if isinstance(repl, TypeVarType):
                repl.default = repl.default.accept(self)
            self.recursive_tvar_guard[tvar_id] = repl
        return repl

    def visit_param_spec(self, t: ParamSpecType) -> Type:
        # Set prefix to something empty, so we don't duplicate it below.
        repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
        if isinstance(repl, ParamSpecType):
            return repl.copy_modified(
                flavor=t.flavor,
                prefix=t.prefix.copy_modified(
                    arg_types=self.expand_types(t.prefix.arg_types) + repl.prefix.arg_types,
                    arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
                    arg_names=t.prefix.arg_names + repl.prefix.arg_names,
                ),
            )
        elif isinstance(repl, Parameters):
            assert t.flavor == ParamSpecFlavor.BARE
            return Parameters(
                self.expand_types(t.prefix.arg_types) + repl.arg_types,
                t.prefix.arg_kinds + repl.arg_kinds,
                t.prefix.arg_names + repl.arg_names,
                variables=[*t.prefix.variables, *repl.variables],
                imprecise_arg_kinds=repl.imprecise_arg_kinds,
            )
        else:
            # We could encode Any as trivial parameters etc., but it would be too verbose.
            # TODO: assert this is a trivial type, like Any, Never, or object.
            return repl

    def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
        # Sometimes solver may need to expand a type variable with (a copy of) itself
        # (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
        repl = self.variables.get(t.id, t)
        if isinstance(repl, TypeVarTupleType):
            return repl
        elif isinstance(repl, ProperType) and isinstance(repl, (AnyType, UninhabitedType)):
            # Some failed inference scenarios will try to set all type variables to Never.
            # Instead of being picky and require all the callers to wrap them,
            # do this here instead.
            # Note: most cases when this happens are handled in expand unpack below, but
            # in rare cases (e.g. ParamSpec containing Unpack star args) it may be skipped.
            return t.tuple_fallback.copy_modified(args=[repl])
        raise NotImplementedError

    def visit_unpack_type(self, t: UnpackType) -> Type:
        # It is impossible to reasonably implement visit_unpack_type, because
        # unpacking inherently expands to something more like a list of types.
        #
        # Relevant sections that can call unpack should call expand_unpack()
        # instead.
        # However, if the item is a variadic tuple, we can simply carry it over.
        # In particular, if we expand A[*tuple[T, ...]] with substitutions {T: str},
        # it is hard to assert this without getting proper type. Another important
        # example is non-normalized types when called from semanal.py.
        return UnpackType(t.type.accept(self))

    def expand_unpack(self, t: UnpackType) -> list[Type]:
        assert isinstance(t.type, TypeVarTupleType)
        repl = get_proper_type(self.variables.get(t.type.id, t.type))
        if isinstance(repl, UnpackType):
            repl = get_proper_type(repl.type)
        if isinstance(repl, TupleType):
            return repl.items
        elif (
            isinstance(repl, Instance)
            and repl.type.fullname == "builtins.tuple"
            or isinstance(repl, TypeVarTupleType)
        ):
            return [UnpackType(typ=repl)]
        elif isinstance(repl, (AnyType, UninhabitedType)):
            # Replace *Ts = Any with *Ts = *tuple[Any, ...] and same for Never.
            # These types may appear here as a result of user error or failed inference.
            return [UnpackType(t.type.tuple_fallback.copy_modified(args=[repl]))]
        else:
            raise RuntimeError(f"Invalid type replacement to expand: {repl}")

    def visit_parameters(self, t: Parameters) -> Type:
        return t.copy_modified(arg_types=self.expand_types(t.arg_types))

    def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> list[Type]:
        star_index = t.arg_kinds.index(ARG_STAR)
        prefix = self.expand_types(t.arg_types[:star_index])
        suffix = self.expand_types(t.arg_types[star_index + 1 :])

        var_arg_type = get_proper_type(var_arg.type)
        new_unpack: Type
        if isinstance(var_arg_type, TupleType):
            # We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
            expanded_tuple = var_arg_type.accept(self)
            assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
            expanded_items = expanded_tuple.items
            fallback = var_arg_type.partial_fallback
            new_unpack = UnpackType(TupleType(expanded_items, fallback))
        elif isinstance(var_arg_type, TypeVarTupleType):
            # We have plain Unpack[Ts]
            fallback = var_arg_type.tuple_fallback
            expanded_items = self.expand_unpack(var_arg)
            new_unpack = UnpackType(TupleType(expanded_items, fallback))
        # Since get_proper_type() may be called in semanal.py before callable
        # normalization happens, we need to also handle non-normal cases here.
        elif isinstance(var_arg_type, Instance):
            # we have something like Unpack[Tuple[Any, ...]]
            new_unpack = UnpackType(var_arg.type.accept(self))
        else:
            # We have invalid type in Unpack. This can happen when expanding aliases
            # to Callable[[*Invalid], Ret]
            new_unpack = AnyType(TypeOfAny.from_error, line=var_arg.line, column=var_arg.column)
        return prefix + [new_unpack] + suffix

    def visit_callable_type(self, t: CallableType) -> CallableType:
        param_spec = t.param_spec()
        if param_spec is not None:
            repl = self.variables.get(param_spec.id)
            # If a ParamSpec in a callable type is substituted with a
            # callable type, we can't use normal substitution logic,
            # since ParamSpec is actually split into two components
            # *P.args and **P.kwargs in the original type. Instead, we
            # must expand both of them with all the argument types,
            # kinds and names in the replacement. The return type in
            # the replacement is ignored.
            if isinstance(repl, Parameters):
                # We need to expand both the types in the prefix and the ParamSpec itself
                expanded = t.copy_modified(
                    arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
                    arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
                    arg_names=t.arg_names[:-2] + repl.arg_names,
                    ret_type=t.ret_type.accept(self),
                    type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
                    type_is=(t.type_is.accept(self) if t.type_is is not None else None),
                    imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
                    variables=[*repl.variables, *t.variables],
                )
                var_arg = expanded.var_arg()
                if var_arg is not None and isinstance(var_arg.typ, UnpackType):
                    # Sometimes we get new unpacks after expanding ParamSpec.
                    expanded.normalize_trivial_unpack()
                return expanded
            elif isinstance(repl, ParamSpecType):
                # We're substituting one ParamSpec for another; this can mean that the prefix
                # changes, e.g. substitute Concatenate[int, P] in place of Q.
                prefix = repl.prefix
                clean_repl = repl.copy_modified(prefix=Parameters([], [], []))
                return t.copy_modified(
                    arg_types=self.expand_types(t.arg_types[:-2])
                    + prefix.arg_types
                    + [
                        clean_repl.with_flavor(ParamSpecFlavor.ARGS),
                        clean_repl.with_flavor(ParamSpecFlavor.KWARGS),
                    ],
                    arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
                    arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
                    ret_type=t.ret_type.accept(self),
                    from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
                    imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
                )

        var_arg = t.var_arg()
        needs_normalization = False
        if var_arg is not None and isinstance(var_arg.typ, UnpackType):
            needs_normalization = True
            arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
        else:
            arg_types = self.expand_types(t.arg_types)
        expanded = t.copy_modified(
            arg_types=arg_types,
            ret_type=t.ret_type.accept(self),
            type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
            type_is=(t.type_is.accept(self) if t.type_is is not None else None),
        )
        if needs_normalization:
            return expanded.with_normalized_var_args()
        return expanded

    def visit_overloaded(self, t: Overloaded) -> Type:
        items: list[CallableType] = []
        for item in t.items:
            new_item = item.accept(self)
            assert isinstance(new_item, ProperType)
            assert isinstance(new_item, CallableType)
            items.append(new_item)
        return Overloaded(items)

    def expand_types_with_unpack(self, typs: Sequence[Type]) -> list[Type]:
        """Expands a list of types that has an unpack."""
        items: list[Type] = []
        for item in typs:
            if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
                items.extend(self.expand_unpack(item))
            else:
                items.append(item.accept(self))
        return items

    def visit_tuple_type(self, t: TupleType) -> Type:
        items = self.expand_types_with_unpack(t.items)
        if len(items) == 1:
            # Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
            item = items[0]
            if isinstance(item, UnpackType):
                unpacked = get_proper_type(item.type)
                if isinstance(unpacked, Instance):
                    assert unpacked.type.fullname == "builtins.tuple"
                    if t.partial_fallback.type.fullname != "builtins.tuple":
                        # If it is a subtype (like named tuple) we need to preserve it,
                        # this essentially mimics the logic in tuple_fallback().
                        return t.partial_fallback.accept(self)
                    return unpacked
        fallback = t.partial_fallback.accept(self)
        assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
        return t.copy_modified(items=items, fallback=fallback)

    def visit_typeddict_type(self, t: TypedDictType) -> Type:
        if cached := self.get_cached(t):
            return cached
        fallback = t.fallback.accept(self)
        assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
        result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
        self.set_cached(t, result)
        return result

    def visit_literal_type(self, t: LiteralType) -> Type:
        # TODO: Verify this implementation is correct
        return t

    def visit_union_type(self, t: UnionType) -> Type:
        # Use cache to avoid O(n**2) or worse expansion of types during translation
        # (only for large unions, since caching adds overhead)
        use_cache = len(t.items) > 3
        if use_cache and (cached := self.get_cached(t)):
            return cached

        expanded = self.expand_types(t.items)
        # After substituting for type variables in t.items, some resulting types
        # might be subtypes of others, however calling  make_simplified_union()
        # can cause recursion, so we just remove strict duplicates.
        simplified = UnionType.make_union(
            remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
        )
        # This call to get_proper_type() is unfortunate but is required to preserve
        # the invariant that ProperType will stay ProperType after applying expand_type(),
        # otherwise a single item union of a type alias will break it. Note this should not
        # cause infinite recursion since pathological aliases like A = Union[A, B] are
        # banned at the semantic analysis level.
        result = get_proper_type(simplified)

        if use_cache:
            self.set_cached(t, result)
        return result

    def visit_partial_type(self, t: PartialType) -> Type:
        return t

    def visit_type_type(self, t: TypeType) -> Type:
        # TODO: Verify that the new item type is valid (instance or
        # union of instances or Any).  Sadly we can't report errors
        # here yet.
        item = t.item.accept(self)
        return TypeType.make_normalized(item)

    def visit_type_alias_type(self, t: TypeAliasType) -> Type:
        # Target of the type alias cannot contain type variables (not bound by the type
        # alias itself), so we just expand the arguments.
        args = self.expand_types_with_unpack(t.args)
        # TODO: normalize if target is Tuple, and args are [*tuple[X, ...]]?
        return t.copy_modified(args=args)

    def expand_types(self, types: Iterable[Type]) -> list[Type]:
        a: list[Type] = []
        for t in types:
            a.append(t.accept(self))
        return a


@overload
def expand_self_type(var: Var, typ: ProperType, replacement: ProperType) -> ProperType: ...


@overload
def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type: ...


def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
    """Expand appearances of Self type in a variable type."""
    if var.info.self_type is not None and not var.is_property:
        return expand_type(typ, {var.info.self_type.id: replacement})
    return typ


def remove_trivial(types: Iterable[Type]) -> list[Type]:
    """Make trivial simplifications on a list of types without calling is_subtype().

    This makes following simplifications:
        * Remove bottom types (taking into account strict optional setting)
        * Remove everything else if there is an `object`
        * Remove strict duplicate types
    """
    removed_none = False
    new_types = []
    all_types = set()
    for t in types:
        p_t = get_proper_type(t)
        if isinstance(p_t, UninhabitedType):
            continue
        if isinstance(p_t, NoneType) and not state.strict_optional:
            removed_none = True
            continue
        if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":
            return [p_t]
        if p_t not in all_types:
            new_types.append(t)
            all_types.add(p_t)
    if new_types:
        return new_types
    if removed_none:
        return [NoneType()]
    return [UninhabitedType()]
