from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, overload

import hypothesis.strategies as st
from hypothesis.errors import InvalidArgument

from polars import select, when
from polars._utils.deprecation import issue_deprecation_warning
from polars.dataframe import DataFrame
from polars.datatypes import Array, Boolean, DataType, DataTypeClass, List, Null, Struct
from polars.series import Series
from polars.string_cache import StringCache
from polars.testing.parametric.strategies._utils import flexhash
from polars.testing.parametric.strategies.data import data
from polars.testing.parametric.strategies.dtype import _instantiate_dtype, dtypes

if TYPE_CHECKING:
    from collections.abc import Collection, Sequence
    from typing import Literal

    from hypothesis.strategies import DrawFn, SearchStrategy

    from polars import LazyFrame
    from polars._typing import PolarsDataType


_ROW_LIMIT = 5  # max generated frame/series length
_COL_LIMIT = 5  # max number of generated cols


@st.composite
def series(
    draw: DrawFn,
    /,
    *,
    name: str | SearchStrategy[str] | None = None,
    dtype: PolarsDataType | None = None,
    min_size: int = 0,
    max_size: int = _ROW_LIMIT,
    strategy: SearchStrategy[Any] | None = None,
    allow_null: bool = True,
    allow_chunks: bool = True,
    allow_masked_out: bool = True,
    unique: bool = False,
    allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    allow_time_zones: bool = True,
    **kwargs: Any,
) -> Series:
    """
    Hypothesis strategy for producing Polars Series.

    .. warning::
        This functionality is currently considered **unstable**. It may be
        changed at any point without it being considered a breaking change.

    Parameters
    ----------
    name : {str, strategy}, optional
        literal string or a strategy for strings (or None), passed to the Series
        constructor name-param.
    dtype : PolarsDataType, optional
        a valid polars DataType for the resulting series.
    min_size : int
        if not passing an exact size, can set a minimum here (defaults to 0).
        no-op if `size` is set.
    max_size : int
        if not passing an exact size, can set a maximum value here (defaults to
        MAX_DATA_SIZE). no-op if `size` is set.
    strategy : strategy, optional
        supports overriding the default strategy for the given dtype.
    allow_null : bool
        Allow nulls as possible values and allow the `Null` data type by default.
    allow_chunks : bool
        Allow the Series to contain multiple chunks.
    allow_masked_out : bool
        Allow the nulls to contain masked out elements.
    unique : bool, optional
        indicate whether Series values should all be distinct.
    allowed_dtypes : {list,set}, optional
        when automatically generating Series data, allow only these dtypes.
    excluded_dtypes : {list,set}, optional
        when automatically generating Series data, exclude these dtypes.
    allow_time_zones
        Allow generating `Datetime` Series with a time zone.
    **kwargs
        Additional keyword arguments that are passed to the underlying data generation
        strategies.

    size : int, optional
        if set, creates a Series of exactly this size (ignoring min_size/max_size
        params).

        .. deprecated:: 1.0.0
            Use `min_size` and `max_size` instead.

    null_probability : float
        Percentage chance (expressed between 0.0 => 1.0) that any Series value is null.
        This is applied independently of any None values generated by the underlying
        strategy.

        .. deprecated:: 0.20.26
            Use `allow_null` instead.

    allow_infinities : bool, optional
        Allow generation of +/-inf values for floating-point dtypes.

        .. deprecated:: 0.20.26
            Use `allow_infinity` instead.

    Notes
    -----
    In actual usage this is deployed as a unit test decorator, providing a strategy
    that generates multiple Series with the given dtype/size characteristics for the
    unit test. While developing a strategy/test, it can also be useful to call
    `.example()` directly on a given strategy to see concrete instances of the
    generated data.

    Examples
    --------
    The strategy is generally used to generate series in a unit test:

    >>> from polars.testing.parametric import series
    >>> from hypothesis import given
    >>> @given(s=series(min_size=3, max_size=5))
    ... def test_series_len(s: pl.Series) -> None:
    ...     assert 3 <= s.len() <= 5

    Drawing examples interactively is also possible with the `.example()` method.
    This should be avoided while running tests.

    >>> from polars.testing.parametric import lists
    >>> s = series(strategy=lists(pl.String, select_from=["xx", "yy", "zz"]))
    >>> s.example()  # doctest: +SKIP
    shape: (4,)
    Series: '' [list[str]]
    [
            ["zz", "zz"]
            ["zz", "xx", "yy"]
            []
            ["xx"]
    ]
    """
    if (null_prob := kwargs.pop("null_probability", None)) is not None:
        allow_null = _handle_null_probability_deprecation(null_prob)  # type: ignore[assignment]
    if (allow_inf := kwargs.pop("allow_infinities", None)) is not None:
        issue_deprecation_warning(
            "`allow_infinities` is deprecated. Use `allow_infinity` instead.",
            version="0.20.26",
        )
        kwargs["allow_infinity"] = allow_inf
    if (chunked := kwargs.pop("chunked", None)) is not None:
        issue_deprecation_warning(
            "`chunked` is deprecated. Use `allow_chunks` instead.",
            version="0.20.26",
        )
        allow_chunks = chunked
    if (size := kwargs.pop("size", None)) is not None:
        issue_deprecation_warning(
            "`size` is deprecated. Use `min_size` and `max_size` instead.",
            version="1.0.0",
        )
        min_size = max_size = size

    if isinstance(allowed_dtypes, (DataType, DataTypeClass)):
        allowed_dtypes = [allowed_dtypes]
    elif allowed_dtypes is not None:
        allowed_dtypes = list(allowed_dtypes)
    if isinstance(excluded_dtypes, (DataType, DataTypeClass)):
        excluded_dtypes = [excluded_dtypes]
    elif excluded_dtypes is not None:
        excluded_dtypes = list(excluded_dtypes)

    if not allow_null and not (allowed_dtypes is not None and Null in allowed_dtypes):
        if excluded_dtypes is None:
            excluded_dtypes = [Null]
        else:
            excluded_dtypes.append(Null)

    if strategy is None:
        if dtype is None:
            dtype_strat = dtypes(
                allowed_dtypes=allowed_dtypes,
                excluded_dtypes=excluded_dtypes,
                allow_time_zones=allow_time_zones,
            )
        else:
            dtype_strat = _instantiate_dtype(
                dtype,
                allowed_dtypes=allowed_dtypes,
                excluded_dtypes=excluded_dtypes,
                allow_time_zones=allow_time_zones,
            )
        dtype = draw(dtype_strat)

    if min_size == max_size:
        size = min_size
    else:
        size = draw(st.integers(min_value=min_size, max_value=max_size))

    if isinstance(name, st.SearchStrategy):
        name = draw(name)

    do_mask_out = (
        allow_masked_out
        and allow_null
        and isinstance(dtype, (List, Array, Struct))
        and draw(st.booleans())
    )

    if size == 0:
        values = []
    else:
        # Create series using dtype-specific strategy to generate values
        if strategy is None:
            strategy = data(
                dtype,  # type: ignore[arg-type]
                allow_null=allow_null and not do_mask_out,
                **kwargs,
            )

        values = draw(
            st.lists(
                strategy,
                min_size=size,
                max_size=size,
                unique_by=(flexhash if unique else None),
            )
        )

    s = Series(name=name, values=values, dtype=dtype)

    # Apply masking out of values
    if do_mask_out:
        values = draw(
            st.lists(
                st.booleans(),
                min_size=size,
                max_size=size,
                unique_by=(flexhash if unique else None),
            )
        )

        mask = Series(name=None, values=values, dtype=Boolean)
        s = select(when(mask).then(s).alias(s.name)).to_series()

    # Apply chunking
    if allow_chunks and size > 1 and draw(st.booleans()):
        split_at = size // 2
        s = s[:split_at].append(s[split_at:])

    return s


@overload
def dataframes(
    cols: int | column | Sequence[column] | None = None,
    *,
    lazy: Literal[False] = ...,
    min_cols: int = 0,
    max_cols: int = _COL_LIMIT,
    min_size: int = 0,
    max_size: int = _ROW_LIMIT,
    include_cols: Sequence[column] | column | None = None,
    allow_null: bool | Mapping[str, bool] = True,
    allow_chunks: bool = True,
    allow_masked_out: bool = True,
    allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    allow_time_zones: bool = True,
    **kwargs: Any,
) -> SearchStrategy[DataFrame]: ...


@overload
def dataframes(
    cols: int | column | Sequence[column] | None = None,
    *,
    lazy: Literal[True],
    min_cols: int = 0,
    max_cols: int = _COL_LIMIT,
    min_size: int = 0,
    max_size: int = _ROW_LIMIT,
    include_cols: Sequence[column] | column | None = None,
    allow_null: bool | Mapping[str, bool] = True,
    allow_chunks: bool = True,
    allow_masked_out: bool = True,
    allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    allow_time_zones: bool = True,
    **kwargs: Any,
) -> SearchStrategy[LazyFrame]: ...


@st.composite
def dataframes(
    draw: DrawFn,
    /,
    cols: int | column | Sequence[column] | None = None,
    *,
    lazy: bool = False,
    min_cols: int = 1,
    max_cols: int = _COL_LIMIT,
    min_size: int = 0,
    max_size: int = _ROW_LIMIT,
    include_cols: Sequence[column] | column | None = None,
    allow_null: bool | Mapping[str, bool] = True,
    allow_chunks: bool = True,
    allow_masked_out: bool = True,
    allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
    allow_time_zones: bool = True,
    **kwargs: Any,
) -> DataFrame | LazyFrame:
    """
    Hypothesis strategy for producing Polars DataFrames or LazyFrames.

    .. warning::
        This functionality is currently considered **unstable**. It may be
        changed at any point without it being considered a breaking change.

    Parameters
    ----------
    cols : {int, columns}, optional
        integer number of columns to create, or a sequence of `column` objects
        that describe the desired DataFrame column data.
    lazy : bool, optional
        produce a LazyFrame instead of a DataFrame.
    min_cols : int, optional
        if not passing an exact size, can set a minimum here (defaults to 0).
    max_cols : int, optional
        if not passing an exact size, can set a maximum value here (defaults to
        MAX_COLS).
    min_size : int, optional
        if not passing an exact size, set the minimum number of rows in the
        DataFrame.
    max_size : int, optional
        if not passing an exact size, set the maximum number of rows in the
        DataFrame.
    include_cols : [column], optional
        a list of `column` objects to include in the generated DataFrame. note that
        explicitly provided columns are appended onto the list of existing columns
        (if any present).
    allow_null : bool or Mapping[str, bool]
        Allow nulls as possible values and allow the `Null` data type by default.
        Accepts either a boolean or a mapping of column names to booleans.
    allow_chunks : bool
        Allow the DataFrame to contain multiple chunks.
    allow_masked_out : bool
        Allow the nulls to contain masked out elements.
    allowed_dtypes : {list,set}, optional
        when automatically generating data, allow only these dtypes.
    excluded_dtypes : {list,set}, optional
        when automatically generating data, exclude these dtypes.
    allow_time_zones
        Allow generating `Datetime` columns with a time zone.
    **kwargs
        Additional keyword arguments that are passed to the underlying data generation
        strategies.

    size : int, optional
        if set, will create a DataFrame of exactly this size (and ignore
        the min_size/max_size len params).

        .. deprecated:: 1.0.0
            Use `min_size` and `max_size` instead.

    null_probability : {float, dict[str,float]}, optional
        percentage chance (expressed between 0.0 => 1.0) that a generated value is
        None. this is applied independently of any None values generated by the
        underlying strategy, and can be applied either on a per-column basis (if
        given as a `{col:pct}` dict), or globally. if null_probability is defined
        on a column, it takes precedence over the global value.

        .. deprecated:: 0.20.26
            Use `allow_null` instead.

    allow_infinities : bool, optional
        optionally disallow generation of +/-inf values for floating-point dtypes.

        .. deprecated:: 0.20.26
            Use `allow_infinity` instead.

    Notes
    -----
    In actual usage this is deployed as a unit test decorator, providing a strategy
    that generates DataFrames or LazyFrames with the given characteristics for
    the unit test. While developing a strategy/test, it can also be useful to
    call `.example()` directly on a given strategy to see concrete instances of
    the generated data.

    Examples
    --------
    The strategy is generally used to generate series in a unit test:

    >>> from polars.testing.parametric import dataframes
    >>> from hypothesis import given
    >>> @given(df=dataframes(min_size=3, max_size=5))
    ... def test_df_height(df: pl.DataFrame) -> None:
    ...     assert 3 <= df.height <= 5

    Drawing examples interactively is also possible with the `.example()` method.
    This should be avoided while running tests.

    >>> df = dataframes(allowed_dtypes=[pl.Datetime, pl.Float64], max_cols=3)
    >>> df.example()  # doctest: +SKIP
    shape: (3, 3)
    ┌─────────────┬────────────────────────────┬───────────┐
    │ col0        ┆ col1                       ┆ col2      │
    │ ---         ┆ ---                        ┆ ---       │
    │ f64         ┆ datetime[ns]               ┆ f64       │
    ╞═════════════╪════════════════════════════╪═══════════╡
    │ NaN         ┆ 1844-07-05 06:19:48.848808 ┆ 3.1436e16 │
    │ -1.9914e218 ┆ 2068-12-01 23:05:11.412277 ┆ 2.7415e16 │
    │ 0.5         ┆ 2095-11-19 22:05:17.647961 ┆ -0.5      │
    └─────────────┴────────────────────────────┴───────────┘

    Use :class:`column` for more control over which exactly which columns are generated.

    >>> from polars.testing.parametric import column
    >>> dfs = dataframes(
    ...     [
    ...         column("x", dtype=pl.Int32),
    ...         column("y", dtype=pl.Float64),
    ...     ],
    ...     min_size=2,
    ...     max_size=2,
    ... )
    >>> dfs.example()  # doctest: +SKIP
    shape: (2, 2)
    ┌───────────┬────────────┐
    │ x         ┆ y          │
    │ ---       ┆ ---        │
    │ i32       ┆ f64        │
    ╞═══════════╪════════════╡
    │ -15836    ┆ 1.1755e-38 │
    │ 575050513 ┆ NaN        │
    └───────────┴────────────┘
    """
    if (null_prob := kwargs.pop("null_probability", None)) is not None:
        allow_null = _handle_null_probability_deprecation(null_prob)
    if (allow_inf := kwargs.pop("allow_infinities", None)) is not None:
        issue_deprecation_warning(
            "`allow_infinities` is deprecated. Use `allow_infinity` instead.",
            version="0.20.26",
        )
        kwargs["allow_infinity"] = allow_inf
    if (chunked := kwargs.pop("chunked", None)) is not None:
        issue_deprecation_warning(
            "`chunked` is deprecated. Use `allow_chunks` instead.",
            version="0.20.26",
        )
        allow_chunks = chunked
    if (size := kwargs.pop("size", None)) is not None:
        issue_deprecation_warning(
            "`size` is deprecated. Use `min_size` and `max_size` instead.",
            version="1.0.0",
        )
        min_size = max_size = size

    if isinstance(include_cols, column):
        include_cols = [include_cols]

    if cols is None:
        n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
        cols = [column() for _ in range(n_cols)]
    elif isinstance(cols, int):
        cols = [column() for _ in range(cols)]
    elif isinstance(cols, column):
        cols = [cols]
    else:
        cols = list(cols)

    if include_cols:
        cols.extend(list(include_cols))

    if min_size == max_size:
        size = min_size
    else:
        size = draw(st.integers(min_value=min_size, max_value=max_size))

    # Process columns
    for idx, c in enumerate(cols):
        if c.name is None:
            c.name = f"col{idx}"
        if c.allow_null is None:
            if isinstance(allow_null, Mapping):
                c.allow_null = allow_null.get(c.name, True)
            else:
                c.allow_null = allow_null

    allow_series_chunks = draw(st.booleans()) if allow_chunks else False

    with StringCache():
        data = {
            c.name: draw(
                series(
                    name=c.name,
                    dtype=c.dtype,
                    min_size=size,
                    max_size=size,
                    strategy=c.strategy,
                    allow_null=c.allow_null,  # type: ignore[arg-type]
                    allow_chunks=allow_series_chunks,
                    allow_masked_out=allow_masked_out,
                    unique=c.unique,
                    allowed_dtypes=allowed_dtypes,
                    excluded_dtypes=excluded_dtypes,
                    allow_time_zones=allow_time_zones,
                    **kwargs,
                )
            )
            for c in cols
        }

    df = DataFrame(data)

    # Apply chunking
    if allow_chunks and size > 1 and not allow_series_chunks and draw(st.booleans()):
        split_at = size // 2
        df = df[:split_at].vstack(df[split_at:])

    if lazy:
        return df.lazy()

    return df


@dataclass
class column:
    """
    Define a column for use with the `dataframes` strategy.

    .. warning::
        This functionality is currently considered **unstable**. It may be
        changed at any point without it being considered a breaking change.

    Parameters
    ----------
    name : str
        string column name.
    dtype : PolarsDataType
        a polars dtype.
    strategy : strategy, optional
        supports overriding the default strategy for the given dtype.
    allow_null : bool, optional
        Allow nulls as possible values and allow the `Null` data type by default.
    unique : bool, optional
        flag indicating that all values generated for the column should be unique.

    null_probability : float, optional
        percentage chance (expressed between 0.0 => 1.0) that a generated value is
        None. this is applied independently of any None values generated by the
        underlying strategy.

        .. deprecated:: 0.20.26
            Use `allow_null` instead.

    Examples
    --------
    >>> from polars.testing.parametric import column
    >>> dfs = dataframes(
    ...     [
    ...         column("x", dtype=pl.Int32, allow_null=True),
    ...         column("y", dtype=pl.Float64),
    ...     ],
    ...     size=2,
    ... )
    >>> dfs.example()  # doctest: +SKIP
    shape: (2, 2)
    ┌───────────┬────────────┐
    │ x         ┆ y          │
    │ ---       ┆ ---        │
    │ i32       ┆ f64        │
    ╞═══════════╪════════════╡
    │ null      ┆ 1.1755e-38 │
    │ 575050513 ┆ inf        │
    └───────────┴────────────┘
    """

    name: str | None = None
    dtype: PolarsDataType | None = None
    strategy: SearchStrategy[Any] | None = None
    allow_null: bool | None = None
    unique: bool = False

    null_probability: float | None = None

    def __post_init__(self) -> None:
        if self.null_probability is not None:
            self.allow_null = _handle_null_probability_deprecation(  # type: ignore[assignment]
                self.null_probability
            )


def _handle_null_probability_deprecation(
    null_probability: float | Mapping[str, float],
) -> bool | dict[str, bool]:
    issue_deprecation_warning(
        "`null_probability` is deprecated. Use `allow_null` instead.",
        version="0.20.26",
    )

    def prob_to_bool(prob: float) -> bool:
        if not (0.0 <= prob <= 1.0):
            msg = f"`null_probability` should be between 0.0 and 1.0, got {prob!r}"
            raise InvalidArgument(msg)

        return bool(prob)

    if isinstance(null_probability, Mapping):
        return {col: prob_to_bool(prob) for col, prob in null_probability.items()}
    else:
        return prob_to_bool(null_probability)
