from __future__ import annotations

from typing import TYPE_CHECKING

from polars.datatypes import Boolean, Categorical, Enum, String
from polars.interchange.buffer import PolarsBuffer
from polars.interchange.protocol import (
    Column,
    ColumnNullType,
    CopyNotAllowedError,
    DtypeKind,
    Endianness,
)
from polars.interchange.utils import polars_dtype_to_dtype

if TYPE_CHECKING:
    from collections.abc import Iterator
    from typing import Any

    from polars import Series
    from polars.interchange.protocol import CategoricalDescription, ColumnBuffers, Dtype


class PolarsColumn(Column):
    """
    A column object backed by a Polars Series.

    Parameters
    ----------
    column
        The Polars Series backing the column object.
    allow_copy
        Allow data to be copied during operations on this column. If set to `False`,
        a RuntimeError will be raised if data would be copied.
    """

    def __init__(self, column: Series, *, allow_copy: bool = True) -> None:
        self._col = column
        self._allow_copy = allow_copy

    def size(self) -> int:
        """Size of the column in elements."""
        return self._col.len()

    @property
    def offset(self) -> int:
        """Offset of the first element with respect to the start of the underlying buffer."""  # noqa: W505
        if self._col.dtype == Boolean:
            return self._col._get_buffer_info()[1]
        else:
            return 0

    @property
    def dtype(self) -> Dtype:
        """Data type of the column."""
        pl_dtype = self._col.dtype
        return polars_dtype_to_dtype(pl_dtype)

    @property
    def describe_categorical(self) -> CategoricalDescription:
        """
        Description of the categorical data type of the column.

        Raises
        ------
        TypeError
            If the data type of the column is not categorical.
        """
        dtype = self._col.dtype
        if dtype == Categorical:
            categories = self._col.cat.get_categories()
            is_ordered = dtype.ordering == "physical"  # type: ignore[attr-defined]
        elif dtype == Enum:
            categories = dtype.categories  # type: ignore[attr-defined]
            is_ordered = True
        else:
            msg = "`describe_categorical` only works on categorical columns"
            raise TypeError(msg)

        return {
            "is_ordered": is_ordered,
            "is_dictionary": True,
            "categories": PolarsColumn(categories, allow_copy=self._allow_copy),
        }

    @property
    def describe_null(self) -> tuple[ColumnNullType, int | None]:
        """Description of the null representation the column uses."""
        if self.null_count == 0:
            return ColumnNullType.NON_NULLABLE, None
        else:
            return ColumnNullType.USE_BITMASK, 0

    @property
    def null_count(self) -> int:
        """The number of null elements."""
        return self._col.null_count()

    @property
    def metadata(self) -> dict[str, Any]:
        """The metadata for the column."""
        return {}

    def num_chunks(self) -> int:
        """Return the number of chunks the column consists of."""
        return self._col.n_chunks()

    def get_chunks(self, n_chunks: int | None = None) -> Iterator[PolarsColumn]:
        """
        Return an iterator yielding the column chunks.

        Parameters
        ----------
        n_chunks
            The number of chunks to return. Must be a multiple of the number of chunks
            in the column.

        Notes
        -----
        When `n_chunks` is higher than the number of chunks in the column, a slice
        must be performed that is not on the chunk boundary. This will trigger some
        compute if the column contains null values or if the column is of data type
        boolean.
        """
        total_n_chunks = self.num_chunks()
        chunks = self._col.get_chunks()

        if (n_chunks is None) or (n_chunks == total_n_chunks):
            for chunk in chunks:
                yield PolarsColumn(chunk, allow_copy=self._allow_copy)

        elif (n_chunks <= 0) or (n_chunks % total_n_chunks != 0):
            msg = (
                "`n_chunks` must be a multiple of the number of chunks of this column"
                f" ({total_n_chunks})"
            )
            raise ValueError(msg)

        else:
            subchunks_per_chunk = n_chunks // total_n_chunks
            for chunk in chunks:
                size = len(chunk)
                step = size // subchunks_per_chunk
                if size % subchunks_per_chunk != 0:
                    step += 1
                for start in range(0, step * subchunks_per_chunk, step):
                    yield PolarsColumn(
                        chunk[start : start + step], allow_copy=self._allow_copy
                    )

    def get_buffers(self) -> ColumnBuffers:
        """Return a dictionary containing the underlying buffers."""
        dtype = self._col.dtype

        if dtype == String and not self._allow_copy:
            msg = "string buffers must be converted"
            raise CopyNotAllowedError(msg)
        elif dtype == Categorical and not self._col.cat.is_local():
            if not self._allow_copy:
                msg = f"column {self._col.name!r} must be converted to a local categorical"
                raise CopyNotAllowedError(msg)
            self._col = self._col.cat.to_local()

        buffers = self._col._get_buffers()

        return {
            "data": self._wrap_data_buffer(buffers["values"]),
            "validity": self._wrap_validity_buffer(buffers["validity"]),
            "offsets": self._wrap_offsets_buffer(buffers["offsets"]),
        }

    def _wrap_data_buffer(self, buffer: Series) -> tuple[PolarsBuffer, Dtype]:
        interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
        dtype = polars_dtype_to_dtype(buffer.dtype)
        return interchange_buffer, dtype

    def _wrap_validity_buffer(
        self, buffer: Series | None
    ) -> tuple[PolarsBuffer, Dtype] | None:
        if buffer is None:
            return None

        interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
        dtype = (DtypeKind.BOOL, 1, "b", Endianness.NATIVE)
        return interchange_buffer, dtype

    def _wrap_offsets_buffer(
        self, buffer: Series | None
    ) -> tuple[PolarsBuffer, Dtype] | None:
        if buffer is None:
            return None

        interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
        dtype = (DtypeKind.INT, 64, "l", Endianness.NATIVE)
        return interchange_buffer, dtype
