"""Calculation of the least upper bound types (joins)."""

from __future__ import annotations

from collections.abc import Sequence
from typing import overload

import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
from mypy.state import state
from mypy.subtypes import (
    SubtypeContext,
    find_member,
    is_equivalent,
    is_proper_subtype,
    is_protocol_implementation,
    is_subtype,
)
from mypy.types import (
    AnyType,
    CallableType,
    DeletedType,
    ErasedType,
    FunctionLike,
    Instance,
    LiteralType,
    NoneType,
    Overloaded,
    Parameters,
    ParamSpecType,
    PartialType,
    ProperType,
    TupleType,
    Type,
    TypeAliasType,
    TypedDictType,
    TypeOfAny,
    TypeType,
    TypeVarId,
    TypeVarLikeType,
    TypeVarTupleType,
    TypeVarType,
    TypeVisitor,
    UnboundType,
    UninhabitedType,
    UnionType,
    UnpackType,
    find_unpack_in_list,
    get_proper_type,
    get_proper_types,
    split_with_prefix_and_suffix,
)


class InstanceJoiner:
    def __init__(self) -> None:
        self.seen_instances: list[tuple[Instance, Instance]] = []

    def join_instances(self, t: Instance, s: Instance) -> ProperType:
        if (t, s) in self.seen_instances or (s, t) in self.seen_instances:
            return object_from_instance(t)

        self.seen_instances.append((t, s))

        # Calculate the join of two instance types
        if t.type == s.type:
            # Simplest case: join two types with the same base type (but
            # potentially different arguments).

            # Combine type arguments.
            args: list[Type] = []
            # N.B: We use zip instead of indexing because the lengths might have
            # mismatches during daemon reprocessing.
            if t.type.has_type_var_tuple_type:
                # We handle joins of variadic instances by simply creating correct mapping
                # for type arguments and compute the individual joins same as for regular
                # instances. All the heavy lifting is done in the join of tuple types.
                assert s.type.type_var_tuple_prefix is not None
                assert s.type.type_var_tuple_suffix is not None
                prefix = s.type.type_var_tuple_prefix
                suffix = s.type.type_var_tuple_suffix
                tvt = s.type.defn.type_vars[prefix]
                assert isinstance(tvt, TypeVarTupleType)
                fallback = tvt.tuple_fallback
                s_prefix, s_middle, s_suffix = split_with_prefix_and_suffix(s.args, prefix, suffix)
                t_prefix, t_middle, t_suffix = split_with_prefix_and_suffix(t.args, prefix, suffix)
                s_args = s_prefix + (TupleType(list(s_middle), fallback),) + s_suffix
                t_args = t_prefix + (TupleType(list(t_middle), fallback),) + t_suffix
            else:
                t_args = t.args
                s_args = s.args
            for ta, sa, type_var in zip(t_args, s_args, t.type.defn.type_vars):
                ta_proper = get_proper_type(ta)
                sa_proper = get_proper_type(sa)
                new_type: Type | None = None
                if isinstance(ta_proper, AnyType):
                    new_type = AnyType(TypeOfAny.from_another_any, ta_proper)
                elif isinstance(sa_proper, AnyType):
                    new_type = AnyType(TypeOfAny.from_another_any, sa_proper)
                elif isinstance(type_var, TypeVarType):
                    if type_var.variance in (COVARIANT, VARIANCE_NOT_READY):
                        new_type = join_types(ta, sa, self)
                        if len(type_var.values) != 0 and new_type not in type_var.values:
                            self.seen_instances.pop()
                            return object_from_instance(t)
                        if not is_subtype(new_type, type_var.upper_bound):
                            self.seen_instances.pop()
                            return object_from_instance(t)
                    # TODO: contravariant case should use meet but pass seen instances as
                    # an argument to keep track of recursive checks.
                    elif type_var.variance in (INVARIANT, CONTRAVARIANT):
                        if isinstance(ta_proper, UninhabitedType) and ta_proper.ambiguous:
                            new_type = sa
                        elif isinstance(sa_proper, UninhabitedType) and sa_proper.ambiguous:
                            new_type = ta
                        elif not is_equivalent(ta, sa):
                            self.seen_instances.pop()
                            return object_from_instance(t)
                        else:
                            # If the types are different but equivalent, then an Any is involved
                            # so using a join in the contravariant case is also OK.
                            new_type = join_types(ta, sa, self)
                elif isinstance(type_var, TypeVarTupleType):
                    new_type = get_proper_type(join_types(ta, sa, self))
                    # Put the joined arguments back into instance in the normal form:
                    #   a) Tuple[X, Y, Z] -> [X, Y, Z]
                    #   b) tuple[X, ...] -> [*tuple[X, ...]]
                    if isinstance(new_type, Instance):
                        assert new_type.type.fullname == "builtins.tuple"
                        new_type = UnpackType(new_type)
                    else:
                        assert isinstance(new_type, TupleType)
                        args.extend(new_type.items)
                        continue
                else:
                    # ParamSpec type variables behave the same, independent of variance
                    if not is_equivalent(ta, sa):
                        return get_proper_type(type_var.upper_bound)
                    new_type = join_types(ta, sa, self)
                assert new_type is not None
                args.append(new_type)
            result: ProperType = Instance(t.type, args)
        elif t.type.bases and is_proper_subtype(
            t, s, subtype_context=SubtypeContext(ignore_type_params=True)
        ):
            result = self.join_instances_via_supertype(t, s)
        else:
            # Now t is not a subtype of s, and t != s. Now s could be a subtype
            # of t; alternatively, we need to find a common supertype. This works
            # in of the both cases.
            result = self.join_instances_via_supertype(s, t)

        self.seen_instances.pop()
        return result

    def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
        # Give preference to joins via duck typing relationship, so that
        # join(int, float) == float, for example.
        for p in t.type._promote:
            if is_subtype(p, s):
                return join_types(p, s, self)
        for p in s.type._promote:
            if is_subtype(p, t):
                return join_types(t, p, self)

        # Compute the "best" supertype of t when joined with s.
        # The definition of "best" may evolve; for now it is the one with
        # the longest MRO.  Ties are broken by using the earlier base.
        best: ProperType | None = None
        for base in t.type.bases:
            mapped = map_instance_to_supertype(t, base.type)
            res = self.join_instances(mapped, s)
            if best is None or is_better(res, best):
                best = res
        assert best is not None
        for promote in t.type._promote:
            if isinstance(promote, Instance):
                res = self.join_instances(promote, s)
                if is_better(res, best):
                    best = res
        return best


def trivial_join(s: Type, t: Type) -> Type:
    """Return one of types (expanded) if it is a supertype of other, otherwise top type."""
    if is_subtype(s, t):
        return t
    elif is_subtype(t, s):
        return s
    else:
        return object_or_any_from_type(get_proper_type(t))


@overload
def join_types(
    s: ProperType, t: ProperType, instance_joiner: InstanceJoiner | None = None
) -> ProperType: ...


@overload
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type: ...


def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
    """Return the least upper bound of s and t.

    For example, the join of 'int' and 'object' is 'object'.
    """
    if mypy.typeops.is_recursive_pair(s, t):
        # This case can trigger an infinite recursion, general support for this will be
        # tricky so we use a trivial join (like for protocols).
        return trivial_join(s, t)
    s = get_proper_type(s)
    t = get_proper_type(t)

    if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false):
        # if types are restricted in different ways, use the more general versions
        s = mypy.typeops.true_or_false(s)
        t = mypy.typeops.true_or_false(t)

    if isinstance(s, UnionType) and not isinstance(t, UnionType):
        s, t = t, s

    if isinstance(s, AnyType):
        return s

    if isinstance(s, ErasedType):
        return t

    if isinstance(s, NoneType) and not isinstance(t, NoneType):
        s, t = t, s

    if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
        s, t = t, s

    # Meets/joins require callable type normalization.
    s, t = normalize_callables(s, t)

    # Use a visitor to handle non-trivial cases.
    return t.accept(TypeJoinVisitor(s, instance_joiner))


class TypeJoinVisitor(TypeVisitor[ProperType]):
    """Implementation of the least upper bound algorithm.

    Attributes:
      s: The other (left) type operand.
    """

    def __init__(self, s: ProperType, instance_joiner: InstanceJoiner | None = None) -> None:
        self.s = s
        self.instance_joiner = instance_joiner

    def visit_unbound_type(self, t: UnboundType) -> ProperType:
        return AnyType(TypeOfAny.special_form)

    def visit_union_type(self, t: UnionType) -> ProperType:
        if is_proper_subtype(self.s, t):
            return t
        else:
            return mypy.typeops.make_simplified_union([self.s, t])

    def visit_any(self, t: AnyType) -> ProperType:
        return t

    def visit_none_type(self, t: NoneType) -> ProperType:
        if state.strict_optional:
            if isinstance(self.s, (NoneType, UninhabitedType)):
                return t
            elif isinstance(self.s, (UnboundType, AnyType)):
                return AnyType(TypeOfAny.special_form)
            else:
                return mypy.typeops.make_simplified_union([self.s, t])
        else:
            return self.s

    def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
        return self.s

    def visit_deleted_type(self, t: DeletedType) -> ProperType:
        return self.s

    def visit_erased_type(self, t: ErasedType) -> ProperType:
        return self.s

    def visit_type_var(self, t: TypeVarType) -> ProperType:
        if isinstance(self.s, TypeVarType) and self.s.id == t.id:
            return self.s
        else:
            return self.default(self.s)

    def visit_param_spec(self, t: ParamSpecType) -> ProperType:
        if self.s == t:
            return t
        return self.default(self.s)

    def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
        if self.s == t:
            return t
        if isinstance(self.s, Instance) and is_subtype(t.upper_bound, self.s):
            # TODO: should we do this more generally and for all TypeVarLikeTypes?
            return self.s
        return self.default(self.s)

    def visit_unpack_type(self, t: UnpackType) -> UnpackType:
        raise NotImplementedError

    def visit_parameters(self, t: Parameters) -> ProperType:
        if isinstance(self.s, Parameters):
            if not is_similar_params(t, self.s):
                # TODO: it would be prudent to return [*object, **object] instead of Any.
                return self.default(self.s)
            from mypy.meet import meet_types

            return t.copy_modified(
                arg_types=[
                    meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)
                ],
                arg_names=combine_arg_names(self.s, t),
            )
        else:
            return self.default(self.s)

    def visit_instance(self, t: Instance) -> ProperType:
        if isinstance(self.s, Instance):
            if self.instance_joiner is None:
                self.instance_joiner = InstanceJoiner()
            nominal = self.instance_joiner.join_instances(t, self.s)
            structural: Instance | None = None
            if t.type.is_protocol and is_protocol_implementation(self.s, t):
                structural = t
            elif self.s.type.is_protocol and is_protocol_implementation(t, self.s):
                structural = self.s
            # Structural join is preferred in the case where we have found both
            # structural and nominal and they have same MRO length (see two comments
            # in join_instances_via_supertype). Otherwise, just return the nominal join.
            if not structural or is_better(nominal, structural):
                return nominal
            return structural
        elif isinstance(self.s, FunctionLike):
            if t.type.is_protocol:
                call = unpack_callback_protocol(t)
                if call:
                    return join_types(call, self.s)
            return join_types(t, self.s.fallback)
        elif isinstance(self.s, TypeType):
            return join_types(t, self.s)
        elif isinstance(self.s, TypedDictType):
            return join_types(t, self.s)
        elif isinstance(self.s, TupleType):
            return join_types(t, self.s)
        elif isinstance(self.s, LiteralType):
            return join_types(t, self.s)
        elif isinstance(self.s, TypeVarTupleType) and is_subtype(self.s.upper_bound, t):
            return t
        else:
            return self.default(self.s)

    def visit_callable_type(self, t: CallableType) -> ProperType:
        if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
            if is_equivalent(t, self.s):
                return combine_similar_callables(t, self.s)
            result = join_similar_callables(t, self.s)
            # We set the from_type_type flag to suppress error when a collection of
            # concrete class objects gets inferred as their common abstract superclass.
            if not (
                (t.is_type_obj() and t.type_object().is_abstract)
                or (self.s.is_type_obj() and self.s.type_object().is_abstract)
            ):
                result.from_type_type = True
            if any(
                isinstance(tp, (NoneType, UninhabitedType))
                for tp in get_proper_types(result.arg_types)
            ):
                # We don't want to return unusable Callable, attempt fallback instead.
                return join_types(t.fallback, self.s)
            return result
        elif isinstance(self.s, Overloaded):
            # Switch the order of arguments to that we'll get to visit_overloaded.
            return join_types(t, self.s)
        elif isinstance(self.s, Instance) and self.s.type.is_protocol:
            call = unpack_callback_protocol(self.s)
            if call:
                return join_types(t, call)
        return join_types(t.fallback, self.s)

    def visit_overloaded(self, t: Overloaded) -> ProperType:
        # This is more complex than most other cases. Here are some
        # examples that illustrate how this works.
        #
        # First let's define a concise notation:
        #  - Cn are callable types (for n in 1, 2, ...)
        #  - Ov(C1, C2, ...) is an overloaded type with items C1, C2, ...
        #  - Callable[[T, ...], S] is written as [T, ...] -> S.
        #
        # We want some basic properties to hold (assume Cn are all
        # unrelated via Any-similarity):
        #
        #   join(Ov(C1, C2), C1) == C1
        #   join(Ov(C1, C2), Ov(C1, C2)) == Ov(C1, C2)
        #   join(Ov(C1, C2), Ov(C1, C3)) == C1
        #   join(Ov(C2, C2), C3) == join of fallback types
        #
        # The presence of Any types makes things more interesting. The join is the
        # most general type we can get with respect to Any:
        #
        #   join(Ov([int] -> int, [str] -> str), [Any] -> str) == Any -> str
        #
        # We could use a simplification step that removes redundancies, but that's not
        # implemented right now. Consider this example, where we get a redundancy:
        #
        #   join(Ov([int, Any] -> Any, [str, Any] -> Any), [Any, int] -> Any) ==
        #       Ov([Any, int] -> Any, [Any, int] -> Any)
        #
        # TODO: Consider more cases of callable subtyping.
        result: list[CallableType] = []
        s = self.s
        if isinstance(s, FunctionLike):
            # The interesting case where both types are function types.
            for t_item in t.items:
                for s_item in s.items:
                    if is_similar_callables(t_item, s_item):
                        if is_equivalent(t_item, s_item):
                            result.append(combine_similar_callables(t_item, s_item))
                        elif is_subtype(t_item, s_item):
                            result.append(s_item)
            if result:
                # TODO: Simplify redundancies from the result.
                if len(result) == 1:
                    return result[0]
                else:
                    return Overloaded(result)
            return join_types(t.fallback, s.fallback)
        elif isinstance(s, Instance) and s.type.is_protocol:
            call = unpack_callback_protocol(s)
            if call:
                return join_types(t, call)
        return join_types(t.fallback, s)

    def join_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
        """Join two tuple types while handling variadic entries.

        This is surprisingly tricky, and we don't handle some tricky corner cases.
        Most of the trickiness comes from the variadic tuple items like *tuple[X, ...]
        since they can have arbitrary partial overlaps (while *Ts can't be split).
        """
        s_unpack_index = find_unpack_in_list(s.items)
        t_unpack_index = find_unpack_in_list(t.items)
        if s_unpack_index is None and t_unpack_index is None:
            if s.length() == t.length():
                items: list[Type] = []
                for i in range(t.length()):
                    items.append(join_types(t.items[i], s.items[i]))
                return items
            return None
        if s_unpack_index is not None and t_unpack_index is not None:
            # The most complex case: both tuples have an unpack item.
            s_unpack = s.items[s_unpack_index]
            assert isinstance(s_unpack, UnpackType)
            s_unpacked = get_proper_type(s_unpack.type)
            t_unpack = t.items[t_unpack_index]
            assert isinstance(t_unpack, UnpackType)
            t_unpacked = get_proper_type(t_unpack.type)
            if s.length() == t.length() and s_unpack_index == t_unpack_index:
                # We can handle a case where arity is perfectly aligned, e.g.
                # join(Tuple[X1, *tuple[Y1, ...], Z1], Tuple[X2, *tuple[Y2, ...], Z2]).
                # We can essentially perform the join elementwise.
                prefix_len = t_unpack_index
                suffix_len = t.length() - t_unpack_index - 1
                items = []
                for si, ti in zip(s.items[:prefix_len], t.items[:prefix_len]):
                    items.append(join_types(si, ti))
                joined = join_types(s_unpacked, t_unpacked)
                if isinstance(joined, TypeVarTupleType):
                    items.append(UnpackType(joined))
                elif isinstance(joined, Instance) and joined.type.fullname == "builtins.tuple":
                    items.append(UnpackType(joined))
                else:
                    if isinstance(t_unpacked, Instance):
                        assert t_unpacked.type.fullname == "builtins.tuple"
                        tuple_instance = t_unpacked
                    else:
                        assert isinstance(t_unpacked, TypeVarTupleType)
                        tuple_instance = t_unpacked.tuple_fallback
                    items.append(
                        UnpackType(
                            tuple_instance.copy_modified(
                                args=[object_from_instance(tuple_instance)]
                            )
                        )
                    )
                if suffix_len:
                    for si, ti in zip(s.items[-suffix_len:], t.items[-suffix_len:]):
                        items.append(join_types(si, ti))
                return items
            if s.length() == 1 or t.length() == 1:
                # Another case we can handle is when one of tuple is purely variadic
                # (i.e. a non-normalized form of tuple[X, ...]), in this case the join
                # will be again purely variadic.
                if not (isinstance(s_unpacked, Instance) and isinstance(t_unpacked, Instance)):
                    return None
                assert s_unpacked.type.fullname == "builtins.tuple"
                assert t_unpacked.type.fullname == "builtins.tuple"
                mid_joined = join_types(s_unpacked.args[0], t_unpacked.args[0])
                t_other = [a for i, a in enumerate(t.items) if i != t_unpack_index]
                s_other = [a for i, a in enumerate(s.items) if i != s_unpack_index]
                other_joined = join_type_list(s_other + t_other)
                mid_joined = join_types(mid_joined, other_joined)
                return [UnpackType(s_unpacked.copy_modified(args=[mid_joined]))]
            # TODO: are there other case we can handle (e.g. both prefix/suffix are shorter)?
            return None
        if s_unpack_index is not None:
            variadic = s
            unpack_index = s_unpack_index
            fixed = t
        else:
            assert t_unpack_index is not None
            variadic = t
            unpack_index = t_unpack_index
            fixed = s
        # Case where one tuple has variadic item and the other one doesn't. The join will
        # be variadic, since fixed tuple is a subtype of variadic, but not vice versa.
        unpack = variadic.items[unpack_index]
        assert isinstance(unpack, UnpackType)
        unpacked = get_proper_type(unpack.type)
        if not isinstance(unpacked, Instance):
            return None
        if fixed.length() < variadic.length() - 1:
            # There are no non-trivial types that are supertype of both.
            return None
        prefix_len = unpack_index
        suffix_len = variadic.length() - prefix_len - 1
        prefix, middle, suffix = split_with_prefix_and_suffix(
            tuple(fixed.items), prefix_len, suffix_len
        )
        items = []
        for fi, vi in zip(prefix, variadic.items[:prefix_len]):
            items.append(join_types(fi, vi))
        mid_joined = join_type_list(list(middle))
        mid_joined = join_types(mid_joined, unpacked.args[0])
        items.append(UnpackType(unpacked.copy_modified(args=[mid_joined])))
        if suffix_len:
            for fi, vi in zip(suffix, variadic.items[-suffix_len:]):
                items.append(join_types(fi, vi))
        return items

    def visit_tuple_type(self, t: TupleType) -> ProperType:
        # When given two fixed-length tuples:
        # * If they have the same length, join their subtypes item-wise:
        #   Tuple[int, bool] + Tuple[bool, bool] becomes Tuple[int, bool]
        # * If lengths do not match, return a variadic tuple:
        #   Tuple[bool, int] + Tuple[bool] becomes Tuple[int, ...]
        #
        # Otherwise, `t` is a fixed-length tuple but `self.s` is NOT:
        # * Joining with a variadic tuple returns variadic tuple:
        #   Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
        # * Joining with any Sequence also returns a Sequence:
        #   Tuple[int, bool] + List[bool] becomes Sequence[int]
        if isinstance(self.s, TupleType):
            if self.instance_joiner is None:
                self.instance_joiner = InstanceJoiner()
            fallback = self.instance_joiner.join_instances(
                mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
            )
            assert isinstance(fallback, Instance)
            items = self.join_tuples(self.s, t)
            if items is not None:
                if len(items) == 1 and isinstance(item := items[0], UnpackType):
                    if isinstance(unpacked := get_proper_type(item.type), Instance):
                        # Avoid double-wrapping tuple[*tuple[X, ...]]
                        return unpacked
                return TupleType(items, fallback)
            else:
                # TODO: should this be a default fallback behaviour like for meet?
                if is_proper_subtype(self.s, t):
                    return t
                if is_proper_subtype(t, self.s):
                    return self.s
                return fallback
        else:
            return join_types(self.s, mypy.typeops.tuple_fallback(t))

    def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
        if isinstance(self.s, TypedDictType):
            items = {
                item_name: s_item_type
                for (item_name, s_item_type, t_item_type) in self.s.zip(t)
                if (
                    is_equivalent(s_item_type, t_item_type)
                    and (item_name in t.required_keys) == (item_name in self.s.required_keys)
                )
            }
            fallback = self.s.create_anonymous_fallback()
            all_keys = set(items.keys())
            # We need to filter by items.keys() since some required keys present in both t and
            # self.s might be missing from the join if the types are incompatible.
            required_keys = all_keys & t.required_keys & self.s.required_keys
            # If one type has a key as readonly, we mark it as readonly for both:
            readonly_keys = (t.readonly_keys | t.readonly_keys) & all_keys
            return TypedDictType(items, required_keys, readonly_keys, fallback)
        elif isinstance(self.s, Instance):
            return join_types(self.s, t.fallback)
        else:
            return self.default(self.s)

    def visit_literal_type(self, t: LiteralType) -> ProperType:
        if isinstance(self.s, LiteralType):
            if t == self.s:
                return t
            if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
                return mypy.typeops.make_simplified_union([self.s, t])
            return join_types(self.s.fallback, t.fallback)
        else:
            return join_types(self.s, t.fallback)

    def visit_partial_type(self, t: PartialType) -> ProperType:
        # We only have partial information so we can't decide the join result. We should
        # never get here.
        assert False, "Internal error"

    def visit_type_type(self, t: TypeType) -> ProperType:
        if isinstance(self.s, TypeType):
            return TypeType.make_normalized(join_types(t.item, self.s.item), line=t.line)
        elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type":
            return self.s
        else:
            return self.default(self.s)

    def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
        assert False, f"This should be never called, got {t}"

    def default(self, typ: Type) -> ProperType:
        typ = get_proper_type(typ)
        if isinstance(typ, Instance):
            return object_from_instance(typ)
        elif isinstance(typ, UnboundType):
            return AnyType(TypeOfAny.special_form)
        elif isinstance(typ, TupleType):
            return self.default(mypy.typeops.tuple_fallback(typ))
        elif isinstance(typ, TypedDictType):
            return self.default(typ.fallback)
        elif isinstance(typ, FunctionLike):
            return self.default(typ.fallback)
        elif isinstance(typ, TypeVarType):
            return self.default(typ.upper_bound)
        elif isinstance(typ, ParamSpecType):
            return self.default(typ.upper_bound)
        else:
            return AnyType(TypeOfAny.special_form)


def is_better(t: Type, s: Type) -> bool:
    # Given two possible results from join_instances_via_supertype(),
    # indicate whether t is the better one.
    t = get_proper_type(t)
    s = get_proper_type(s)

    if isinstance(t, Instance):
        if not isinstance(s, Instance):
            return True
        # Use len(mro) as a proxy for the better choice.
        if len(t.type.mro) > len(s.type.mro):
            return True
    return False


def normalize_callables(s: ProperType, t: ProperType) -> tuple[ProperType, ProperType]:
    if isinstance(s, (CallableType, Overloaded)):
        s = s.with_unpacked_kwargs()
    if isinstance(t, (CallableType, Overloaded)):
        t = t.with_unpacked_kwargs()
    return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
    """Return True if t and s have identical numbers of
    arguments, default arguments and varargs.
    """
    return (
        len(t.arg_types) == len(s.arg_types)
        and t.min_args == s.min_args
        and t.is_var_arg == s.is_var_arg
    )


def is_similar_params(t: Parameters, s: Parameters) -> bool:
    # This matches the logic in is_similar_callables() above.
    return (
        len(t.arg_types) == len(s.arg_types)
        and t.min_args == s.min_args
        and (t.var_arg() is not None) == (s.var_arg() is not None)
    )


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
    tv_map = {}
    tvs = []
    for tv, new_id in zip(c.variables, ids):
        new_tv = tv.copy_modified(id=new_id)
        tvs.append(new_tv)
        tv_map[tv.id] = new_tv
    return expand_type(c, tv_map).copy_modified(variables=tvs)


def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableType, CallableType]:
    # The case where we combine/join/meet similar callables, situation where both are generic
    # requires special care. A more principled solution may involve unify_generic_callable(),
    # but it would have two problems:
    #   * This adds risk of infinite recursion: e.g. join -> unification -> solver -> join
    #   * Using unification is an incorrect thing for meets, as it "widens" the types
    # Finally, this effectively falls back to an old behaviour before namespaces were added to
    # type variables, and it worked relatively well.
    max_len = max(len(t.variables), len(s.variables))
    min_len = min(len(t.variables), len(s.variables))
    if min_len == 0:
        return t, s
    new_ids = [TypeVarId.new(meta_level=0) for _ in range(max_len)]
    # Note: this relies on variables being in order they appear in function definition.
    return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)


def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
    t, s = match_generic_callables(t, s)
    arg_types: list[Type] = []
    for i in range(len(t.arg_types)):
        arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))
    # TODO in combine_similar_callables also applies here (names and kinds; user metaclasses)
    # The fallback type can be either 'function', 'type', or some user-provided metaclass.
    # The result should always use 'function' as a fallback if either operands are using it.
    if t.fallback.type.fullname == "builtins.function":
        fallback = t.fallback
    else:
        fallback = s.fallback
    return t.copy_modified(
        arg_types=arg_types,
        arg_names=combine_arg_names(t, s),
        ret_type=join_types(t.ret_type, s.ret_type),
        fallback=fallback,
        name=None,
    )


def safe_join(t: Type, s: Type) -> Type:
    # This is a temporary solution to prevent crashes in combine_similar_callables() etc.,
    # until relevant TODOs on handling arg_kinds will be addressed there.
    if not isinstance(t, UnpackType) and not isinstance(s, UnpackType):
        return join_types(t, s)
    if isinstance(t, UnpackType) and isinstance(s, UnpackType):
        return UnpackType(join_types(t.type, s.type))
    return object_or_any_from_type(get_proper_type(t))


def safe_meet(t: Type, s: Type) -> Type:
    # Similar to above but for meet_types().
    from mypy.meet import meet_types

    if not isinstance(t, UnpackType) and not isinstance(s, UnpackType):
        return meet_types(t, s)
    if isinstance(t, UnpackType) and isinstance(s, UnpackType):
        unpacked = get_proper_type(t.type)
        if isinstance(unpacked, TypeVarTupleType):
            fallback_type = unpacked.tuple_fallback.type
        elif isinstance(unpacked, TupleType):
            fallback_type = unpacked.partial_fallback.type
        else:
            assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
            fallback_type = unpacked.type
        res = meet_types(t.type, s.type)
        if isinstance(res, UninhabitedType):
            res = Instance(fallback_type, [res])
        return UnpackType(res)
    return UninhabitedType()


def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
    t, s = match_generic_callables(t, s)
    arg_types: list[Type] = []
    for i in range(len(t.arg_types)):
        arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
    # TODO kinds and argument names
    # TODO what should happen if one fallback is 'type' and the other is a user-provided metaclass?
    # The fallback type can be either 'function', 'type', or some user-provided metaclass.
    # The result should always use 'function' as a fallback if either operands are using it.
    if t.fallback.type.fullname == "builtins.function":
        fallback = t.fallback
    else:
        fallback = s.fallback
    return t.copy_modified(
        arg_types=arg_types,
        arg_names=combine_arg_names(t, s),
        ret_type=join_types(t.ret_type, s.ret_type),
        fallback=fallback,
        name=None,
    )


def combine_arg_names(
    t: CallableType | Parameters, s: CallableType | Parameters
) -> list[str | None]:
    """Produces a list of argument names compatible with both callables.

    For example, suppose 't' and 's' have the following signatures:

    - t: (a: int, b: str, X: str) -> None
    - s: (a: int, b: str, Y: str) -> None

    This function would return ["a", "b", None]. This information
    is then used above to compute the join of t and s, which results
    in a signature of (a: int, b: str, str) -> None.

    Note that the third argument's name is omitted and 't' and 's'
    are both valid subtypes of this inferred signature.

    Precondition: is_similar_types(t, s) is true.
    """
    num_args = len(t.arg_types)
    new_names = []
    for i in range(num_args):
        t_name = t.arg_names[i]
        s_name = s.arg_names[i]
        if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
            new_names.append(t_name)
        else:
            new_names.append(None)
    return new_names


def object_from_instance(instance: Instance) -> Instance:
    """Construct the type 'builtins.object' from an instance type."""
    # Use the fact that 'object' is always the last class in the mro.
    res = Instance(instance.type.mro[-1], [])
    return res


def object_or_any_from_type(typ: ProperType) -> ProperType:
    # Similar to object_from_instance() but tries hard for all types.
    # TODO: find a better way to get object, or make this more reliable.
    if isinstance(typ, Instance):
        return object_from_instance(typ)
    elif isinstance(typ, (CallableType, TypedDictType, LiteralType)):
        return object_from_instance(typ.fallback)
    elif isinstance(typ, TupleType):
        return object_from_instance(typ.partial_fallback)
    elif isinstance(typ, TypeType):
        return object_or_any_from_type(typ.item)
    elif isinstance(typ, TypeVarLikeType) and isinstance(typ.upper_bound, ProperType):
        return object_or_any_from_type(typ.upper_bound)
    elif isinstance(typ, UnionType):
        for item in typ.items:
            if isinstance(item, ProperType):
                candidate = object_or_any_from_type(item)
                if isinstance(candidate, Instance):
                    return candidate
    elif isinstance(typ, UnpackType):
        object_or_any_from_type(get_proper_type(typ.type))
    return AnyType(TypeOfAny.implementation_artifact)


def join_type_list(types: Sequence[Type]) -> Type:
    if not types:
        # This is a little arbitrary but reasonable. Any empty tuple should be compatible
        # with all variable length tuples, and this makes it possible.
        return UninhabitedType()
    joined = types[0]
    for t in types[1:]:
        joined = join_types(joined, t)
    return joined


def unpack_callback_protocol(t: Instance) -> ProperType | None:
    assert t.type.is_protocol
    if t.type.protocol_members == ["__call__"]:
        return get_proper_type(find_member("__call__", t, t, is_operator=True))
    return None
