from __future__ import annotations

import operator
from functools import reduce
from typing import TYPE_CHECKING, Iterable, Sequence, cast

import dask.dataframe as dd
import pandas as pd

from narwhals._compliant import CompliantThen, CompliantWhen, LazyNamespace
from narwhals._compliant.namespace import DepthTrackingNamespace
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import (
    align_series_full_broadcast,
    narwhals_to_native_dtype,
    validate_comparand,
)
from narwhals._expression_parsing import (
    ExprKind,
    combine_alias_output_names,
    combine_evaluate_output_names,
)
from narwhals._utils import Implementation

if TYPE_CHECKING:
    import dask.dataframe.dask_expr as dx

    from narwhals._utils import Version
    from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral


class DaskNamespace(
    LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
    DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
):
    _implementation: Implementation = Implementation.DASK

    @property
    def selectors(self) -> DaskSelectorNamespace:
        return DaskSelectorNamespace.from_namespace(self)

    @property
    def _expr(self) -> type[DaskExpr]:
        return DaskExpr

    @property
    def _lazyframe(self) -> type[DaskLazyFrame]:
        return DaskLazyFrame

    def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None:
        self._backend_version = backend_version
        self._version = version

    def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            if dtype is not None:
                native_dtype = narwhals_to_native_dtype(dtype, self._version)
                native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
            else:
                native_pd_series = pd.Series([value], name="literal")
            npartitions = df._native_frame.npartitions
            dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
            return [dask_series[0].to_series()]

        return self._expr(
            func,
            depth=0,
            function_name="lit",
            evaluate_output_names=lambda _df: ["literal"],
            alias_output_names=None,
            backend_version=self._backend_version,
            version=self._version,
        )

    def len(self) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            # We don't allow dataframes with 0 columns, so `[0]` is safe.
            return [df._native_frame[df.columns[0]].size.to_series()]

        return self._expr(
            func,
            depth=0,
            function_name="len",
            evaluate_output_names=lambda _df: ["len"],
            alias_output_names=None,
            backend_version=self._backend_version,
            version=self._version,
        )

    def all_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            series = align_series_full_broadcast(
                df, *(s for _expr in exprs for s in _expr(df))
            )
            return [reduce(operator.and_, series)]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="all_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def any_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            series = align_series_full_broadcast(
                df, *(s for _expr in exprs for s in _expr(df))
            )
            return [reduce(operator.or_, series)]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="any_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            series = align_series_full_broadcast(
                df, *(s for _expr in exprs for s in _expr(df))
            )
            return [dd.concat(series, axis=1).sum(axis=1)]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="sum_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def concat(
        self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
    ) -> DaskLazyFrame:
        if not items:
            msg = "No items to concatenate"  # pragma: no cover
            raise AssertionError(msg)
        dfs = [i._native_frame for i in items]
        cols_0 = dfs[0].columns
        if how == "vertical":
            for i, df in enumerate(dfs[1:], start=1):
                cols_current = df.columns
                if not (
                    (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
                ):
                    msg = (
                        "unable to vstack, column names don't match:\n"
                        f"   - dataframe 0: {cols_0.to_list()}\n"
                        f"   - dataframe {i}: {cols_current.to_list()}\n"
                    )
                    raise TypeError(msg)
            return DaskLazyFrame(
                dd.concat(dfs, axis=0, join="inner"),
                backend_version=self._backend_version,
                version=self._version,
            )
        if how == "diagonal":
            return DaskLazyFrame(
                dd.concat(dfs, axis=0, join="outer"),
                backend_version=self._backend_version,
                version=self._version,
            )

        raise NotImplementedError

    def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            expr_results = [s for _expr in exprs for s in _expr(df)]
            series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
            non_na = align_series_full_broadcast(
                df, *(1 - s.isna() for s in expr_results)
            )
            num = reduce(lambda x, y: x + y, series)  # pyright: ignore[reportOperatorIssue]
            den = reduce(lambda x, y: x + y, non_na)  # pyright: ignore[reportOperatorIssue]
            return [cast("dx.Series", num / den)]  # pyright: ignore[reportOperatorIssue]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="mean_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            series = align_series_full_broadcast(
                df, *(s for _expr in exprs for s in _expr(df))
            )

            return [dd.concat(series, axis=1).min(axis=1)]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="min_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            series = align_series_full_broadcast(
                df, *(s for _expr in exprs for s in _expr(df))
            )

            return [dd.concat(series, axis=1).max(axis=1)]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="max_horizontal",
            evaluate_output_names=combine_evaluate_output_names(*exprs),
            alias_output_names=combine_alias_output_names(*exprs),
            backend_version=self._backend_version,
            version=self._version,
        )

    def when(self, predicate: DaskExpr) -> DaskWhen:
        return DaskWhen.from_expr(predicate, context=self)

    def concat_str(
        self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
    ) -> DaskExpr:
        def func(df: DaskLazyFrame) -> list[dx.Series]:
            expr_results = [s for _expr in exprs for s in _expr(df)]
            series = (
                s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
            )
            null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]

            if not ignore_nulls:
                null_mask_result = reduce(operator.or_, null_mask)
                result = reduce(lambda x, y: x + separator + y, series).where(
                    ~null_mask_result, None
                )
            else:
                init_value, *values = [
                    s.where(~nm, "") for s, nm in zip(series, null_mask)
                ]

                separators = (
                    nm.map({True: "", False: separator}, meta=str)
                    for nm in null_mask[:-1]
                )
                result = reduce(
                    operator.add, (s + v for s, v in zip(separators, values)), init_value
                )

            return [result]

        return self._expr(
            call=func,
            depth=max(x._depth for x in exprs) + 1,
            function_name="concat_str",
            evaluate_output_names=getattr(
                exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
            ),
            alias_output_names=getattr(exprs[0], "_alias_output_names", None),
            backend_version=self._backend_version,
            version=self._version,
        )


class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
    @property
    def _then(self) -> type[DaskThen]:
        return DaskThen

    def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
        then_value = (
            self._then_value(df)[0]
            if isinstance(self._then_value, DaskExpr)
            else self._then_value
        )
        otherwise_value = (
            self._otherwise_value(df)[0]
            if isinstance(self._otherwise_value, DaskExpr)
            else self._otherwise_value
        )

        condition = self._condition(df)[0]
        # re-evaluate DataFrame if the condition aggregates to force
        #   then/otherwise to be evaluated against the aggregated frame
        assert self._condition._metadata is not None  # noqa: S101
        if self._condition._metadata.is_scalar_like:
            new_df = df._with_native(condition.to_frame())
            condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
            df = new_df

        if self._otherwise_value is None:
            (condition, then_series) = align_series_full_broadcast(
                df, condition, then_value
            )
            validate_comparand(condition, then_series)
            return [then_series.where(condition)]  # pyright: ignore[reportArgumentType]
        (condition, then_series, otherwise_series) = align_series_full_broadcast(
            df, condition, then_value, otherwise_value
        )
        validate_comparand(condition, then_series)
        validate_comparand(condition, otherwise_series)
        return [then_series.where(condition, otherwise_series)]  # pyright: ignore[reportArgumentType]


class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ...
