from __future__ import annotations

import inspect
import os
import re
import sys
import warnings
from collections import Counter
from collections.abc import (
    Collection,
    Generator,
    Iterable,
    MappingView,
    Sequence,
    Sized,
)
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Literal,
    TypeVar,
    overload,
)

import polars as pl
from polars import functions as F
from polars.datatypes import (
    Boolean,
    Date,
    Datetime,
    Decimal,
    Duration,
    Int64,
    String,
    Time,
)
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
from polars.dependencies import _check_for_numpy, import_optional, subprocess
from polars.dependencies import numpy as np

if TYPE_CHECKING:
    from collections.abc import Iterator, MutableMapping, Reversible

    from polars import DataFrame, Expr
    from polars._typing import PolarsDataType, SizeUnit

    if sys.version_info >= (3, 13):
        from typing import TypeIs
    else:
        from typing_extensions import TypeIs

    if sys.version_info >= (3, 10):
        from typing import ParamSpec, TypeGuard
    else:
        from typing_extensions import ParamSpec, TypeGuard

    P = ParamSpec("P")
    T = TypeVar("T")

# note: reversed views don't match as instances of MappingView
if sys.version_info >= (3, 11):
    _views: list[Reversible[Any]] = [{}.keys(), {}.values(), {}.items()]
    _reverse_mapping_views = tuple(type(reversed(view)) for view in _views)


def _process_null_values(
    null_values: None | str | Sequence[str] | dict[str, str] = None,
) -> None | str | Sequence[str] | list[tuple[str, str]]:
    if isinstance(null_values, dict):
        return list(null_values.items())
    else:
        return null_values


def _is_generator(val: object | Iterator[T]) -> TypeIs[Iterator[T]]:
    return (
        (isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
        or isinstance(val, MappingView)
        or (sys.version_info >= (3, 11) and isinstance(val, _reverse_mapping_views))
    )


def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> bool:
    """Check whether the given iterable is of the given type(s)."""
    return all(isinstance(x, eltype) for x in val)


def is_path_or_str_sequence(
    val: object, *, allow_str: bool = False, include_series: bool = False
) -> TypeGuard[Sequence[str | Path]]:
    """
    Check that `val` is a sequence of strings or paths.

    Note that a single string is a sequence of strings by definition, use
    `allow_str=False` to return False on a single string.
    """
    if allow_str is False and isinstance(val, str):
        return False
    elif _check_for_numpy(val) and isinstance(val, np.ndarray):
        return np.issubdtype(val.dtype, np.str_)
    elif include_series and isinstance(val, pl.Series):
        return val.dtype == pl.String
    return (
        not isinstance(val, bytes)
        and isinstance(val, Sequence)
        and _is_iterable_of(val, (Path, str))
    )


def is_bool_sequence(
    val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[bool]]:
    """Check whether the given sequence is a sequence of booleans."""
    if _check_for_numpy(val) and isinstance(val, np.ndarray):
        return val.dtype == np.bool_
    elif include_series and isinstance(val, pl.Series):
        return val.dtype == pl.Boolean
    return isinstance(val, Sequence) and _is_iterable_of(val, bool)


def is_int_sequence(
    val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[int]]:
    """Check whether the given sequence is a sequence of integers."""
    if _check_for_numpy(val) and isinstance(val, np.ndarray):
        return np.issubdtype(val.dtype, np.integer)
    elif include_series and isinstance(val, pl.Series):
        return val.dtype.is_integer()
    return isinstance(val, Sequence) and _is_iterable_of(val, int)


def is_sequence(
    val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[Any]]:
    """Check whether the given input is a numpy array or python sequence."""
    return (_check_for_numpy(val) and isinstance(val, np.ndarray)) or (
        isinstance(val, (pl.Series, Sequence) if include_series else Sequence)
        and not isinstance(val, str)
    )


def is_str_sequence(
    val: object, *, allow_str: bool = False, include_series: bool = False
) -> TypeGuard[Sequence[str]]:
    """
    Check that `val` is a sequence of strings.

    Note that a single string is a sequence of strings by definition, use
    `allow_str=False` to return False on a single string.
    """
    if allow_str is False and isinstance(val, str):
        return False
    elif _check_for_numpy(val) and isinstance(val, np.ndarray):
        return np.issubdtype(val.dtype, np.str_)
    elif include_series and isinstance(val, pl.Series):
        return val.dtype == pl.String
    return isinstance(val, Sequence) and _is_iterable_of(val, str)


def is_column(obj: Any) -> bool:
    """Indicate if the given object is a basic/unaliased column."""
    from polars.expr import Expr

    return isinstance(obj, Expr) and obj.meta.is_column()


def warn_null_comparison(obj: Any) -> None:
    """Warn for possibly unintentional comparisons with None."""
    if obj is None:
        warnings.warn(
            "Comparisons with None always result in null. Consider using `.is_null()` or `.is_not_null()`.",
            UserWarning,
            stacklevel=find_stacklevel(),
        )


def range_to_series(
    name: str, rng: range, dtype: PolarsDataType | None = None
) -> pl.Series:
    """Fast conversion of the given range to a Series."""
    dtype = dtype or Int64
    if dtype.is_integer():
        range = F.int_range(  # type: ignore[call-overload]
            start=rng.start, end=rng.stop, step=rng.step, dtype=dtype, eager=True
        )
    else:
        range = F.int_range(
            start=rng.start, end=rng.stop, step=rng.step, eager=True
        ).cast(dtype)
    return range.alias(name)


def range_to_slice(rng: range) -> slice:
    """Return the given range as an equivalent slice."""
    return slice(rng.start, rng.stop, rng.step)


def _in_notebook() -> bool:
    try:
        from IPython import get_ipython

        if "IPKernelApp" not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    except AttributeError:
        return False
    return True


def arrlen(obj: Any) -> int | None:
    """Return length of (non-string/dict) sequence; returns None for non-sequences."""
    try:
        return None if isinstance(obj, (str, dict)) else len(obj)
    except TypeError:
        return None


def normalize_filepath(path: str | Path, *, check_not_directory: bool = True) -> str:
    """Create a string path, expanding the home directory if present."""
    # don't use pathlib here as it modifies slashes (s3:// -> s3:/)
    path = os.path.expanduser(path)  # noqa: PTH111
    if (
        check_not_directory
        and os.path.exists(path)  # noqa: PTH110
        and os.path.isdir(path)  # noqa: PTH112
    ):
        msg = f"expected a file path; {path!r} is a directory"
        raise IsADirectoryError(msg)
    return path


def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
    """Simple version parser; split into a tuple of ints for comparison."""
    if isinstance(version, str):
        version = version.split(".")
    return tuple(int(re.sub(r"\D", "", str(v))) for v in version)


def ordered_unique(values: Sequence[Any]) -> list[Any]:
    """Return unique list of sequence values, maintaining their order of appearance."""
    seen: set[Any] = set()
    add_ = seen.add
    return [v for v in values if not (v in seen or add_(v))]


def deduplicate_names(names: Iterable[str]) -> list[str]:
    """Ensure name uniqueness by appending a counter to subsequent duplicates."""
    seen: MutableMapping[str, int] = Counter()
    deduped = []
    for nm in names:
        deduped.append(f"{nm}{seen[nm] - 1}" if nm in seen else nm)
        seen[nm] += 1
    return deduped


@overload
def scale_bytes(sz: int, unit: SizeUnit) -> int | float: ...


@overload
def scale_bytes(sz: Expr, unit: SizeUnit) -> Expr: ...


def scale_bytes(sz: int | Expr, unit: SizeUnit) -> int | float | Expr:
    """Scale size in bytes to other size units (eg: "kb", "mb", "gb", "tb")."""
    if unit in {"b", "bytes"}:
        return sz
    elif unit in {"kb", "kilobytes"}:
        return sz / 1024
    elif unit in {"mb", "megabytes"}:
        return sz / 1024**2
    elif unit in {"gb", "gigabytes"}:
        return sz / 1024**3
    elif unit in {"tb", "terabytes"}:
        return sz / 1024**4
    else:
        msg = f"`unit` must be one of {{'b', 'kb', 'mb', 'gb', 'tb'}}, got {unit!r}"
        raise ValueError(msg)


def _cast_repr_strings_with_schema(
    df: DataFrame, schema: dict[str, PolarsDataType | None]
) -> DataFrame:
    """
    Utility function to cast table repr/string values into frame-native types.

    Parameters
    ----------
    df
        Dataframe containing string-repr column data.
    schema
        DataFrame schema containing the desired end-state types.

    Notes
    -----
    Table repr strings are less strict (or different) than equivalent CSV data, so need
    special handling; as this function is only used for reprs, parsing is flexible.
    """
    tp: PolarsDataType | None
    if not df.is_empty():
        for tp in df.schema.values():
            if tp != String:
                msg = f"DataFrame should contain only String repr data; found {tp!r}"
                raise TypeError(msg)

    special_floats = {"-inf", "+inf", "inf", "nan"}

    # duration string scaling
    ns_sec = 1_000_000_000
    duration_scaling = {
        "ns": 1,
        "us": 1_000,
        "µs": 1_000,
        "ms": 1_000_000,
        "s": ns_sec,
        "m": ns_sec * 60,
        "h": ns_sec * 60 * 60,
        "d": ns_sec * 3_600 * 24,
        "w": ns_sec * 3_600 * 24 * 7,
    }

    # identify duration units and convert to nanoseconds
    def str_duration_(td: str | None) -> int | None:
        return (
            None
            if td is None
            else sum(
                int(value) * duration_scaling[unit.strip()]
                for value, unit in re.findall(r"(\d+)(\D+)", td)
            )
        )

    cast_cols = {}
    for c, tp in schema.items():
        if tp is not None:
            if tp.base_type() == Datetime:
                tp_base = Datetime(tp.time_unit)  # type: ignore[union-attr]
                d = F.col(c).str.replace(r"[A-Z ]+$", "")
                cast_cols[c] = (
                    F.when(d.str.len_bytes() == 19)
                    .then(d + ".000000000")
                    .otherwise(d + "000000000")
                    .str.slice(0, 29)
                    .str.strptime(tp_base, "%Y-%m-%d %H:%M:%S.%9f")
                )
                if getattr(tp, "time_zone", None) is not None:
                    cast_cols[c] = cast_cols[c].dt.replace_time_zone(tp.time_zone)  # type: ignore[union-attr]
            elif tp == Date:
                cast_cols[c] = F.col(c).str.strptime(tp, "%Y-%m-%d")  # type: ignore[arg-type]
            elif tp == Time:
                cast_cols[c] = (
                    F.when(F.col(c).str.len_bytes() == 8)
                    .then(F.col(c) + ".000000000")
                    .otherwise(F.col(c) + "000000000")
                    .str.slice(0, 18)
                    .str.strptime(tp, "%H:%M:%S.%9f")  # type: ignore[arg-type]
                )
            elif tp == Duration:
                cast_cols[c] = (
                    F.col(c)
                    .map_elements(str_duration_, return_dtype=Int64)
                    .cast(Duration("ns"))
                    .cast(tp)
                )
            elif tp == Boolean:
                cast_cols[c] = F.col(c).replace_strict({"true": True, "false": False})
            elif tp in INTEGER_DTYPES:
                int_string = F.col(c).str.replace_all(r"[^\d+-]", "")
                cast_cols[c] = (
                    pl.when(int_string.str.len_bytes() > 0).then(int_string).cast(tp)
                )
            elif tp in FLOAT_DTYPES or tp.base_type() == Decimal:
                # identify integer/fractional parts
                integer_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$1")
                fractional_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$2")
                cast_cols[c] = (
                    # check for empty string, special floats, or integer format
                    pl.when(
                        F.col(c).str.contains(r"^[+-]?\d*$")
                        | F.col(c).str.to_lowercase().is_in(special_floats)
                    )
                    .then(pl.when(F.col(c).str.len_bytes() > 0).then(F.col(c)))
                    # check for scientific notation
                    .when(F.col(c).str.contains("[eE]"))
                    .then(F.col(c).str.replace(r"[^eE\d]", "."))
                    .otherwise(
                        # recombine sanitised integer/fractional components
                        pl.concat_str(
                            integer_part.str.replace_all(r"[^\d+-]", ""),
                            fractional_part,
                            separator=".",
                        )
                    )
                    .cast(String)
                    .cast(tp)
                )
            elif tp != df.schema[c]:
                cast_cols[c] = F.col(c).cast(tp)

    return df.with_columns(**cast_cols) if cast_cols else df


# when building docs (with Sphinx) we need access to the functions
# associated with the namespaces from the class, as we don't have
# an instance; @sphinx_accessor is a @property that allows this.
NS = TypeVar("NS")


class sphinx_accessor(property):
    def __get__(  # type: ignore[override]
        self,
        instance: Any,
        cls: type[NS],
    ) -> NS:
        try:
            return self.fget(  # type: ignore[misc]
                instance if isinstance(instance, cls) else cls
            )
        except (AttributeError, ImportError):
            return self  # type: ignore[return-value]


BUILDING_SPHINX_DOCS = os.getenv("BUILDING_SPHINX_DOCS")


class _NoDefault(Enum):
    # "borrowed" from
    # https://github.com/pandas-dev/pandas/blob/e7859983a814b1823cf26e3b491ae2fa3be47c53/pandas/_libs/lib.pyx#L2736-L2748
    no_default = "NO_DEFAULT"

    def __repr__(self) -> str:
        return "<no_default>"


# 'NoDefault' is a sentinel indicating that no default value has been set; note that
# this should typically be used only when one of the valid parameter values is also
# None, as otherwise we cannot determine if the caller has explicitly set that value.
no_default = _NoDefault.no_default
NoDefault = Literal[_NoDefault.no_default]


def find_stacklevel() -> int:
    """
    Find the first place in the stack that is not inside Polars.

    Taken from:
    https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
    """
    pkg_dir = str(Path(pl.__file__).parent)

    # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
    frame = inspect.currentframe()
    n = 0
    try:
        while frame:
            fname = inspect.getfile(frame)
            if fname.startswith(pkg_dir) or (
                (qualname := getattr(frame.f_code, "co_qualname", None))
                # ignore @singledispatch wrappers
                and qualname.startswith("singledispatch.")
            ):
                frame = frame.f_back
                n += 1
            else:
                break
    finally:
        # https://docs.python.org/3/library/inspect.html
        # > Though the cycle detector will catch these, destruction of the frames
        # > (and local variables) can be made deterministic by removing the cycle
        # > in a finally clause.
        del frame
    return n


def issue_warning(message: str, category: type[Warning], **kwargs: Any) -> None:
    """
    Issue a warning.

    Parameters
    ----------
    message
        The message associated with the warning.
    category
        The warning category.
    **kwargs
        Additional arguments for `warnings.warn`. Note that the `stacklevel` is
        determined automatically.
    """
    warnings.warn(
        message=message, category=category, stacklevel=find_stacklevel(), **kwargs
    )


def _get_stack_locals(
    of_type: type | Collection[type] | Callable[[Any], bool] | None = None,
    *,
    named: str | Collection[str] | None = None,
    n_objects: int | None = None,
    n_frames: int | None = None,
) -> dict[str, Any]:
    """
    Retrieve f_locals from all (or the last 'n') stack frames from the calling location.

    Parameters
    ----------
    of_type
        Only return objects of this type; can be a single class, tuple of
        classes, or a callable that returns True/False if the object being
        tested is considered a match.
    n_objects
        If specified, return only the most recent `n` matching objects.
    n_frames
        If specified, look at objects in the last `n` stack frames only.
    named
        If specified, only return objects matching the given name(s).
    """
    objects = {}
    examined_frames = 0

    if isinstance(named, str):
        named = (named,)
    if n_frames is None:
        n_frames = sys.maxsize

    if inspect.isfunction(of_type):
        matches_type = of_type
    else:
        if isinstance(of_type, Collection):
            of_type = tuple(of_type)

        def matches_type(obj: Any) -> bool:  # type: ignore[misc]
            return isinstance(obj, of_type)  # type: ignore[arg-type]

    if named is not None:
        if isinstance(named, str):
            named = (named,)
        elif not isinstance(named, set):
            named = set(named)

    stack_frame = inspect.currentframe()
    stack_frame = getattr(stack_frame, "f_back", None)
    try:
        while stack_frame and examined_frames < n_frames:
            local_items = list(stack_frame.f_locals.items())
            for nm, obj in reversed(local_items):
                if (
                    nm not in objects
                    and (named is None or nm in named)
                    and (of_type is None or matches_type(obj))
                ):
                    objects[nm] = obj
                    if n_objects is not None and len(objects) >= n_objects:
                        return objects

            stack_frame = stack_frame.f_back
            examined_frames += 1
    finally:
        # https://docs.python.org/3/library/inspect.html
        # > Though the cycle detector will catch these, destruction of the frames
        # > (and local variables) can be made deterministic by removing the cycle
        # > in a finally clause.
        del stack_frame

    return objects


# this is called from rust
def _polars_warn(msg: str, category: type[Warning] = UserWarning) -> None:
    warnings.warn(
        msg,
        category=category,
        stacklevel=find_stacklevel(),
    )


def extend_bool(
    value: bool | Sequence[bool],
    n_match: int,
    value_name: str,
    match_name: str,
) -> Sequence[bool]:
    """Ensure the given bool or sequence of bools is the correct length."""
    values = [value] * n_match if isinstance(value, bool) else value
    if n_match != len(values):
        msg = (
            f"the length of `{value_name}` ({len(values)}) "
            f"does not match the length of `{match_name}` ({n_match})"
        )
        raise ValueError(msg)
    return values


def in_terminal_that_supports_colour() -> bool:
    """
    Determine (within reason) if we are in an interactive terminal that supports color.

    Note: this is not exhaustive, but it covers a lot (most?) of the common cases.
    """
    if hasattr(sys.stdout, "isatty"):
        # can enhance as necessary, but this is a reasonable start
        return (
            sys.stdout.isatty()
            and (
                sys.platform != "win32"
                or "ANSICON" in os.environ
                or "WT_SESSION" in os.environ
                or os.environ.get("TERM_PROGRAM") == "vscode"
                or os.environ.get("TERM") == "xterm-256color"
            )
        ) or os.environ.get("PYCHARM_HOSTED") == "1"
    return False


def parse_percentiles(
    percentiles: Sequence[float] | float | None, *, inject_median: bool = False
) -> Sequence[float]:
    """
    Transforms raw percentiles into our preferred format, adding the 50th percentile.

    Raises a ValueError if the percentile sequence is invalid
    (e.g. outside the range [0, 1])
    """
    if isinstance(percentiles, float):
        percentiles = [percentiles]
    elif percentiles is None:
        percentiles = []
    if not all((0 <= p <= 1) for p in percentiles):
        msg = "`percentiles` must all be in the range [0, 1]"
        raise ValueError(msg)

    sub_50_percentiles = sorted(p for p in percentiles if p < 0.5)
    at_or_above_50_percentiles = sorted(p for p in percentiles if p >= 0.5)

    if inject_median and (
        not at_or_above_50_percentiles or at_or_above_50_percentiles[0] != 0.5
    ):
        at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles]

    return [*sub_50_percentiles, *at_or_above_50_percentiles]


def re_escape(s: str) -> str:
    """Escape a string for use in a Polars (Rust) regex."""
    # note: almost the same as the standard python 're.escape' function, but
    # escapes _only_ those metachars with meaning to the rust regex crate
    re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
    return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)


# Don't rename or move. This is used by polars cloud
def display_dot_graph(
    *,
    dot: str,
    show: bool = True,
    output_path: str | Path | None = None,
    raw_output: bool = False,
    figsize: tuple[float, float] = (16.0, 12.0),
) -> str | None:
    if raw_output:
        # we do not show a graph, nor save a graph to disk
        return dot

    output_type = "svg" if _in_notebook() else "png"

    try:
        graph = subprocess.check_output(
            ["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
        )
    except (ImportError, FileNotFoundError):
        msg = (
            "the graphviz `dot` binary should be on your PATH."
            "(If not installed you can download here: https://graphviz.org/download/)"
        )
        raise ImportError(msg) from None

    if output_path:
        Path(output_path).write_bytes(graph)

    if not show:
        return None

    if _in_notebook():
        from IPython.display import SVG, display

        return display(SVG(graph))
    else:
        import_optional(
            "matplotlib",
            err_prefix="",
            err_suffix="should be installed to show graphs",
        )
        import matplotlib.image as mpimg
        import matplotlib.pyplot as plt

        plt.figure(figsize=figsize)
        img = mpimg.imread(BytesIO(graph))
        plt.imshow(img)
        plt.show()
        return None


def qualified_type_name(obj: Any, *, qualify_polars: bool = False) -> str:
    """
    Return the module-qualified name of the given object as a string.

    Parameters
    ----------
    obj
        The object to get the qualified name for.
    qualify_polars
        If False (default), omit the module path for our own (Polars) objects.
    """
    if isinstance(obj, type):
        module = obj.__module__
        name = obj.__name__
    else:
        module = obj.__class__.__module__
        name = obj.__class__.__name__

    if (
        not module
        or module == "builtins"
        or (not qualify_polars and module.startswith("polars."))
    ):
        return name

    return f"{module}.{name}"


def require_same_type(current: Any, other: Any) -> None:
    """
    Raise an error if the two arguments are not of the same type.

    Parameters
    ----------
    current
        The object the type of which is being checked against.
    other
        An object that has to be of the same type.
    """
    if not isinstance(other, type(current)):
        msg = (
            f"expected `other` to be a {qualified_type_name(current)!r}, "
            f"not {qualified_type_name(other)!r}"
        )
        raise TypeError(msg)
