from typing import Any

from polars import DataFrame
from polars._typing import IndexOrder
from polars.datatypes import Array, List
from polars.dependencies import numpy as np


def frame_to_numpy(
    df: DataFrame,
    *,
    writable: bool,
    target: str,
    order: IndexOrder = "fortran",
) -> np.ndarray[Any, Any]:
    """Convert a DataFrame to a NumPy array for use with Jax or PyTorch."""
    for nm, tp in df.schema.items():
        if tp == List:
            msg = f"cannot convert List column {nm!r} to {target} (use Array dtype instead)"
            raise TypeError(msg) from None

    if df.width == 1 and df.schema.dtypes()[0] == Array:
        arr = df[df.columns[0]].to_numpy(writable=writable)
    else:
        arr = df.to_numpy(writable=writable, order=order)

    if arr.dtype == object:
        msg = f"cannot convert DataFrame to {target} (mixed type columns result in `object` dtype)\n{df.schema!r}"
        raise TypeError(msg)
    return arr
