from __future__ import annotations

from collections.abc import Container
from typing import Callable, cast

from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.types import (
    AnyType,
    CallableType,
    DeletedType,
    ErasedType,
    Instance,
    LiteralType,
    NoneType,
    Overloaded,
    Parameters,
    ParamSpecType,
    PartialType,
    ProperType,
    TupleType,
    Type,
    TypeAliasType,
    TypedDictType,
    TypeOfAny,
    TypeTranslator,
    TypeType,
    TypeVarId,
    TypeVarTupleType,
    TypeVarType,
    TypeVisitor,
    UnboundType,
    UninhabitedType,
    UnionType,
    UnpackType,
    get_proper_type,
    get_proper_types,
)
from mypy.typevartuples import erased_vars


def erase_type(typ: Type) -> ProperType:
    """Erase any type variables from a type.

    Also replace tuple types with the corresponding concrete types.

    Examples:
      A -> A
      B[X] -> B[Any]
      Tuple[A, B] -> tuple
      Callable[[A1, A2, ...], R] -> Callable[..., Any]
      Type[X] -> Type[Any]
    """
    typ = get_proper_type(typ)
    return typ.accept(EraseTypeVisitor())


class EraseTypeVisitor(TypeVisitor[ProperType]):
    def visit_unbound_type(self, t: UnboundType) -> ProperType:
        # TODO: replace with an assert after UnboundType can't leak from semantic analysis.
        return AnyType(TypeOfAny.from_error)

    def visit_any(self, t: AnyType) -> ProperType:
        return t

    def visit_none_type(self, t: NoneType) -> ProperType:
        return t

    def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
        return t

    def visit_erased_type(self, t: ErasedType) -> ProperType:
        return t

    def visit_partial_type(self, t: PartialType) -> ProperType:
        # Should not get here.
        raise RuntimeError("Cannot erase partial types")

    def visit_deleted_type(self, t: DeletedType) -> ProperType:
        return t

    def visit_instance(self, t: Instance) -> ProperType:
        args = erased_vars(t.type.defn.type_vars, TypeOfAny.special_form)
        return Instance(t.type, args, t.line)

    def visit_type_var(self, t: TypeVarType) -> ProperType:
        return AnyType(TypeOfAny.special_form)

    def visit_param_spec(self, t: ParamSpecType) -> ProperType:
        return AnyType(TypeOfAny.special_form)

    def visit_parameters(self, t: Parameters) -> ProperType:
        raise RuntimeError("Parameters should have been bound to a class")

    def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
        # Likely, we can never get here because of aggressive erasure of types that
        # can contain this, but better still return a valid replacement.
        return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])

    def visit_unpack_type(self, t: UnpackType) -> ProperType:
        return AnyType(TypeOfAny.special_form)

    def visit_callable_type(self, t: CallableType) -> ProperType:
        # We must preserve the fallback type for overload resolution to work.
        any_type = AnyType(TypeOfAny.special_form)
        return CallableType(
            arg_types=[any_type, any_type],
            arg_kinds=[ARG_STAR, ARG_STAR2],
            arg_names=[None, None],
            ret_type=any_type,
            fallback=t.fallback,
            is_ellipsis_args=True,
            implicit=True,
        )

    def visit_overloaded(self, t: Overloaded) -> ProperType:
        return t.fallback.accept(self)

    def visit_tuple_type(self, t: TupleType) -> ProperType:
        return t.partial_fallback.accept(self)

    def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
        return t.fallback.accept(self)

    def visit_literal_type(self, t: LiteralType) -> ProperType:
        # The fallback for literal types should always be either
        # something like int or str, or an enum class -- types that
        # don't contain any TypeVars. So there's no need to visit it.
        return t

    def visit_union_type(self, t: UnionType) -> ProperType:
        erased_items = [erase_type(item) for item in t.items]
        from mypy.typeops import make_simplified_union

        return make_simplified_union(erased_items)

    def visit_type_type(self, t: TypeType) -> ProperType:
        return TypeType.make_normalized(t.item.accept(self), line=t.line)

    def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
        raise RuntimeError("Type aliases should be expanded before accepting this visitor")


def erase_typevars(t: Type, ids_to_erase: Container[TypeVarId] | None = None) -> Type:
    """Replace all type variables in a type with any,
    or just the ones in the provided collection.
    """

    def erase_id(id: TypeVarId) -> bool:
        if ids_to_erase is None:
            return True
        return id in ids_to_erase

    return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form)))


def replace_meta_vars(t: Type, target_type: Type) -> Type:
    """Replace unification variables in a type with the target type."""
    return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type))


class TypeVarEraser(TypeTranslator):
    """Implementation of type erasure"""

    def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
        super().__init__()
        self.erase_id = erase_id
        self.replacement = replacement

    def visit_type_var(self, t: TypeVarType) -> Type:
        if self.erase_id(t.id):
            return self.replacement
        return t

    # TODO: below two methods duplicate some logic with expand_type().
    # In fact, we may want to refactor this whole visitor to use expand_type().
    def visit_instance(self, t: Instance) -> Type:
        result = super().visit_instance(t)
        assert isinstance(result, ProperType) and isinstance(result, Instance)
        if t.type.fullname == "builtins.tuple":
            # Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
            arg = result.args[0]
            if isinstance(arg, UnpackType):
                unpacked = get_proper_type(arg.type)
                if isinstance(unpacked, Instance):
                    assert unpacked.type.fullname == "builtins.tuple"
                    return unpacked
        return result

    def visit_tuple_type(self, t: TupleType) -> Type:
        result = super().visit_tuple_type(t)
        assert isinstance(result, ProperType) and isinstance(result, TupleType)
        if len(result.items) == 1:
            # Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
            item = result.items[0]
            if isinstance(item, UnpackType):
                unpacked = get_proper_type(item.type)
                if isinstance(unpacked, Instance):
                    assert unpacked.type.fullname == "builtins.tuple"
                    if result.partial_fallback.type.fullname != "builtins.tuple":
                        # If it is a subtype (like named tuple) we need to preserve it,
                        # this essentially mimics the logic in tuple_fallback().
                        return result.partial_fallback.accept(self)
                    return unpacked
        return result

    def visit_callable_type(self, t: CallableType) -> Type:
        result = super().visit_callable_type(t)
        assert isinstance(result, ProperType) and isinstance(result, CallableType)
        # Usually this is done in semanal_typeargs.py, but erasure can create
        # a non-normal callable from normal one.
        result.normalize_trivial_unpack()
        return result

    def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
        if self.erase_id(t.id):
            return t.tuple_fallback.copy_modified(args=[self.replacement])
        return t

    def visit_param_spec(self, t: ParamSpecType) -> Type:
        if self.erase_id(t.id):
            return self.replacement
        return t

    def visit_type_alias_type(self, t: TypeAliasType) -> Type:
        # Type alias target can't contain bound type variables (not bound by the type
        # alias itself), so it is safe to just erase the arguments.
        return t.copy_modified(args=[a.accept(self) for a in t.args])


def remove_instance_last_known_values(t: Type) -> Type:
    return t.accept(LastKnownValueEraser())


class LastKnownValueEraser(TypeTranslator):
    """Removes the Literal[...] type that may be associated with any
    Instance types."""

    def visit_instance(self, t: Instance) -> Type:
        if not t.last_known_value and not t.args:
            return t
        return t.copy_modified(args=[a.accept(self) for a in t.args], last_known_value=None)

    def visit_type_alias_type(self, t: TypeAliasType) -> Type:
        # Type aliases can't contain literal values, because they are
        # always constructed as explicit types.
        return t

    def visit_union_type(self, t: UnionType) -> Type:
        new = cast(UnionType, super().visit_union_type(t))
        # Erasure can result in many duplicate items; merge them.
        # Call make_simplified_union only on lists of instance types
        # that all have the same fullname, to avoid simplifying too
        # much.
        instances = [item for item in new.items if isinstance(get_proper_type(item), Instance)]
        # Avoid merge in simple cases such as optional types.
        if len(instances) > 1:
            instances_by_name: dict[str, list[Instance]] = {}
            p_new_items = get_proper_types(new.items)
            for p_item in p_new_items:
                if isinstance(p_item, Instance) and not p_item.args:
                    instances_by_name.setdefault(p_item.type.fullname, []).append(p_item)
            merged: list[Type] = []
            for item in new.items:
                orig_item = item
                item = get_proper_type(item)
                if isinstance(item, Instance) and not item.args:
                    types = instances_by_name.get(item.type.fullname)
                    if types is not None:
                        if len(types) == 1:
                            merged.append(item)
                        else:
                            from mypy.typeops import make_simplified_union

                            merged.append(make_simplified_union(types))
                            del instances_by_name[item.type.fullname]
                else:
                    merged.append(orig_item)
            return UnionType.make_union(merged)
        return new
