from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any

import duckdb

from narwhals._utils import Version, isinstance_or_issubclass

if TYPE_CHECKING:
    from duckdb import DuckDBPyRelation, Expression
    from duckdb.typing import DuckDBPyType

    from narwhals._duckdb.dataframe import DuckDBLazyFrame
    from narwhals._duckdb.expr import DuckDBExpr
    from narwhals.dtypes import DType
    from narwhals.typing import IntoDType

UNITS_DICT = {
    "y": "year",
    "q": "quarter",
    "mo": "month",
    "d": "day",
    "h": "hour",
    "m": "minute",
    "s": "second",
    "ms": "millisecond",
    "us": "microsecond",
    "ns": "nanosecond",
}

col = duckdb.ColumnExpression
"""Alias for `duckdb.ColumnExpression`."""

lit = duckdb.ConstantExpression
"""Alias for `duckdb.ConstantExpression`."""

when = duckdb.CaseExpression
"""Alias for `duckdb.CaseExpression`."""


def concat_str(*exprs: Expression, separator: str = "") -> Expression:
    """Concatenate many strings, NULL inputs are skipped.

    Wraps [concat] and [concat_ws] `FunctionExpression`(s).

    Arguments:
        exprs: Native columns.
        separator: String that will be used to separate the values of each column.

    Returns:
        A new native expression.

    [concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
    [concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
    """
    return (
        duckdb.FunctionExpression("concat_ws", lit(separator), *exprs)
        if separator
        else duckdb.FunctionExpression("concat", *exprs)
    )


def evaluate_exprs(
    df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
) -> list[tuple[str, Expression]]:
    native_results: list[tuple[str, Expression]] = []
    for expr in exprs:
        native_series_list = expr._call(df)
        output_names = expr._evaluate_output_names(df)
        if expr._alias_output_names is not None:
            output_names = expr._alias_output_names(output_names)
        if len(output_names) != len(native_series_list):  # pragma: no cover
            msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
            raise AssertionError(msg)
        native_results.extend(zip(output_names, native_series_list))
    return native_results


class DeferredTimeZone:
    """Object which gets passed between `native_to_narwhals_dtype` calls.

    DuckDB stores the time zone in the connection, rather than in the dtypes, so
    this ensures that when calculating the schema of a dataframe with multiple
    timezone-aware columns, that the connection's time zone is only fetched once.

    Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
    the time zone can be modified after `DuckDBLazyFrame` creation:

    ```python
    df = nw.from_native(rel)
    print(df.collect_schema())
    rel.query("set timezone = 'Asia/Kolkata'")
    print(df.collect_schema())  # should change to reflect new time zone
    ```
    """

    _cached_time_zone: str | None = None

    def __init__(self, rel: DuckDBPyRelation) -> None:
        self._rel = rel

    @property
    def time_zone(self) -> str:
        """Fetch relation time zone (if it wasn't calculated already)."""
        if self._cached_time_zone is None:
            self._cached_time_zone = fetch_rel_time_zone(self._rel)
        return self._cached_time_zone


def native_to_narwhals_dtype(
    duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
) -> DType:
    duckdb_dtype_id = duckdb_dtype.id
    dtypes = version.dtypes

    # Handle nested data types first
    if duckdb_dtype_id == "list":
        return dtypes.List(
            native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
        )

    if duckdb_dtype_id == "struct":
        children = duckdb_dtype.children
        return dtypes.Struct(
            [
                dtypes.Field(
                    name=child[0],
                    dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
                )
                for child in children
            ]
        )

    if duckdb_dtype_id == "array":
        child, size = duckdb_dtype.children
        shape: list[int] = [size[1]]

        while child[1].id == "array":
            child, size = child[1].children
            shape.insert(0, size[1])

        inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
        return dtypes.Array(inner=inner, shape=tuple(shape))

    if duckdb_dtype_id == "enum":
        if version is Version.V1:
            return dtypes.Enum()  # type: ignore[call-arg]
        categories = duckdb_dtype.children[0][1]
        return dtypes.Enum(categories=categories)

    if duckdb_dtype_id == "timestamp with time zone":
        return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)

    return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)


def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
    result = rel.query(
        "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
    ).fetchone()
    assert result is not None  # noqa: S101
    return result[0]  # type: ignore[no-any-return]


@lru_cache(maxsize=16)
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
    dtypes = version.dtypes
    return {
        "hugeint": dtypes.Int128(),
        "bigint": dtypes.Int64(),
        "integer": dtypes.Int32(),
        "smallint": dtypes.Int16(),
        "tinyint": dtypes.Int8(),
        "uhugeint": dtypes.UInt128(),
        "ubigint": dtypes.UInt64(),
        "uinteger": dtypes.UInt32(),
        "usmallint": dtypes.UInt16(),
        "utinyint": dtypes.UInt8(),
        "double": dtypes.Float64(),
        "float": dtypes.Float32(),
        "varchar": dtypes.String(),
        "date": dtypes.Date(),
        "timestamp": dtypes.Datetime(),
        "boolean": dtypes.Boolean(),
        "interval": dtypes.Duration(),
        "decimal": dtypes.Decimal(),
        "time": dtypes.Time(),
        "blob": dtypes.Binary(),
    }.get(duckdb_dtype_id, dtypes.Unknown())


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str:  # noqa: C901, PLR0912, PLR0915
    dtypes = version.dtypes
    if isinstance_or_issubclass(dtype, dtypes.Decimal):
        msg = "Casting to Decimal is not supported yet."
        raise NotImplementedError(msg)
    if isinstance_or_issubclass(dtype, dtypes.Float64):
        return "DOUBLE"
    if isinstance_or_issubclass(dtype, dtypes.Float32):
        return "FLOAT"
    if isinstance_or_issubclass(dtype, dtypes.Int128):
        return "INT128"
    if isinstance_or_issubclass(dtype, dtypes.Int64):
        return "BIGINT"
    if isinstance_or_issubclass(dtype, dtypes.Int32):
        return "INTEGER"
    if isinstance_or_issubclass(dtype, dtypes.Int16):
        return "SMALLINT"
    if isinstance_or_issubclass(dtype, dtypes.Int8):
        return "TINYINT"
    if isinstance_or_issubclass(dtype, dtypes.UInt128):
        return "UINT128"
    if isinstance_or_issubclass(dtype, dtypes.UInt64):
        return "UBIGINT"
    if isinstance_or_issubclass(dtype, dtypes.UInt32):
        return "UINTEGER"
    if isinstance_or_issubclass(dtype, dtypes.UInt16):  # pragma: no cover
        return "USMALLINT"
    if isinstance_or_issubclass(dtype, dtypes.UInt8):  # pragma: no cover
        return "UTINYINT"
    if isinstance_or_issubclass(dtype, dtypes.String):
        return "VARCHAR"
    if isinstance_or_issubclass(dtype, dtypes.Boolean):  # pragma: no cover
        return "BOOLEAN"
    if isinstance_or_issubclass(dtype, dtypes.Time):
        return "TIME"
    if isinstance_or_issubclass(dtype, dtypes.Binary):
        return "BLOB"
    if isinstance_or_issubclass(dtype, dtypes.Categorical):
        msg = "Categorical not supported by DuckDB"
        raise NotImplementedError(msg)
    if isinstance_or_issubclass(dtype, dtypes.Enum):
        if version is Version.V1:
            msg = "Converting to Enum is not supported in narwhals.stable.v1"
            raise NotImplementedError(msg)
        if isinstance(dtype, dtypes.Enum):
            categories = "'" + "', '".join(dtype.categories) + "'"
            return f"ENUM ({categories})"
        msg = "Can not cast / initialize Enum without categories present"
        raise ValueError(msg)

    if isinstance_or_issubclass(dtype, dtypes.Datetime):
        _time_unit = dtype.time_unit
        _time_zone = dtype.time_zone
        msg = "todo"
        raise NotImplementedError(msg)
    if isinstance_or_issubclass(dtype, dtypes.Duration):  # pragma: no cover
        _time_unit = dtype.time_unit
        msg = "todo"
        raise NotImplementedError(msg)
    if isinstance_or_issubclass(dtype, dtypes.Date):  # pragma: no cover
        return "DATE"
    if isinstance_or_issubclass(dtype, dtypes.List):
        inner = narwhals_to_native_dtype(dtype.inner, version)
        return f"{inner}[]"
    if isinstance_or_issubclass(dtype, dtypes.Struct):  # pragma: no cover
        inner = ", ".join(
            f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}'
            for field in dtype.fields
        )
        return f"STRUCT({inner})"
    if isinstance_or_issubclass(dtype, dtypes.Array):  # pragma: no cover
        shape = dtype.shape
        duckdb_shape_fmt = "".join(f"[{item}]" for item in shape)
        inner_dtype: Any = dtype
        for _ in shape:
            inner_dtype = inner_dtype.inner
        duckdb_inner = narwhals_to_native_dtype(inner_dtype, version)
        return f"{duckdb_inner}{duckdb_shape_fmt}"
    msg = f"Unknown dtype: {dtype}"  # pragma: no cover
    raise AssertionError(msg)


def generate_partition_by_sql(*partition_by: str | Expression) -> str:
    if not partition_by:
        return ""
    by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by])
    return f"partition by {by_sql}"


def generate_order_by_sql(*order_by: str, ascending: bool) -> str:
    if ascending:
        by_sql = ", ".join([f"{col(x)} asc nulls first" for x in order_by])
    else:
        by_sql = ", ".join([f"{col(x)} desc nulls last" for x in order_by])
    return f"order by {by_sql}"
