# Utilities for expression parsing
# Useful for backends which don't have any concept of expressions, such
# and pandas or PyArrow.
from __future__ import annotations

from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, cast

from narwhals._utils import is_compliant_expr
from narwhals.dependencies import is_narwhals_series, is_numpy_array
from narwhals.exceptions import (
    InvalidOperationError,
    LengthChangingExprError,
    MultiOutputExpressionError,
    ShapeError,
)

if TYPE_CHECKING:
    from typing_extensions import Never, TypeIs

    from narwhals._compliant import CompliantExpr, CompliantFrameT
    from narwhals._compliant.typing import (
        AliasNames,
        CompliantExprAny,
        CompliantFrameAny,
        CompliantNamespaceAny,
        EagerNamespaceAny,
        EvalNames,
    )
    from narwhals.expr import Expr
    from narwhals.series import Series
    from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray

    T = TypeVar("T")


def is_expr(obj: Any) -> TypeIs[Expr]:
    """Check whether `obj` is a Narwhals Expr."""
    from narwhals.expr import Expr

    return isinstance(obj, Expr)


def is_series(obj: Any) -> TypeIs[Series[Any]]:
    """Check whether `obj` is a Narwhals Expr."""
    from narwhals.series import Series

    return isinstance(obj, Series)


def combine_evaluate_output_names(
    *exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
    # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
    # first name of `expr1`.
    if not is_compliant_expr(exprs[0]):  # pragma: no cover
        msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
        raise AssertionError(msg)

    def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
        return exprs[0]._evaluate_output_names(df)[:1]

    return evaluate_output_names


def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
    # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
    # aliasing function of `expr1` and apply it to the first output name of `expr1`.
    if exprs[0]._alias_output_names is None:
        return None

    def alias_output_names(names: Sequence[str]) -> Sequence[str]:
        return exprs[0]._alias_output_names(names)[:1]  # type: ignore[misc]

    return alias_output_names


def extract_compliant(
    plx: CompliantNamespaceAny,
    other: IntoExpr | NonNestedLiteral | _1DArray,
    *,
    str_as_lit: bool,
) -> CompliantExprAny | NonNestedLiteral:
    if is_expr(other):
        return other._to_compliant_expr(plx)
    if isinstance(other, str) and not str_as_lit:
        return plx.col(other)
    if is_narwhals_series(other):
        return other._compliant_series._to_expr()
    if is_numpy_array(other):
        ns = cast("EagerNamespaceAny", plx)
        return ns._series.from_numpy(other, context=ns)._to_expr()
    return other


def evaluate_output_names_and_aliases(
    expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
) -> tuple[Sequence[str], Sequence[str]]:
    output_names = expr._evaluate_output_names(df)
    aliases = (
        output_names
        if expr._alias_output_names is None
        else expr._alias_output_names(output_names)
    )
    if exclude:
        assert expr._metadata is not None  # noqa: S101
        if expr._metadata.expansion_kind.is_multi_unnamed():
            output_names, aliases = zip(
                *[
                    (x, alias)
                    for x, alias in zip(output_names, aliases)
                    if x not in exclude
                ]
            )
    return output_names, aliases


class ExprKind(Enum):
    """Describe which kind of expression we are dealing with."""

    LITERAL = auto()
    """e.g. `nw.lit(1)`"""

    AGGREGATION = auto()
    """Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""

    ORDERABLE_AGGREGATION = auto()
    """Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""

    ELEMENTWISE = auto()
    """Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""

    ORDERABLE_WINDOW = auto()
    """Depends on the rows around it and on their order, e.g. `diff`."""

    UNORDERABLE_WINDOW = auto()
    """Depends on the rows around it but not on their order, e.g. `rank`."""

    FILTRATION = auto()
    """Changes length, not affected by row order, e.g. `drop_nulls`."""

    ORDERABLE_FILTRATION = auto()
    """Changes length, affected by row order, e.g. `tail`."""

    NARY = auto()
    """Results from the combination of multiple expressions."""

    OVER = auto()
    """Results from calling `.over` on expression."""

    UNKNOWN = auto()
    """Based on the information we have, we can't determine the ExprKind."""

    @property
    def is_scalar_like(self) -> bool:
        return self in {ExprKind.LITERAL, ExprKind.AGGREGATION}

    @property
    def is_orderable_window(self) -> bool:
        return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION}

    @classmethod
    def from_expr(cls, obj: Expr) -> ExprKind:
        meta = obj._metadata
        if meta.is_literal:
            return ExprKind.LITERAL
        if meta.is_scalar_like:
            return ExprKind.AGGREGATION
        if meta.is_elementwise:
            return ExprKind.ELEMENTWISE
        return ExprKind.UNKNOWN

    @classmethod
    def from_into_expr(
        cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool
    ) -> ExprKind:
        if is_expr(obj):
            return cls.from_expr(obj)
        if (
            is_narwhals_series(obj)
            or is_numpy_array(obj)
            or (isinstance(obj, str) and not str_as_lit)
        ):
            return ExprKind.ELEMENTWISE
        return ExprKind.LITERAL


def is_scalar_like(
    obj: ExprKind,
) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]:
    return obj.is_scalar_like


class ExpansionKind(Enum):
    """Describe what kind of expansion the expression performs."""

    SINGLE = auto()
    """e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""

    MULTI_NAMED = auto()
    """e.g. `nw.col('a', 'b')`"""

    MULTI_UNNAMED = auto()
    """e.g. `nw.all()`, nw.nth(0, 1)"""

    def is_multi_unnamed(self) -> bool:
        return self is ExpansionKind.MULTI_UNNAMED

    def is_multi_output(self) -> bool:
        return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}

    def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
        if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
            # e.g. nw.selectors.all() - nw.selectors.numeric().
            return ExpansionKind.MULTI_UNNAMED
        # Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
        msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug."  # pragma: no cover
        raise AssertionError(msg)  # pragma: no cover


class ExprMetadata:
    __slots__ = (
        "expansion_kind",
        "has_windows",
        "is_elementwise",
        "is_literal",
        "is_scalar_like",
        "last_node",
        "n_orderable_ops",
        "preserves_length",
    )

    def __init__(
        self,
        expansion_kind: ExpansionKind,
        last_node: ExprKind,
        *,
        has_windows: bool = False,
        n_orderable_ops: int = 0,
        preserves_length: bool = True,
        is_elementwise: bool = True,
        is_scalar_like: bool = False,
        is_literal: bool = False,
    ) -> None:
        if is_literal:
            assert is_scalar_like  # noqa: S101  # debug assertion
        if is_elementwise:
            assert preserves_length  # noqa: S101  # debug assertion
        self.expansion_kind: ExpansionKind = expansion_kind
        self.last_node: ExprKind = last_node
        self.has_windows: bool = has_windows
        self.n_orderable_ops: int = n_orderable_ops
        self.is_elementwise: bool = is_elementwise
        self.preserves_length: bool = preserves_length
        self.is_scalar_like: bool = is_scalar_like
        self.is_literal: bool = is_literal

    def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never:  # pragma: no cover
        msg = f"Cannot subclass {cls.__name__!r}"
        raise TypeError(msg)

    def __repr__(self) -> str:  # pragma: no cover
        return (
            f"ExprMetadata(\n"
            f"  expansion_kind: {self.expansion_kind},\n"
            f"  last_node: {self.last_node},\n"
            f"  has_windows: {self.has_windows},\n"
            f"  n_orderable_ops: {self.n_orderable_ops},\n"
            f"  is_elementwise: {self.is_elementwise},\n"
            f"  preserves_length: {self.preserves_length},\n"
            f"  is_scalar_like: {self.is_scalar_like},\n"
            f"  is_literal: {self.is_literal},\n"
            ")"
        )

    @property
    def is_filtration(self) -> bool:
        return not self.preserves_length and not self.is_scalar_like

    def with_aggregation(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply aggregations to scalar-like expressions."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.AGGREGATION,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops,
            preserves_length=False,
            is_elementwise=False,
            is_scalar_like=True,
            is_literal=False,
        )

    def with_orderable_aggregation(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply aggregations to scalar-like expressions."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.ORDERABLE_AGGREGATION,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops + 1,
            preserves_length=False,
            is_elementwise=False,
            is_scalar_like=True,
            is_literal=False,
        )

    def with_elementwise_op(self) -> ExprMetadata:
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.ELEMENTWISE,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops,
            preserves_length=self.preserves_length,
            is_elementwise=self.is_elementwise,
            is_scalar_like=self.is_scalar_like,
            is_literal=self.is_literal,
        )

    def with_unorderable_window(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply unorderable window (`rank`, `is_unique`) to scalar-like expression."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.UNORDERABLE_WINDOW,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops,
            preserves_length=self.preserves_length,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    def with_orderable_window(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.ORDERABLE_WINDOW,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops + 1,
            preserves_length=self.preserves_length,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    def with_ordered_over(self) -> ExprMetadata:
        if self.has_windows:
            msg = "Cannot nest `over` statements."
            raise InvalidOperationError(msg)
        if self.is_elementwise or self.is_filtration:
            msg = (
                "Cannot use `over` on expressions which are elementwise\n"
                "(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
            )
            raise InvalidOperationError(msg)
        n_orderable_ops = self.n_orderable_ops
        if not n_orderable_ops:
            msg = "Cannot use `order_by` in `over` on expression which isn't orderable."
            raise InvalidOperationError(msg)
        if self.last_node.is_orderable_window:
            n_orderable_ops -= 1
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.OVER,
            has_windows=True,
            n_orderable_ops=n_orderable_ops,
            preserves_length=True,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    def with_partitioned_over(self) -> ExprMetadata:
        if self.has_windows:
            msg = "Cannot nest `over` statements."
            raise InvalidOperationError(msg)
        if self.is_elementwise or self.is_filtration:
            msg = (
                "Cannot use `over` on expressions which are elementwise\n"
                "(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
            )
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.OVER,
            has_windows=True,
            n_orderable_ops=self.n_orderable_ops,
            preserves_length=True,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    def with_filtration(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.FILTRATION,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops,
            preserves_length=False,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    def with_orderable_filtration(self) -> ExprMetadata:
        if self.is_scalar_like:
            msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
            raise InvalidOperationError(msg)
        return ExprMetadata(
            self.expansion_kind,
            ExprKind.ORDERABLE_FILTRATION,
            has_windows=self.has_windows,
            n_orderable_ops=self.n_orderable_ops + 1,
            preserves_length=False,
            is_elementwise=False,
            is_scalar_like=False,
            is_literal=False,
        )

    @staticmethod
    def aggregation() -> ExprMetadata:
        return ExprMetadata(
            ExpansionKind.SINGLE,
            ExprKind.AGGREGATION,
            is_elementwise=False,
            preserves_length=False,
            is_scalar_like=True,
        )

    @staticmethod
    def literal() -> ExprMetadata:
        return ExprMetadata(
            ExpansionKind.SINGLE,
            ExprKind.LITERAL,
            is_elementwise=False,
            preserves_length=False,
            is_literal=True,
            is_scalar_like=True,
        )

    @staticmethod
    def selector_single() -> ExprMetadata:
        # e.g. `nw.col('a')`, `nw.nth(0)`
        return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE)

    @staticmethod
    def selector_multi_named() -> ExprMetadata:
        # e.g. `nw.col('a', 'b')`
        return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE)

    @staticmethod
    def selector_multi_unnamed() -> ExprMetadata:
        # e.g. `nw.all()`
        return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE)

    @classmethod
    def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata:
        # We may be able to allow multi-output rhs in the future:
        # https://github.com/narwhals-dev/narwhals/issues/2244.
        return combine_metadata(
            lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False
        )

    @classmethod
    def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
        return combine_metadata(
            *exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
        )


def combine_metadata(  # noqa: C901, PLR0912
    *args: IntoExpr | object | None,
    str_as_lit: bool,
    allow_multi_output: bool,
    to_single_output: bool,
) -> ExprMetadata:
    """Combine metadata from `args`.

    Arguments:
        args: Arguments, maybe expressions, literals, or Series.
        str_as_lit: Whether to interpret strings as literals or as column names.
        allow_multi_output: Whether to allow multi-output inputs.
        to_single_output: Whether the result is always single-output, regardless
            of the inputs (e.g. `nw.sum_horizontal`).
    """
    n_filtrations = 0
    result_expansion_kind = ExpansionKind.SINGLE
    result_has_windows = False
    result_n_orderable_ops = 0
    # result preserves length if at least one input does
    result_preserves_length = False
    # result is elementwise if all inputs are elementwise
    result_is_not_elementwise = False
    # result is scalar-like if all inputs are scalar-like
    result_is_not_scalar_like = False
    # result is literal if all inputs are literal
    result_is_not_literal = False

    for i, arg in enumerate(args):  # noqa: PLR1702
        if (isinstance(arg, str) and not str_as_lit) or is_series(arg):
            result_preserves_length = True
            result_is_not_scalar_like = True
            result_is_not_literal = True
        elif is_expr(arg):
            metadata = arg._metadata
            if metadata.expansion_kind.is_multi_output():
                expansion_kind = metadata.expansion_kind
                if i > 0 and not allow_multi_output:
                    # Left-most argument is always allowed to be multi-output.
                    msg = (
                        "Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) "
                        "are not supported in this context."
                    )
                    raise MultiOutputExpressionError(msg)
                if not to_single_output:
                    if i == 0:
                        result_expansion_kind = expansion_kind
                    else:
                        result_expansion_kind = result_expansion_kind & expansion_kind

            if metadata.has_windows:
                result_has_windows = True
            result_n_orderable_ops += metadata.n_orderable_ops
            if metadata.preserves_length:
                result_preserves_length = True
            if not metadata.is_elementwise:
                result_is_not_elementwise = True
            if not metadata.is_scalar_like:
                result_is_not_scalar_like = True
            if not metadata.is_literal:
                result_is_not_literal = True
            if metadata.is_filtration:
                n_filtrations += 1

    if n_filtrations > 1:
        msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
        raise LengthChangingExprError(msg)
    if result_preserves_length and n_filtrations:
        msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
        raise ShapeError(msg)

    return ExprMetadata(
        result_expansion_kind,
        ExprKind.NARY,
        has_windows=result_has_windows,
        n_orderable_ops=result_n_orderable_ops,
        preserves_length=result_preserves_length,
        is_elementwise=not result_is_not_elementwise,
        is_scalar_like=not result_is_not_scalar_like,
        is_literal=not result_is_not_literal,
    )


def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
    # Raise if any argument in `args` isn't length-preserving.
    # For Series input, we don't raise (yet), we let such checks happen later,
    # as this function works lazily and so can't evaluate lengths.
    from narwhals.series import Series

    if not all(
        (is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series))
        for x in args
    ):
        msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
        raise ShapeError(msg)


def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
    # Raise if any argument in `args` isn't an aggregation or literal.
    # For Series input, we don't raise (yet), we let such checks happen later,
    # as this function works lazily and so can't evaluate lengths.
    exprs = chain(args, kwargs.values())
    return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)


def apply_n_ary_operation(
    plx: CompliantNamespaceAny,
    function: Any,
    *comparands: IntoExpr | NonNestedLiteral | _1DArray,
    str_as_lit: bool,
) -> CompliantExprAny:
    compliant_exprs = (
        extract_compliant(plx, comparand, str_as_lit=str_as_lit)
        for comparand in comparands
    )
    kinds = [
        ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit)
        for comparand in comparands
    ]

    broadcast = any(not kind.is_scalar_like for kind in kinds)
    compliant_exprs = (
        compliant_expr.broadcast(kind)
        if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
        else compliant_expr
        for compliant_expr, kind in zip(compliant_exprs, kinds)
    )
    return function(*compliant_exprs)
