from __future__ import annotations

from itertools import chain
from typing import (
    TYPE_CHECKING,
    Any,
    Iterator,
    Literal,
    Mapping,
    Protocol,
    Sequence,
    Sized,
    TypeVar,
    overload,
)

from narwhals._compliant.typing import (
    CompliantDataFrameAny,
    CompliantExprT_contra,
    CompliantLazyFrameAny,
    CompliantSeriesT,
    EagerExprT,
    EagerSeriesT,
    NativeExprT,
    NativeFrameT,
)
from narwhals._translate import (
    ArrowConvertible,
    DictConvertible,
    FromNative,
    NumpyConvertible,
    ToNarwhals,
    ToNarwhalsT_co,
)
from narwhals._typing_compat import deprecated
from narwhals._utils import (
    Version,
    _StoresNative,
    check_columns_exist,
    is_compliant_series,
    is_index_selector,
    is_range,
    is_sequence_like,
    is_sized_multi_index_selector,
    is_slice_index,
    is_slice_none,
)

if TYPE_CHECKING:
    from io import BytesIO
    from pathlib import Path

    import pandas as pd
    import polars as pl
    import pyarrow as pa
    from typing_extensions import Self, TypeAlias

    from narwhals._compliant.expr import LazyExpr
    from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
    from narwhals._compliant.namespace import EagerNamespace
    from narwhals._compliant.window import WindowInputs
    from narwhals._translate import IntoArrowTable
    from narwhals._utils import Implementation, _FullContext
    from narwhals.dataframe import DataFrame
    from narwhals.dtypes import DType
    from narwhals.exceptions import ColumnNotFoundError
    from narwhals.schema import Schema
    from narwhals.typing import (
        AsofJoinStrategy,
        JoinStrategy,
        LazyUniqueKeepStrategy,
        MultiColSelector,
        MultiIndexSelector,
        PivotAgg,
        SingleIndexSelector,
        SizedMultiIndexSelector,
        SizedMultiNameSelector,
        SizeUnit,
        UniqueKeepStrategy,
        _2DArray,
        _SliceIndex,
        _SliceName,
    )

    Incomplete: TypeAlias = Any

__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]

T = TypeVar("T")

_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]"  # noqa: PYI047


class CompliantDataFrame(
    NumpyConvertible["_2DArray", "_2DArray"],
    DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
    ArrowConvertible["pa.Table", "IntoArrowTable"],
    _StoresNative[NativeFrameT],
    FromNative[NativeFrameT],
    ToNarwhals[ToNarwhalsT_co],
    Sized,
    Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
    _native_frame: NativeFrameT
    _implementation: Implementation
    _backend_version: tuple[int, ...]
    _version: Version

    def __narwhals_dataframe__(self) -> Self: ...
    def __narwhals_namespace__(self) -> Any: ...
    @classmethod
    def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: ...
    @classmethod
    def from_dict(
        cls,
        data: Mapping[str, Any],
        /,
        *,
        context: _FullContext,
        schema: Mapping[str, DType] | Schema | None,
    ) -> Self: ...
    @classmethod
    def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
    @classmethod
    def from_numpy(
        cls,
        data: _2DArray,
        /,
        *,
        context: _FullContext,
        schema: Mapping[str, DType] | Schema | Sequence[str] | None,
    ) -> Self: ...

    def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
    def __getitem__(
        self,
        item: tuple[
            SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
            MultiColSelector[CompliantSeriesT],
        ],
    ) -> Self: ...
    def simple_select(self, *column_names: str) -> Self:
        """`select` where all args are column names."""
        ...

    def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
        """`select` where all args are aggregations or literals.

        (so, no broadcasting is necessary).
        """
        # NOTE: Ignore is to avoid an intermittent false positive
        return self.select(*exprs)  # pyright: ignore[reportArgumentType]

    def _with_version(self, version: Version) -> Self: ...

    @property
    def native(self) -> NativeFrameT:
        return self._native_frame

    @property
    def columns(self) -> Sequence[str]: ...
    @property
    def schema(self) -> Mapping[str, DType]: ...
    @property
    def shape(self) -> tuple[int, int]: ...
    def clone(self) -> Self: ...
    def collect(
        self, backend: Implementation | None, **kwargs: Any
    ) -> CompliantDataFrameAny: ...
    def collect_schema(self) -> Mapping[str, DType]: ...
    def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
    def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
    def estimated_size(self, unit: SizeUnit) -> int | float: ...
    def explode(self, columns: Sequence[str]) -> Self: ...
    def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
    def gather_every(self, n: int, offset: int) -> Self: ...
    def get_column(self, name: str) -> CompliantSeriesT: ...
    def group_by(
        self,
        keys: Sequence[str] | Sequence[CompliantExprT_contra],
        *,
        drop_null_keys: bool,
    ) -> DataFrameGroupBy[Self, Any]: ...
    def head(self, n: int) -> Self: ...
    def item(self, row: int | None, column: int | str | None) -> Any: ...
    def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
    def iter_rows(
        self, *, named: bool, buffer_size: int
    ) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
    def is_unique(self) -> CompliantSeriesT: ...
    def join(
        self,
        other: Self,
        *,
        how: JoinStrategy,
        left_on: Sequence[str] | None,
        right_on: Sequence[str] | None,
        suffix: str,
    ) -> Self: ...
    def join_asof(
        self,
        other: Self,
        *,
        left_on: str,
        right_on: str,
        by_left: Sequence[str] | None,
        by_right: Sequence[str] | None,
        strategy: AsofJoinStrategy,
        suffix: str,
    ) -> Self: ...
    def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
    def pivot(
        self,
        on: Sequence[str],
        *,
        index: Sequence[str] | None,
        values: Sequence[str] | None,
        aggregate_function: PivotAgg | None,
        sort_columns: bool,
        separator: str,
    ) -> Self: ...
    def rename(self, mapping: Mapping[str, str]) -> Self: ...
    def row(self, index: int) -> tuple[Any, ...]: ...
    def rows(
        self, *, named: bool
    ) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
    def sample(
        self,
        n: int | None,
        *,
        fraction: float | None,
        with_replacement: bool,
        seed: int | None,
    ) -> Self: ...
    def select(self, *exprs: CompliantExprT_contra) -> Self: ...
    def sort(
        self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
    ) -> Self: ...
    def tail(self, n: int) -> Self: ...
    def to_arrow(self) -> pa.Table: ...
    def to_pandas(self) -> pd.DataFrame: ...
    def to_polars(self) -> pl.DataFrame: ...
    @overload
    def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
    @overload
    def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
    def to_dict(
        self, *, as_series: bool
    ) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
    def unique(
        self,
        subset: Sequence[str] | None,
        *,
        keep: UniqueKeepStrategy,
        maintain_order: bool | None = None,
    ) -> Self: ...
    def unpivot(
        self,
        on: Sequence[str] | None,
        index: Sequence[str] | None,
        variable_name: str,
        value_name: str,
    ) -> Self: ...
    def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
    def with_row_index(self, name: str) -> Self: ...
    @overload
    def write_csv(self, file: None) -> str: ...
    @overload
    def write_csv(self, file: str | Path | BytesIO) -> None: ...
    def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
    def write_parquet(self, file: str | Path | BytesIO) -> None: ...

    def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
        it = (expr._evaluate_aliases(self) for expr in exprs)
        return list(chain.from_iterable(it))

    def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
        return check_columns_exist(subset, available=self.columns)


class CompliantLazyFrame(
    _StoresNative[NativeFrameT],
    FromNative[NativeFrameT],
    ToNarwhals[ToNarwhalsT_co],
    Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
    _native_frame: NativeFrameT
    _implementation: Implementation
    _backend_version: tuple[int, ...]
    _version: Version

    def __narwhals_lazyframe__(self) -> Self: ...
    def __narwhals_namespace__(self) -> Any: ...

    @classmethod
    def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...

    def simple_select(self, *column_names: str) -> Self:
        """`select` where all args are column names."""
        ...

    def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
        """`select` where all args are aggregations or literals.

        (so, no broadcasting is necessary).
        """
        ...

    def _with_version(self, version: Version) -> Self: ...

    @property
    def native(self) -> NativeFrameT:
        return self._native_frame

    @property
    def columns(self) -> Sequence[str]: ...
    @property
    def schema(self) -> Mapping[str, DType]: ...
    def _iter_columns(self) -> Iterator[Any]: ...
    def collect(
        self, backend: Implementation | None, **kwargs: Any
    ) -> CompliantDataFrameAny: ...
    def collect_schema(self) -> Mapping[str, DType]: ...
    def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
    def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
    def explode(self, columns: Sequence[str]) -> Self: ...
    def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
    @deprecated(
        "`LazyFrame.gather_every` is deprecated and will be removed in a future version."
    )
    def gather_every(self, n: int, offset: int) -> Self: ...
    def group_by(
        self,
        keys: Sequence[str] | Sequence[CompliantExprT_contra],
        *,
        drop_null_keys: bool,
    ) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
    def head(self, n: int) -> Self: ...
    def join(
        self,
        other: Self,
        *,
        how: Literal["left", "inner", "cross", "anti", "semi"],
        left_on: Sequence[str] | None,
        right_on: Sequence[str] | None,
        suffix: str,
    ) -> Self: ...
    def join_asof(
        self,
        other: Self,
        *,
        left_on: str,
        right_on: str,
        by_left: Sequence[str] | None,
        by_right: Sequence[str] | None,
        strategy: AsofJoinStrategy,
        suffix: str,
    ) -> Self: ...
    def rename(self, mapping: Mapping[str, str]) -> Self: ...
    def select(self, *exprs: CompliantExprT_contra) -> Self: ...
    def sort(
        self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
    ) -> Self: ...
    @deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.")
    def tail(self, n: int) -> Self: ...
    def unique(
        self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
    ) -> Self: ...
    def unpivot(
        self,
        on: Sequence[str] | None,
        index: Sequence[str] | None,
        variable_name: str,
        value_name: str,
    ) -> Self: ...
    def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
    def with_row_index(self, name: str) -> Self: ...
    def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
        result = expr(self)
        assert len(result) == 1  # debug assertion  # noqa: S101
        return result[0]

    def _evaluate_window_expr(
        self,
        expr: LazyExpr[Self, NativeExprT],
        /,
        window_inputs: WindowInputs[NativeExprT],
    ) -> NativeExprT:
        result = expr.window_function(self, window_inputs)
        assert len(result) == 1  # debug assertion  # noqa: S101
        return result[0]

    def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
        it = (expr._evaluate_aliases(self) for expr in exprs)
        return list(chain.from_iterable(it))

    def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
        return check_columns_exist(subset, available=self.columns)


class EagerDataFrame(
    CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
    CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
    Protocol[EagerSeriesT, EagerExprT, NativeFrameT],
):
    def __narwhals_namespace__(
        self,
    ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ...

    def to_narwhals(self) -> DataFrame[NativeFrameT]:
        return self._version.dataframe(self, level="full")

    def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
        """Evaluate `expr` and ensure it has a **single** output."""
        result: Sequence[EagerSeriesT] = expr(self)
        assert len(result) == 1  # debug assertion  # noqa: S101
        return result[0]

    def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
        # NOTE: Ignore is to avoid an intermittent false positive
        return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs))  # pyright: ignore[reportArgumentType]

    def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
        """Return list of raw columns.

        For eager backends we alias operations at each step.

        As a safety precaution, here we can check that the expected result names match those
        we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.

        Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
        """
        aliases = expr._evaluate_aliases(self)
        result = expr(self)
        if list(aliases) != (
            result_aliases := [s.name for s in result]
        ):  # pragma: no cover
            msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
            raise AssertionError(msg)
        return result

    def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
        """Extract native Series, broadcasting to `len(self)` if necessary."""
        ...

    @staticmethod
    def _numpy_column_names(
        data: _2DArray, columns: Sequence[str] | None, /
    ) -> list[str]:
        return list(columns or (f"column_{x}" for x in range(data.shape[1])))

    def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ...
    def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
    def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ...
    def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ...
    def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
    def _select_slice_name(self, columns: _SliceName) -> Self: ...
    def __getitem__(  # noqa: C901, PLR0912
        self,
        item: tuple[
            SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
            MultiColSelector[EagerSeriesT],
        ],
    ) -> Self:
        rows, columns = item
        compliant = self
        if not is_slice_none(columns):
            if isinstance(columns, Sized) and len(columns) == 0:
                return compliant.select()
            if is_index_selector(columns):
                if is_slice_index(columns) or is_range(columns):
                    compliant = compliant._select_slice_index(columns)
                elif is_compliant_series(columns):
                    compliant = self._select_multi_index(columns.native)
                else:
                    compliant = compliant._select_multi_index(columns)
            elif isinstance(columns, slice):
                compliant = compliant._select_slice_name(columns)
            elif is_compliant_series(columns):
                compliant = self._select_multi_name(columns.native)
            elif is_sequence_like(columns):
                compliant = self._select_multi_name(columns)
            else:  # pragma: no cover
                msg = f"Unreachable code, got unexpected type: {type(columns)}"
                raise AssertionError(msg)

        if not is_slice_none(rows):
            if isinstance(rows, int):
                compliant = compliant._gather([rows])
            elif isinstance(rows, (slice, range)):
                compliant = compliant._gather_slice(rows)
            elif is_compliant_series(rows):
                compliant = compliant._gather(rows.native)
            elif is_sized_multi_index_selector(rows):
                compliant = compliant._gather(rows)
            else:  # pragma: no cover
                msg = f"Unreachable code, got unexpected type: {type(rows)}"
                raise AssertionError(msg)

        return compliant
