from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, NoReturn, overload

import polars._reexport as pl
import polars.functions as F
from polars._utils.constants import U32_MAX
from polars._utils.slice import PolarsSlice
from polars._utils.various import qualified_type_name, range_to_slice
from polars.datatypes.classes import (
    Boolean,
    Int8,
    Int16,
    Int32,
    Int64,
    String,
    UInt32,
    UInt64,
)
from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np
from polars.meta.index_type import get_index_type

if TYPE_CHECKING:
    from collections.abc import Iterable

    from polars import DataFrame, Series
    from polars._typing import (
        MultiColSelector,
        MultiIndexSelector,
        SingleColSelector,
        SingleIndexSelector,
    )

__all__ = [
    "get_df_item_by_key",
    "get_series_item_by_key",
]


@overload
def get_series_item_by_key(s: Series, key: SingleIndexSelector) -> Any: ...


@overload
def get_series_item_by_key(s: Series, key: MultiIndexSelector) -> Series: ...


def get_series_item_by_key(
    s: Series, key: SingleIndexSelector | MultiIndexSelector
) -> Any | Series:
    """Select one or more elements from the Series."""
    if isinstance(key, int):
        return s._s.get_index_signed(key)

    elif isinstance(key, slice):
        return _select_elements_by_slice(s, key)

    elif isinstance(key, range):
        key = range_to_slice(key)
        return _select_elements_by_slice(s, key)

    elif isinstance(key, Sequence):
        if not key:
            return s.clear()

        first = key[0]
        if isinstance(first, bool):
            _raise_on_boolean_mask()

        try:
            indices = pl.Series("", key, dtype=Int64)
        except TypeError:
            msg = f"cannot select elements using Sequence with elements of type {qualified_type_name(first)!r}"
            raise TypeError(msg) from None

        indices = _convert_series_to_indices(indices, s.len())
        return _select_elements_by_index(s, indices)

    elif isinstance(key, pl.Series):
        indices = _convert_series_to_indices(key, s.len())
        return _select_elements_by_index(s, indices)

    elif _check_for_numpy(key) and isinstance(key, np.ndarray):
        indices = _convert_np_ndarray_to_indices(key, s.len())
        return _select_elements_by_index(s, indices)

    msg = f"cannot select elements using key of type {qualified_type_name(key)!r}: {key!r}"
    raise TypeError(msg)


def _select_elements_by_slice(s: Series, key: slice) -> Series:
    return PolarsSlice(s).apply(key)  # type: ignore[return-value]


def _select_elements_by_index(s: Series, key: Series) -> Series:
    return s._from_pyseries(s._s.gather_with_series(key._s))


# `str` overlaps with `Sequence[str]`
# We can ignore this but we must keep this overload ordering
@overload
def get_df_item_by_key(
    df: DataFrame, key: tuple[SingleIndexSelector, SingleColSelector]
) -> Any: ...


@overload
def get_df_item_by_key(  # type: ignore[overload-overlap]
    df: DataFrame, key: str | tuple[MultiIndexSelector, SingleColSelector]
) -> Series: ...


@overload
def get_df_item_by_key(
    df: DataFrame,
    key: (
        SingleIndexSelector
        | MultiIndexSelector
        | MultiColSelector
        | tuple[SingleIndexSelector, MultiColSelector]
        | tuple[MultiIndexSelector, MultiColSelector]
    ),
) -> DataFrame: ...


def get_df_item_by_key(
    df: DataFrame,
    key: (
        SingleIndexSelector
        | SingleColSelector
        | MultiColSelector
        | MultiIndexSelector
        | tuple[SingleIndexSelector, SingleColSelector]
        | tuple[SingleIndexSelector, MultiColSelector]
        | tuple[MultiIndexSelector, SingleColSelector]
        | tuple[MultiIndexSelector, MultiColSelector]
    ),
) -> DataFrame | Series | Any:
    """Get part of the DataFrame as a new DataFrame, Series, or scalar."""
    # Two inputs, e.g. df[1, 2:5]
    if isinstance(key, tuple) and len(key) == 2:
        row_key, col_key = key

        # Support df[True, False] and df["a", "b"] as these are not ambiguous
        if isinstance(row_key, (bool, str)):
            return _select_columns(df, key)  # type: ignore[arg-type]

        selection = _select_columns(df, col_key)

        if selection.is_empty():
            return selection
        elif isinstance(selection, pl.Series):
            return get_series_item_by_key(selection, row_key)
        else:
            return _select_rows(selection, row_key)

    # Single string input, e.g. df["a"]
    if isinstance(key, str):
        # This case is required because empty strings are otherwise treated
        # as an empty Sequence in `_select_rows`
        return df.get_column(key)

    # Single input - df[1] - or multiple inputs - df["a", "b", "c"]
    try:
        return _select_rows(df, key)  # type: ignore[arg-type]
    except TypeError:
        return _select_columns(df, key)


# `str` overlaps with `Sequence[str]`
# We can ignore this but we must keep this overload ordering
@overload
def _select_columns(df: DataFrame, key: SingleColSelector) -> Series: ...  # type: ignore[overload-overlap]


@overload
def _select_columns(df: DataFrame, key: MultiColSelector) -> DataFrame: ...


def _select_columns(
    df: DataFrame, key: SingleColSelector | MultiColSelector
) -> DataFrame | Series:
    """Select one or more columns from the DataFrame."""
    if isinstance(key, int):
        return df.to_series(key)

    elif isinstance(key, str):
        return df.get_column(key)

    elif isinstance(key, slice):
        start, stop, step = key.start, key.stop, key.step
        # Fast path for common case: df[x, :]
        if start is None and stop is None and step is None:
            return df
        if isinstance(start, str):
            start = df.get_column_index(start)
        if isinstance(stop, str):
            stop = df.get_column_index(stop) + 1
        int_slice = slice(start, stop, step)
        rng = range(df.width)[int_slice]
        return _select_columns_by_index(df, rng)

    elif isinstance(key, range):
        return _select_columns_by_index(df, key)

    elif isinstance(key, Sequence):
        if not key:
            return df.__class__()
        first = key[0]
        if isinstance(first, bool):
            return _select_columns_by_mask(df, key)  # type: ignore[arg-type]
        elif isinstance(first, int):
            return _select_columns_by_index(df, key)  # type: ignore[arg-type]
        elif isinstance(first, str):
            return _select_columns_by_name(df, key)  # type: ignore[arg-type]
        else:
            msg = f"cannot select columns using Sequence with elements of type {qualified_type_name(first)!r}"
            raise TypeError(msg)

    elif isinstance(key, pl.Series):
        if key.is_empty():
            return df.__class__()
        dtype = key.dtype
        if dtype == String:
            return _select_columns_by_name(df, key)
        elif dtype.is_integer():
            return _select_columns_by_index(df, key)
        elif dtype == Boolean:
            return _select_columns_by_mask(df, key)
        else:
            msg = f"cannot select columns using Series of type {dtype}"
            raise TypeError(msg)

    elif _check_for_numpy(key) and isinstance(key, np.ndarray):
        if key.ndim == 0:
            key = np.atleast_1d(key)
        elif key.ndim != 1:
            msg = "multi-dimensional NumPy arrays not supported as index"
            raise TypeError(msg)

        if len(key) == 0:
            return df.__class__()

        dtype_kind = key.dtype.kind
        if dtype_kind in ("i", "u"):
            return _select_columns_by_index(df, key)
        elif dtype_kind == "b":
            return _select_columns_by_mask(df, key)
        elif isinstance(key[0], str):
            return _select_columns_by_name(df, key)
        else:
            msg = f"cannot select columns using NumPy array of type {key.dtype}"
            raise TypeError(msg)

    msg = (
        f"cannot select columns using key of type {qualified_type_name(key)!r}: {key!r}"
    )
    raise TypeError(msg)


def _select_columns_by_index(df: DataFrame, key: Iterable[int]) -> DataFrame:
    series = [df.to_series(i) for i in key]
    return df.__class__(series)


def _select_columns_by_name(df: DataFrame, key: Iterable[str]) -> DataFrame:
    return df._from_pydf(df._df.select(key))


def _select_columns_by_mask(
    df: DataFrame, key: Sequence[bool] | Series | np.ndarray[Any, Any]
) -> DataFrame:
    if len(key) != df.width:
        msg = f"expected {df.width} values when selecting columns by boolean mask, got {len(key)}"
        raise ValueError(msg)

    indices = (i for i, val in enumerate(key) if val)
    return _select_columns_by_index(df, indices)


@overload
def _select_rows(df: DataFrame, key: SingleIndexSelector) -> Series: ...


@overload
def _select_rows(df: DataFrame, key: MultiIndexSelector) -> DataFrame: ...


def _select_rows(
    df: DataFrame, key: SingleIndexSelector | MultiIndexSelector
) -> DataFrame | Series:
    """Select one or more rows from the DataFrame."""
    if isinstance(key, int):
        num_rows = df.height
        if (key >= num_rows) or (key < -num_rows):
            msg = f"index {key} is out of bounds for DataFrame of height {num_rows}"
            raise IndexError(msg)
        return df.slice(key, 1)

    if isinstance(key, slice):
        return _select_rows_by_slice(df, key)

    elif isinstance(key, range):
        key = range_to_slice(key)
        return _select_rows_by_slice(df, key)

    elif isinstance(key, Sequence):
        if not key:
            return df.clear()
        if isinstance(key[0], bool):
            _raise_on_boolean_mask()
        s = pl.Series("", key, dtype=Int64)
        indices = _convert_series_to_indices(s, df.height)
        return _select_rows_by_index(df, indices)

    elif isinstance(key, pl.Series):
        indices = _convert_series_to_indices(key, df.height)
        return _select_rows_by_index(df, indices)

    elif _check_for_numpy(key) and isinstance(key, np.ndarray):
        indices = _convert_np_ndarray_to_indices(key, df.height)
        return _select_rows_by_index(df, indices)

    else:
        msg = f"cannot select rows using key of type {qualified_type_name(key)!r}: {key!r}"
        raise TypeError(msg)


def _select_rows_by_slice(df: DataFrame, key: slice) -> DataFrame:
    return PolarsSlice(df).apply(key)  # type: ignore[return-value]


def _select_rows_by_index(df: DataFrame, key: Series) -> DataFrame:
    return df._from_pydf(df._df.gather_with_series(key._s))


# UTILS


def _convert_series_to_indices(s: Series, size: int) -> Series:
    """Convert a Series to indices, taking into account negative values."""
    # Unsigned or signed Series (ordered from fastest to slowest).
    #   - pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx) Series indexes.
    #   - Other unsigned Series indexes are converted to pl.UInt32 (polars)
    #     or pl.UInt64 (polars_u64_idx).
    #   - Signed Series indexes are converted pl.UInt32 (polars) or
    #     pl.UInt64 (polars_u64_idx) after negative indexes are converted
    #     to absolute indexes.

    # pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx).
    idx_type = get_index_type()

    if s.dtype == idx_type:
        return s

    if not s.dtype.is_integer():
        if s.dtype == Boolean:
            _raise_on_boolean_mask()
        else:
            msg = f"cannot treat Series of type {s.dtype} as indices"
            raise TypeError(msg)

    if s.len() == 0:
        return pl.Series(s.name, [], dtype=idx_type)

    if idx_type == UInt32:
        if s.dtype in {Int64, UInt64} and s.max() >= U32_MAX:  # type: ignore[operator]
            msg = "index positions should be smaller than 2^32"
            raise ValueError(msg)
        if s.dtype == Int64 and s.min() < -U32_MAX:  # type: ignore[operator]
            msg = "index positions should be greater than or equal to -2^32"
            raise ValueError(msg)

    if s.dtype.is_signed_integer():
        if s.min() < 0:  # type: ignore[operator]
            if idx_type == UInt32:
                idxs = s.cast(Int32) if s.dtype in {Int8, Int16} else s
            else:
                idxs = s.cast(Int64) if s.dtype in {Int8, Int16, Int32} else s

            # Update negative indexes to absolute indexes.
            return (
                idxs.to_frame()
                .select(
                    F.when(F.col(idxs.name) < 0)
                    .then(size + F.col(idxs.name))
                    .otherwise(F.col(idxs.name))
                    .cast(idx_type)
                )
                .to_series(0)
            )

    return s.cast(idx_type)


def _convert_np_ndarray_to_indices(arr: np.ndarray[Any, Any], size: int) -> Series:
    """Convert a NumPy ndarray to indices, taking into account negative values."""
    # Unsigned or signed Numpy array (ordered from fastest to slowest).
    #   - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array
    #     indexes.
    #   - Other unsigned numpy array indexes are converted to pl.UInt32
    #     (polars) or pl.UInt64 (polars_u64_idx).
    #   - Signed numpy array indexes are converted pl.UInt32 (polars) or
    #     pl.UInt64 (polars_u64_idx) after negative indexes are converted
    #     to absolute indexes.
    if arr.ndim == 0:
        arr = np.atleast_1d(arr)
    if arr.ndim != 1:
        msg = "only 1D NumPy arrays can be treated as indices"
        raise TypeError(msg)

    idx_type = get_index_type()

    if len(arr) == 0:
        return pl.Series("", [], dtype=idx_type)

    # Numpy array with signed or unsigned integers.
    if arr.dtype.kind not in ("i", "u"):
        if arr.dtype.kind == "b":
            _raise_on_boolean_mask()
        else:
            msg = f"cannot treat NumPy array of type {arr.dtype} as indices"
            raise TypeError(msg)

    if idx_type == UInt32:
        if arr.dtype in {np.int64, np.uint64} and arr.max() >= U32_MAX:
            msg = "index positions should be smaller than 2^32"
            raise ValueError(msg)
        if arr.dtype == np.int64 and arr.min() < -U32_MAX:
            msg = "index positions should be greater than or equal to -2^32"
            raise ValueError(msg)

    if arr.dtype.kind == "i" and arr.min() < 0:
        if idx_type == UInt32:
            if arr.dtype in (np.int8, np.int16):
                arr = arr.astype(np.int32)
        else:
            if arr.dtype in (np.int8, np.int16, np.int32):
                arr = arr.astype(np.int64)

        # Update negative indexes to absolute indexes.
        arr = np.where(arr < 0, size + arr, arr)

    # numpy conversion is much faster
    arr = arr.astype(np.uint32) if idx_type == UInt32 else arr.astype(np.uint64)

    return pl.Series("", arr, dtype=idx_type)


def _raise_on_boolean_mask() -> NoReturn:
    msg = (
        "selecting rows by passing a boolean mask to `__getitem__` is not supported"
        "\n\nHint: Use the `filter` method instead."
    )
    raise TypeError(msg)
