from __future__ import annotations

import functools
import re
from contextlib import suppress
from inspect import isclass
from typing import TYPE_CHECKING, Any

from polars.datatypes import (
    Binary,
    Boolean,
    Date,
    Datetime,
    Decimal,
    Duration,
    Float32,
    Float64,
    Int8,
    Int16,
    Int32,
    Int64,
    Int128,
    List,
    Null,
    String,
    Time,
    UInt8,
    UInt16,
    UInt32,
    UInt64,
)
from polars.datatypes._parse import parse_py_type_into_dtype
from polars.datatypes.group import (
    INTEGER_DTYPES,
    UNSIGNED_INTEGER_DTYPES,
)

if TYPE_CHECKING:
    from polars._typing import PolarsDataType


def dtype_from_database_typename(
    value: str,
    *,
    raise_unmatched: bool = True,
) -> PolarsDataType | None:
    """
    Attempt to infer Polars dtype from database cursor `type_code` string value.

    Examples
    --------
    >>> dtype_from_database_typename("INT2")
    Int16
    >>> dtype_from_database_typename("NVARCHAR")
    String
    >>> dtype_from_database_typename("NUMERIC(10,2)")
    Decimal(precision=10, scale=2)
    >>> dtype_from_database_typename("TIMESTAMP WITHOUT TZ")
    Datetime(time_unit='us', time_zone=None)
    """
    dtype: PolarsDataType | None = None

    # normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
    original_value = value
    value = value.upper().replace("TYPE", "")

    # extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
    if re.search(r"\([\w,: ]+\)$", value):
        modifier = value[value.find("(") + 1 : -1]
        value = value.split("(")[0]
    elif (
        not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
    ) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
        modifier = value[value.find("[") + 1 : -1]
        value = value.split("[")[0]
    else:
        modifier = ""

    # array dtypes
    array_aliases = ("ARRAY", "LIST", "[]")
    if value.endswith(array_aliases) or value.startswith(array_aliases):
        for a in array_aliases:
            value = value.replace(a, "", 1) if value else ""

        nested: PolarsDataType | None = None
        if not value and modifier:
            nested = dtype_from_database_typename(
                value=modifier,
                raise_unmatched=False,
            )
        else:
            if inner_value := dtype_from_database_typename(
                value[1:-1]
                if (value[0], value[-1]) == ("<", ">")
                else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
                raise_unmatched=False,
            ):
                nested = inner_value
            elif modifier:
                nested = dtype_from_database_typename(
                    value=modifier,
                    raise_unmatched=False,
                )
        if nested:
            dtype = List(nested)

    # float dtypes
    elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
        dtype = (
            Float32
            if value == "FLOAT4"
            or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
            else Float64
        )

    # integer dtypes
    elif ("INTERVAL" not in value) and (
        value.startswith(("INT", "UINT", "UNSIGNED"))
        or value.endswith(("INT", "SERIAL"))
        or ("INTEGER" in value)
        or value in ("TINY", "SHORT", "LONG", "LONGLONG", "ROWID")
    ):
        sz: Any
        if "HUGEINT" in value:
            sz = 128
        elif (
            "LARGE" in value or value.startswith("BIG") or value in ("INT8", "LONGLONG")
        ):
            sz = 64
        elif "MEDIUM" in value or value in ("INT4", "UINT4", "LONG", "SERIAL"):
            sz = 32
        elif "SMALL" in value or value in ("INT2", "UINT2", "SHORT"):
            sz = 16
        elif "TINY" in value:
            sz = 8
        elif n := re.sub(r"^\D+", "", value):
            if (sz := int(n)) <= 8:
                sz = sz * 8
        else:
            sz = None

        sz = modifier if (not sz and modifier) else sz
        if not isinstance(sz, int):
            sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
        if (
            ("U" in value and "MEDIUM" not in value)
            or ("UNSIGNED" in value)
            or value == "ROWID"
        ):
            dtype = integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
        else:
            dtype = integer_dtype_from_nbits(sz, unsigned=False, default=Int64)

    # number types (note: 'number' alone is not that helpful and requires refinement)
    elif "NUMBER" in value and "CARDINAL" in value:
        dtype = UInt64

    # decimal dtypes
    elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
        if "," in modifier:
            prec, scale = modifier.split(",")
            dtype = Decimal(int(prec), int(scale))
        else:
            dtype = Decimal if is_dec else Float64

    # string dtypes
    elif (
        any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
        or value.startswith(("STR", "CHAR", "BPCHAR", "NCHAR", "UTF"))
        or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
    ):
        dtype = String

    # binary dtypes
    elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
        dtype = Binary

    # boolean dtypes
    elif value.startswith("BOOL"):
        dtype = Boolean

    # null dtype; odd, but valid
    elif value == "NULL":
        dtype = Null

    # temporal dtypes
    elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
        if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
            if "WITHOUT" not in value:
                return None  # there's a timezone, but we don't know what it is
        unit = timeunit_from_precision(modifier) if modifier else "us"
        dtype = Datetime(time_unit=(unit or "us"))  # type: ignore[arg-type]
    else:
        value = re.sub(r"\d", "", value)
        if value in ("INTERVAL", "TIMEDELTA", "DURATION"):
            dtype = Duration
        elif value == "DATE":
            dtype = Date
        elif value == "TIME":
            dtype = Time

    if not dtype and raise_unmatched:
        msg = f"cannot infer dtype from {original_value!r} string value"
        raise ValueError(msg)

    return dtype


def dtype_from_cursor_description(
    cursor: Any,
    description: tuple[Any, ...],
) -> PolarsDataType | None:
    """Attempt to infer Polars dtype from database cursor description `type_code`."""
    type_code, _disp_size, internal_size, precision, scale, *_ = description
    dtype: PolarsDataType | None = None

    if isclass(type_code):
        # python types, eg: int, float, str, etc
        with suppress(TypeError):
            dtype = parse_py_type_into_dtype(type_code)  # type: ignore[arg-type]

    elif isinstance(type_code, str):
        # database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
        dtype = dtype_from_database_typename(
            value=type_code,
            raise_unmatched=False,
        )

    # check additional cursor attrs to refine dtype specification
    if dtype is not None:
        if dtype == Float64 and internal_size == 4:
            dtype = Float32

        elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
            bits = internal_size * 8
            dtype = integer_dtype_from_nbits(
                bits,
                unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
                default=dtype,
            )
        elif (
            dtype == Decimal
            and isinstance(precision, int)
            and isinstance(scale, int)
            and precision <= 38
            and scale <= 38
        ):
            dtype = Decimal(precision, scale)

    return dtype


@functools.lru_cache(8)
def integer_dtype_from_nbits(
    bits: int,
    *,
    unsigned: bool,
    default: PolarsDataType | None = None,
) -> PolarsDataType | None:
    """
    Return matching Polars integer dtype from num bits and signed/unsigned flag.

    Examples
    --------
    >>> integer_dtype_from_nbits(8, unsigned=False)
    Int8
    >>> integer_dtype_from_nbits(32, unsigned=True)
    UInt32
    """
    dtype = {
        (8, False): Int8,
        (8, True): UInt8,
        (16, False): Int16,
        (16, True): UInt16,
        (32, False): Int32,
        (32, True): UInt32,
        (64, False): Int64,
        (64, True): UInt64,
        (128, False): Int128,
        (128, True): Int128,  # UInt128 not (yet?) supported
    }.get((bits, unsigned), None)

    if dtype is None and default is not None:
        return default
    return dtype


def timeunit_from_precision(precision: int | str | None) -> str | None:
    """
    Return `time_unit` from integer precision value.

    Examples
    --------
    >>> timeunit_from_precision(3)
    'ms'
    >>> timeunit_from_precision(5)
    'us'
    >>> timeunit_from_precision(7)
    'ns'
    """
    from math import ceil

    if not precision:
        return None
    elif isinstance(precision, str):
        if precision.isdigit():
            precision = int(precision)
        elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
            return "ms" if precision == "s" else precision
    try:
        n = min(max(3, int(ceil(precision / 3)) * 3), 9)  # type: ignore[operator]
        return {3: "ms", 6: "us", 9: "ns"}.get(n)
    except TypeError:
        return None
