from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    Iterator,
    Literal,
    Mapping,
    Sequence,
    Sized,
    cast,
    overload,
)

import polars as pl

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import (
    catch_polars_exception,
    extract_args_kwargs,
    native_to_narwhals_dtype,
)
from narwhals._utils import (
    Implementation,
    _into_arrow_table,
    check_columns_exist,
    convert_str_slice_to_int_slice,
    is_compliant_series,
    is_index_selector,
    is_range,
    is_sequence_like,
    is_slice_index,
    is_slice_none,
    parse_columns_to_drop,
    parse_version,
    requires,
    validate_backend_version,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ColumnNotFoundError

if TYPE_CHECKING:
    from types import ModuleType
    from typing import Callable, TypeVar

    import pandas as pd
    import pyarrow as pa
    from typing_extensions import Self, TypeAlias, TypeIs

    from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
    from narwhals._polars.expr import PolarsExpr
    from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
    from narwhals._translate import IntoArrowTable
    from narwhals._utils import Version, _FullContext
    from narwhals.dataframe import DataFrame, LazyFrame
    from narwhals.dtypes import DType
    from narwhals.schema import Schema
    from narwhals.typing import (
        JoinStrategy,
        MultiColSelector,
        MultiIndexSelector,
        PivotAgg,
        SingleIndexSelector,
        _2DArray,
    )

    T = TypeVar("T")
    R = TypeVar("R")

Method: TypeAlias = "Callable[..., R]"
"""Generic alias representing all methods implemented via `__getattr__`.

Where `R` is the return type.
"""

# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
INHERITED_METHODS = frozenset(
    [
        "clone",
        "drop_nulls",
        "estimated_size",
        "explode",
        "filter",
        "gather_every",
        "head",
        "is_unique",
        "item",
        "iter_rows",
        "join_asof",
        "rename",
        "row",
        "rows",
        "sample",
        "select",
        "sort",
        "tail",
        "to_arrow",
        "to_pandas",
        "unique",
        "with_columns",
        "write_csv",
        "write_parquet",
    ]
)


class PolarsDataFrame:
    clone: Method[Self]
    collect: Method[CompliantDataFrameAny]
    drop_nulls: Method[Self]
    estimated_size: Method[int | float]
    explode: Method[Self]
    filter: Method[Self]
    gather_every: Method[Self]
    item: Method[Any]
    iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
    is_unique: Method[PolarsSeries]
    join_asof: Method[Self]
    rename: Method[Self]
    row: Method[tuple[Any, ...]]
    rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
    sample: Method[Self]
    select: Method[Self]
    sort: Method[Self]
    to_arrow: Method[pa.Table]
    to_pandas: Method[pd.DataFrame]
    unique: Method[Self]
    with_columns: Method[Self]
    # NOTE: `write_csv` requires an `@overload` for `str | None`
    # Can't do that here 😟
    write_csv: Method[Any]
    write_parquet: Method[None]

    # CompliantDataFrame
    _evaluate_aliases: Any

    def __init__(
        self, df: pl.DataFrame, *, backend_version: tuple[int, ...], version: Version
    ) -> None:
        self._native_frame = df
        self._backend_version = backend_version
        self._implementation = Implementation.POLARS
        self._version = version
        validate_backend_version(self._implementation, self._backend_version)

    @classmethod
    def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
        if context._backend_version >= (1, 3):
            native = pl.DataFrame(data)
        else:
            native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
        return cls.from_native(native, context=context)

    @classmethod
    def from_dict(
        cls,
        data: Mapping[str, Any],
        /,
        *,
        context: _FullContext,
        schema: Mapping[str, DType] | Schema | None,
    ) -> Self:
        from narwhals.schema import Schema

        pl_schema = Schema(schema).to_polars() if schema is not None else schema
        return cls.from_native(pl.from_dict(data, pl_schema), context=context)

    @staticmethod
    def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
        return isinstance(obj, pl.DataFrame)

    @classmethod
    def from_native(cls, data: pl.DataFrame, /, *, context: _FullContext) -> Self:
        return cls(
            data, backend_version=context._backend_version, version=context._version
        )

    @classmethod
    def from_numpy(
        cls,
        data: _2DArray,
        /,
        *,
        context: _FullContext,  # NOTE: Maybe only `Implementation`?
        schema: Mapping[str, DType] | Schema | Sequence[str] | None,
    ) -> Self:
        from narwhals.schema import Schema

        pl_schema = (
            Schema(schema).to_polars()
            if isinstance(schema, (Mapping, Schema))
            else schema
        )
        return cls.from_native(pl.from_numpy(data, pl_schema), context=context)

    def to_narwhals(self) -> DataFrame[pl.DataFrame]:
        return self._version.dataframe(self, level="full")

    @property
    def native(self) -> pl.DataFrame:
        return self._native_frame

    def __repr__(self) -> str:  # pragma: no cover
        return "PolarsDataFrame"

    def __narwhals_dataframe__(self) -> Self:
        return self

    def __narwhals_namespace__(self) -> PolarsNamespace:
        return PolarsNamespace(
            backend_version=self._backend_version, version=self._version
        )

    def __native_namespace__(self) -> ModuleType:
        if self._implementation is Implementation.POLARS:
            return self._implementation.to_native_namespace()

        msg = f"Expected polars, got: {type(self._implementation)}"  # pragma: no cover
        raise AssertionError(msg)

    def _with_version(self, version: Version) -> Self:
        return self.__class__(
            self.native, backend_version=self._backend_version, version=version
        )

    def _with_native(self, df: pl.DataFrame) -> Self:
        return self.__class__(
            df, backend_version=self._backend_version, version=self._version
        )

    @overload
    def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...

    @overload
    def _from_native_object(self, obj: pl.DataFrame) -> Self: ...

    @overload
    def _from_native_object(self, obj: T) -> T: ...

    def _from_native_object(
        self, obj: pl.Series | pl.DataFrame | T
    ) -> Self | PolarsSeries | T:
        if isinstance(obj, pl.Series):
            return PolarsSeries.from_native(obj, context=self)
        if self._is_native(obj):
            return self._with_native(obj)
        # scalar
        return obj

    def __len__(self) -> int:
        return len(self.native)

    def head(self, n: int) -> Self:
        return self._with_native(self.native.head(n))

    def tail(self, n: int) -> Self:
        return self._with_native(self.native.tail(n))

    def __getattr__(self, attr: str) -> Any:
        if attr not in INHERITED_METHODS:  # pragma: no cover
            msg = f"{self.__class__.__name__} has not attribute '{attr}'."
            raise AttributeError(msg)

        def func(*args: Any, **kwargs: Any) -> Any:
            pos, kwds = extract_args_kwargs(args, kwargs)
            try:
                return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
            except pl.exceptions.ColumnNotFoundError as e:  # pragma: no cover
                msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
                raise ColumnNotFoundError(msg) from e
            except Exception as e:  # noqa: BLE001
                raise catch_polars_exception(e, self._backend_version) from None

        return func

    def __array__(
        self, dtype: Any | None = None, *, copy: bool | None = None
    ) -> _2DArray:
        if self._backend_version < (0, 20, 28) and copy is not None:
            msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
            raise NotImplementedError(msg)
        if self._backend_version < (0, 20, 28):
            return self.native.__array__(dtype)
        return self.native.__array__(dtype)

    def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
        return self.native.to_numpy()

    def collect_schema(self) -> dict[str, DType]:
        if self._backend_version < (1,):
            return {
                name: native_to_narwhals_dtype(
                    dtype, self._version, self._backend_version
                )
                for name, dtype in self.native.schema.items()
            }
        else:
            collected_schema = self.native.collect_schema()
            return {
                name: native_to_narwhals_dtype(
                    dtype, self._version, self._backend_version
                )
                for name, dtype in collected_schema.items()
            }

    @property
    def shape(self) -> tuple[int, int]:
        return self.native.shape

    def __getitem__(  # noqa: C901, PLR0912
        self,
        item: tuple[
            SingleIndexSelector | MultiIndexSelector[PolarsSeries],
            MultiColSelector[PolarsSeries],
        ],
    ) -> Any:
        rows, columns = item
        if self._backend_version > (0, 20, 30):
            rows_native = rows.native if is_compliant_series(rows) else rows
            columns_native = columns.native if is_compliant_series(columns) else columns
            selector = rows_native, columns_native
            selected = self.native.__getitem__(selector)  # type: ignore[index]
            return self._from_native_object(selected)
        else:  # pragma: no cover
            # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
            # Polars version we support
            # This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
            rows = list(rows) if isinstance(rows, tuple) else rows
            columns = list(columns) if isinstance(columns, tuple) else columns
            if is_numpy_array_1d(columns):
                columns = columns.tolist()

            native = self.native
            if not is_slice_none(columns):
                if isinstance(columns, Sized) and len(columns) == 0:
                    return self.select()
                if is_index_selector(columns):
                    if is_slice_index(columns) or is_range(columns):
                        native = native.select(
                            self.columns[slice(columns.start, columns.stop, columns.step)]
                        )
                    elif is_compliant_series(columns):
                        native = native[:, columns.native.to_list()]
                    else:
                        native = native[:, columns]
                elif isinstance(columns, slice):
                    native = native.select(
                        self.columns[
                            slice(*convert_str_slice_to_int_slice(columns, self.columns))
                        ]
                    )
                elif is_compliant_series(columns):
                    native = native.select(columns.native.to_list())
                elif is_sequence_like(columns):
                    native = native.select(columns)
                else:
                    msg = f"Unreachable code, got unexpected type: {type(columns)}"
                    raise AssertionError(msg)

            if not is_slice_none(rows):
                if isinstance(rows, int):
                    native = native[[rows], :]
                elif isinstance(rows, (slice, range)):
                    native = native[rows, :]
                elif is_compliant_series(rows):
                    native = native[rows.native, :]
                elif is_sequence_like(rows):
                    native = native[rows, :]
                else:
                    msg = f"Unreachable code, got unexpected type: {type(rows)}"
                    raise AssertionError(msg)

            return self._with_native(native)

    def simple_select(self, *column_names: str) -> Self:
        return self._with_native(self.native.select(*column_names))

    def aggregate(self, *exprs: Any) -> Self:
        return self.select(*exprs)

    def get_column(self, name: str) -> PolarsSeries:
        return PolarsSeries.from_native(self.native.get_column(name), context=self)

    def iter_columns(self) -> Iterator[PolarsSeries]:
        for series in self.native.iter_columns():
            yield PolarsSeries.from_native(series, context=self)

    @property
    def columns(self) -> list[str]:
        return self.native.columns

    @property
    def schema(self) -> dict[str, DType]:
        return {
            name: native_to_narwhals_dtype(dtype, self._version, self._backend_version)
            for name, dtype in self.native.schema.items()
        }

    def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
        if backend is None or backend is Implementation.POLARS:
            return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
        elif backend is Implementation.DUCKDB:
            import duckdb  # ignore-banned-import

            from narwhals._duckdb.dataframe import DuckDBLazyFrame

            # NOTE: (F841) is a false positive
            df = self.native  # noqa: F841
            return DuckDBLazyFrame(
                duckdb.table("df"),
                backend_version=parse_version(duckdb),
                version=self._version,
            )
        elif backend is Implementation.DASK:
            import dask  # ignore-banned-import
            import dask.dataframe as dd  # ignore-banned-import

            from narwhals._dask.dataframe import DaskLazyFrame

            return DaskLazyFrame(
                dd.from_pandas(self.native.to_pandas()),
                backend_version=parse_version(dask),
                version=self._version,
            )
        raise AssertionError  # pragma: no cover

    @overload
    def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...

    @overload
    def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...

    def to_dict(
        self, *, as_series: bool
    ) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
        if as_series:
            return {
                name: PolarsSeries.from_native(col, context=self)
                for name, col in self.native.to_dict().items()
            }
        else:
            return self.native.to_dict(as_series=False)

    def group_by(
        self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
    ) -> PolarsGroupBy:
        from narwhals._polars.group_by import PolarsGroupBy

        return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)

    def with_row_index(self, name: str) -> Self:
        if self._backend_version < (0, 20, 4):
            return self._with_native(self.native.with_row_count(name))
        return self._with_native(self.native.with_row_index(name))

    def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
        to_drop = parse_columns_to_drop(self, columns, strict=strict)
        return self._with_native(self.native.drop(to_drop))

    def unpivot(
        self,
        on: Sequence[str] | None,
        index: Sequence[str] | None,
        variable_name: str,
        value_name: str,
    ) -> Self:
        if self._backend_version < (1, 0, 0):
            return self._with_native(
                self.native.melt(
                    id_vars=index,
                    value_vars=on,
                    variable_name=variable_name,
                    value_name=value_name,
                )
            )
        return self._with_native(
            self.native.unpivot(
                on=on, index=index, variable_name=variable_name, value_name=value_name
            )
        )

    @requires.backend_version((1,))
    def pivot(
        self,
        on: Sequence[str],
        *,
        index: Sequence[str] | None,
        values: Sequence[str] | None,
        aggregate_function: PivotAgg | None,
        sort_columns: bool,
        separator: str,
    ) -> Self:
        try:
            result = self.native.pivot(
                on,
                index=index,
                values=values,
                aggregate_function=aggregate_function,
                sort_columns=sort_columns,
                separator=separator,
            )
        except Exception as e:  # noqa: BLE001
            raise catch_polars_exception(e, self._backend_version) from None
        return self._from_native_object(result)

    def to_polars(self) -> pl.DataFrame:
        return self.native

    def join(
        self,
        other: Self,
        *,
        how: JoinStrategy,
        left_on: Sequence[str] | None,
        right_on: Sequence[str] | None,
        suffix: str,
    ) -> Self:
        how_native = (
            "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
        )
        try:
            return self._with_native(
                self.native.join(
                    other=other.native,
                    how=how_native,  # type: ignore[arg-type]
                    left_on=left_on,
                    right_on=right_on,
                    suffix=suffix,
                )
            )
        except Exception as e:  # noqa: BLE001
            raise catch_polars_exception(e, self._backend_version) from None

    def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
        return check_columns_exist(subset, available=self.columns)


class PolarsLazyFrame:
    drop_nulls: Method[Self]
    explode: Method[Self]
    filter: Method[Self]
    gather_every: Method[Self]
    head: Method[Self]
    join_asof: Method[Self]
    rename: Method[Self]
    select: Method[Self]
    sort: Method[Self]
    tail: Method[Self]
    unique: Method[Self]
    with_columns: Method[Self]

    # CompliantLazyFrame
    _evaluate_expr: Any
    _evaluate_window_expr: Any
    _evaluate_aliases: Any

    def __init__(
        self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], version: Version
    ) -> None:
        self._native_frame = df
        self._backend_version = backend_version
        self._implementation = Implementation.POLARS
        self._version = version
        validate_backend_version(self._implementation, self._backend_version)

    @staticmethod
    def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
        return isinstance(obj, pl.LazyFrame)

    @classmethod
    def from_native(cls, data: pl.LazyFrame, /, *, context: _FullContext) -> Self:
        return cls(
            data, backend_version=context._backend_version, version=context._version
        )

    def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
        return self._version.lazyframe(self, level="lazy")

    def __repr__(self) -> str:  # pragma: no cover
        return "PolarsLazyFrame"

    def __narwhals_lazyframe__(self) -> Self:
        return self

    def __narwhals_namespace__(self) -> PolarsNamespace:
        return PolarsNamespace(
            backend_version=self._backend_version, version=self._version
        )

    def __native_namespace__(self) -> ModuleType:
        if self._implementation is Implementation.POLARS:
            return self._implementation.to_native_namespace()

        msg = f"Expected polars, got: {type(self._implementation)}"  # pragma: no cover
        raise AssertionError(msg)

    def _with_native(self, df: pl.LazyFrame) -> Self:
        return self.__class__(
            df, backend_version=self._backend_version, version=self._version
        )

    def _with_version(self, version: Version) -> Self:
        return self.__class__(
            self.native, backend_version=self._backend_version, version=version
        )

    def __getattr__(self, attr: str) -> Any:
        if attr not in INHERITED_METHODS:  # pragma: no cover
            msg = f"{self.__class__.__name__} has not attribute '{attr}'."
            raise AttributeError(msg)

        def func(*args: Any, **kwargs: Any) -> Any:
            pos, kwds = extract_args_kwargs(args, kwargs)
            try:
                return self._with_native(getattr(self.native, attr)(*pos, **kwds))
            except pl.exceptions.ColumnNotFoundError as e:  # pragma: no cover
                raise ColumnNotFoundError(str(e)) from e

        return func

    def _iter_columns(self) -> Iterator[PolarsSeries]:  # pragma: no cover
        yield from self.collect(self._implementation).iter_columns()

    @property
    def native(self) -> pl.LazyFrame:
        return self._native_frame

    @property
    def columns(self) -> list[str]:
        return self.native.columns

    @property
    def schema(self) -> dict[str, DType]:
        schema = self.native.schema
        return {
            name: native_to_narwhals_dtype(dtype, self._version, self._backend_version)
            for name, dtype in schema.items()
        }

    def collect_schema(self) -> dict[str, DType]:
        if self._backend_version < (1,):
            return {
                name: native_to_narwhals_dtype(
                    dtype, self._version, self._backend_version
                )
                for name, dtype in self.native.schema.items()
            }
        else:
            try:
                collected_schema = self.native.collect_schema()
            except Exception as e:  # noqa: BLE001
                raise catch_polars_exception(e, self._backend_version) from None
            return {
                name: native_to_narwhals_dtype(
                    dtype, self._version, self._backend_version
                )
                for name, dtype in collected_schema.items()
            }

    def collect(
        self, backend: Implementation | None, **kwargs: Any
    ) -> CompliantDataFrameAny:
        try:
            result = self.native.collect(**kwargs)
        except Exception as e:  # noqa: BLE001
            raise catch_polars_exception(e, self._backend_version) from None

        if backend is None or backend is Implementation.POLARS:
            return PolarsDataFrame.from_native(result, context=self)

        if backend is Implementation.PANDAS:
            import pandas as pd  # ignore-banned-import

            from narwhals._pandas_like.dataframe import PandasLikeDataFrame

            return PandasLikeDataFrame(
                result.to_pandas(),
                implementation=Implementation.PANDAS,
                backend_version=parse_version(pd),
                version=self._version,
                validate_column_names=False,
            )

        if backend is Implementation.PYARROW:
            import pyarrow as pa  # ignore-banned-import

            from narwhals._arrow.dataframe import ArrowDataFrame

            return ArrowDataFrame(
                result.to_arrow(),
                backend_version=parse_version(pa),
                version=self._version,
                validate_column_names=False,
            )

        msg = f"Unsupported `backend` value: {backend}"  # pragma: no cover
        raise ValueError(msg)  # pragma: no cover

    def group_by(
        self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
    ) -> PolarsLazyGroupBy:
        from narwhals._polars.group_by import PolarsLazyGroupBy

        return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)

    def with_row_index(self, name: str) -> Self:
        if self._backend_version < (0, 20, 4):
            return self._with_native(self.native.with_row_count(name))
        return self._with_native(self.native.with_row_index(name))

    def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
        if self._backend_version < (1, 0, 0):
            return self._with_native(self.native.drop(columns))
        return self._with_native(self.native.drop(columns, strict=strict))

    def unpivot(
        self,
        on: Sequence[str] | None,
        index: Sequence[str] | None,
        variable_name: str,
        value_name: str,
    ) -> Self:
        if self._backend_version < (1, 0, 0):
            return self._with_native(
                self.native.melt(
                    id_vars=index,
                    value_vars=on,
                    variable_name=variable_name,
                    value_name=value_name,
                )
            )
        return self._with_native(
            self.native.unpivot(
                on=on, index=index, variable_name=variable_name, value_name=value_name
            )
        )

    def simple_select(self, *column_names: str) -> Self:
        return self._with_native(self.native.select(*column_names))

    def aggregate(self, *exprs: Any) -> Self:
        return self.select(*exprs)

    def join(
        self,
        other: Self,
        *,
        how: JoinStrategy,
        left_on: Sequence[str] | None,
        right_on: Sequence[str] | None,
        suffix: str,
    ) -> Self:
        how_native = (
            "outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
        )
        return self._with_native(
            self.native.join(
                other=other.native,
                how=how_native,  # type: ignore[arg-type]
                left_on=left_on,
                right_on=right_on,
                suffix=suffix,
            )
        )

    def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
        return check_columns_exist(  # pragma: no cover
            subset, available=self.columns
        )
