from __future__ import annotations

from functools import lru_cache
from typing import (
    TYPE_CHECKING,
    Any,
    Iterable,
    Iterator,
    Mapping,
    TypeVar,
    cast,
    overload,
)

import polars as pl

from narwhals._utils import Version, _DeferredIterable, isinstance_or_issubclass
from narwhals.exceptions import (
    ColumnNotFoundError,
    ComputeError,
    DuplicateError,
    InvalidOperationError,
    NarwhalsError,
    ShapeError,
)

if TYPE_CHECKING:
    from typing_extensions import TypeIs

    from narwhals._utils import _StoresNative
    from narwhals.dtypes import DType
    from narwhals.typing import IntoDType

    T = TypeVar("T")
    NativeT = TypeVar(
        "NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
    )


@overload
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
    return obj.native if _is_compliant_polars(obj) else obj


def _is_compliant_polars(
    obj: _StoresNative[NativeT] | Any,
) -> TypeIs[_StoresNative[NativeT]]:
    from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
    from narwhals._polars.expr import PolarsExpr
    from narwhals._polars.series import PolarsSeries

    return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))


def extract_args_kwargs(
    args: Iterable[Any], kwds: Mapping[str, Any], /
) -> tuple[Iterator[Any], dict[str, Any]]:
    it_args = (extract_native(arg) for arg in args)
    return it_args, {k: extract_native(v) for k, v in kwds.items()}


@lru_cache(maxsize=16)
def native_to_narwhals_dtype(  # noqa: C901, PLR0912
    dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
) -> DType:
    dtypes = version.dtypes
    if dtype == pl.Float64:
        return dtypes.Float64()
    if dtype == pl.Float32:
        return dtypes.Float32()
    if hasattr(pl, "Int128") and dtype == pl.Int128:  # pragma: no cover
        # Not available for Polars pre 1.8.0
        return dtypes.Int128()
    if dtype == pl.Int64:
        return dtypes.Int64()
    if dtype == pl.Int32:
        return dtypes.Int32()
    if dtype == pl.Int16:
        return dtypes.Int16()
    if dtype == pl.Int8:
        return dtypes.Int8()
    if hasattr(pl, "UInt128") and dtype == pl.UInt128:  # pragma: no cover
        # Not available for Polars pre 1.8.0
        return dtypes.UInt128()
    if dtype == pl.UInt64:
        return dtypes.UInt64()
    if dtype == pl.UInt32:
        return dtypes.UInt32()
    if dtype == pl.UInt16:
        return dtypes.UInt16()
    if dtype == pl.UInt8:
        return dtypes.UInt8()
    if dtype == pl.String:
        return dtypes.String()
    if dtype == pl.Boolean:
        return dtypes.Boolean()
    if dtype == pl.Object:
        return dtypes.Object()
    if dtype == pl.Categorical:
        return dtypes.Categorical()
    if isinstance_or_issubclass(dtype, pl.Enum):
        if version is Version.V1:
            return dtypes.Enum()  # type: ignore[call-arg]
        categories = _DeferredIterable(
            dtype.categories.to_list
            if backend_version >= (0, 20, 4)
            else lambda: cast("list[str]", dtype.categories)
        )
        return dtypes.Enum(categories)
    if dtype == pl.Date:
        return dtypes.Date()
    if isinstance_or_issubclass(dtype, pl.Datetime):
        return (
            dtypes.Datetime()
            if dtype is pl.Datetime
            else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
        )
    if isinstance_or_issubclass(dtype, pl.Duration):
        return (
            dtypes.Duration()
            if dtype is pl.Duration
            else dtypes.Duration(dtype.time_unit)
        )
    if isinstance_or_issubclass(dtype, pl.Struct):
        fields = [
            dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
            for name, tp in dtype
        ]
        return dtypes.Struct(fields)
    if isinstance_or_issubclass(dtype, pl.List):
        return dtypes.List(
            native_to_narwhals_dtype(dtype.inner, version, backend_version)
        )
    if isinstance_or_issubclass(dtype, pl.Array):
        outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
        return dtypes.Array(
            native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
        )
    if dtype == pl.Decimal:
        return dtypes.Decimal()
    if dtype == pl.Time:
        return dtypes.Time()
    if dtype == pl.Binary:
        return dtypes.Binary()
    return dtypes.Unknown()


def narwhals_to_native_dtype(  # noqa: C901, PLR0912
    dtype: IntoDType, version: Version, backend_version: tuple[int, ...]
) -> pl.DataType:
    dtypes = version.dtypes
    if dtype == dtypes.Float64:
        return pl.Float64()
    if dtype == dtypes.Float32:
        return pl.Float32()
    if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
        # Not available for Polars pre 1.8.0
        return pl.Int128()
    if dtype == dtypes.Int64:
        return pl.Int64()
    if dtype == dtypes.Int32:
        return pl.Int32()
    if dtype == dtypes.Int16:
        return pl.Int16()
    if dtype == dtypes.Int8:
        return pl.Int8()
    if dtype == dtypes.UInt64:
        return pl.UInt64()
    if dtype == dtypes.UInt32:
        return pl.UInt32()
    if dtype == dtypes.UInt16:
        return pl.UInt16()
    if dtype == dtypes.UInt8:
        return pl.UInt8()
    if dtype == dtypes.String:
        return pl.String()
    if dtype == dtypes.Boolean:
        return pl.Boolean()
    if dtype == dtypes.Object:  # pragma: no cover
        return pl.Object()
    if dtype == dtypes.Categorical:
        return pl.Categorical()
    if isinstance_or_issubclass(dtype, dtypes.Enum):
        if version is Version.V1:
            msg = "Converting to Enum is not supported in narwhals.stable.v1"
            raise NotImplementedError(msg)
        if isinstance(dtype, dtypes.Enum):
            return pl.Enum(dtype.categories)
        msg = "Can not cast / initialize Enum without categories present"
        raise ValueError(msg)
    if dtype == dtypes.Date:
        return pl.Date()
    if dtype == dtypes.Time:
        return pl.Time()
    if dtype == dtypes.Binary:
        return pl.Binary()
    if dtype == dtypes.Decimal:
        msg = "Casting to Decimal is not supported yet."
        raise NotImplementedError(msg)
    if isinstance_or_issubclass(dtype, dtypes.Datetime):
        return pl.Datetime(dtype.time_unit, dtype.time_zone)  # type: ignore[arg-type]
    if isinstance_or_issubclass(dtype, dtypes.Duration):
        return pl.Duration(dtype.time_unit)  # type: ignore[arg-type]
    if isinstance_or_issubclass(dtype, dtypes.List):
        return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
    if isinstance_or_issubclass(dtype, dtypes.Struct):
        fields = [
            pl.Field(
                field.name,
                narwhals_to_native_dtype(field.dtype, version, backend_version),
            )
            for field in dtype.fields
        ]
        return pl.Struct(fields)
    if isinstance_or_issubclass(dtype, dtypes.Array):  # pragma: no cover
        size = dtype.size
        kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
        return pl.Array(
            narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
        )
    return pl.Unknown()  # pragma: no cover


def catch_polars_exception(
    exception: Exception, backend_version: tuple[int, ...]
) -> NarwhalsError | Exception:
    if isinstance(exception, pl.exceptions.ColumnNotFoundError):
        return ColumnNotFoundError(str(exception))
    elif isinstance(exception, pl.exceptions.ShapeError):
        return ShapeError(str(exception))
    elif isinstance(exception, pl.exceptions.InvalidOperationError):
        return InvalidOperationError(str(exception))
    elif isinstance(exception, pl.exceptions.DuplicateError):
        return DuplicateError(str(exception))
    elif isinstance(exception, pl.exceptions.ComputeError):
        return ComputeError(str(exception))
    if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
        # Old versions of Polars didn't have PolarsError.
        return NarwhalsError(str(exception))  # pragma: no cover
    elif backend_version < (1,) and "polars.exceptions" in str(
        type(exception)
    ):  # pragma: no cover
        # Last attempt, for old Polars versions.
        return NarwhalsError(str(exception))
    # Just return exception as-is.
    return exception
