import functools
import inspect
import sys
import warnings

import numpy as np

from ._warnings import all_warnings, warn

__all__ = [
    'deprecate_func',
    'get_bound_method_class',
    'all_warnings',
    'safe_as_int',
    'check_shape_equality',
    'check_nD',
    'warn',
    'reshape_nd',
    'identity',
    'slice_at_axis',
    "deprecate_parameter",
    "DEPRECATED",
]


def _count_wrappers(func):
    """Count the number of wrappers around `func`."""
    unwrapped = func
    count = 0
    while hasattr(unwrapped, "__wrapped__"):
        unwrapped = unwrapped.__wrapped__
        count += 1
    return count


def _warning_stacklevel(func):
    """Find stacklevel for a warning raised from a wrapper around `func`.

    Try to determine the number of

    Parameters
    ----------
    func : Callable


    Returns
    -------
    stacklevel : int
        The stacklevel. Minimum of 2.
    """
    # Count number of wrappers around `func`
    wrapped_count = _count_wrappers(func)

    # Count number of total wrappers around global version of `func`
    module = sys.modules.get(func.__module__)
    try:
        for name in func.__qualname__.split("."):
            global_func = getattr(module, name)
    except AttributeError as e:
        raise RuntimeError(
            f"Could not access `{func.__qualname__}` in {module!r}, "
            f" may be a closure. Set stacklevel manually. ",
        ) from e
    else:
        global_wrapped_count = _count_wrappers(global_func)

    stacklevel = global_wrapped_count - wrapped_count + 1
    return max(stacklevel, 2)


def _get_stack_length(func):
    """Return function call stack length."""
    _func = func.__globals__.get(func.__name__, func)
    length = _count_wrappers(_func)
    return length


class _DecoratorBaseClass:
    """Used to manage decorators' warnings stacklevel.

    The `_stack_length` class variable is used to store the number of
    times a function is wrapped by a decorator.

    Let `stack_length` be the total number of times a decorated
    function is wrapped, and `stack_rank` be the rank of the decorator
    in the decorators stack. The stacklevel of a warning is then
    `stacklevel = 1 + stack_length - stack_rank`.
    """

    _stack_length = {}

    def get_stack_length(self, func):
        length = self._stack_length.get(func.__name__, _get_stack_length(func))
        return length


class change_default_value(_DecoratorBaseClass):
    """Decorator for changing the default value of an argument.

    Parameters
    ----------
    arg_name: str
        The name of the argument to be updated.
    new_value: any
        The argument new value.
    changed_version : str
        The package version in which the change will be introduced.
    warning_msg: str
        Optional warning message. If None, a generic warning message
        is used.

    """

    def __init__(self, arg_name, *, new_value, changed_version, warning_msg=None):
        self.arg_name = arg_name
        self.new_value = new_value
        self.warning_msg = warning_msg
        self.changed_version = changed_version

    def __call__(self, func):
        parameters = inspect.signature(func).parameters
        arg_idx = list(parameters.keys()).index(self.arg_name)
        old_value = parameters[self.arg_name].default

        stack_rank = _count_wrappers(func)

        if self.warning_msg is None:
            self.warning_msg = (
                f'The new recommended value for {self.arg_name} is '
                f'{self.new_value}. Until version {self.changed_version}, '
                f'the default {self.arg_name} value is {old_value}. '
                f'From version {self.changed_version}, the {self.arg_name} '
                f'default value will be {self.new_value}. To avoid '
                f'this warning, please explicitly set {self.arg_name} value.'
            )

        @functools.wraps(func)
        def fixed_func(*args, **kwargs):
            stacklevel = 1 + self.get_stack_length(func) - stack_rank
            if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
                # warn that arg_name default value changed:
                warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
            return func(*args, **kwargs)

        return fixed_func


class PatchClassRepr(type):
    """Control class representations in rendered signatures."""

    def __repr__(cls):
        return f"<{cls.__name__}>"


class DEPRECATED(metaclass=PatchClassRepr):
    """Signal value to help with deprecating parameters that use None.

    This is a proxy object, used to signal that a parameter has not been set.
    This is useful if ``None`` is already used for a different purpose or just
    to highlight a deprecated parameter in the signature.
    """


class deprecate_parameter:
    """Deprecate a parameter of a function.

    Parameters
    ----------
    deprecated_name : str
        The name of the deprecated parameter.
    start_version : str
        The package version in which the warning was introduced.
    stop_version : str
        The package version in which the warning will be replaced by
        an error / the deprecation is completed.
    template : str, optional
        If given, this message template is used instead of the default one.
    new_name : str, optional
        If given, the default message will recommend the new parameter name and an
        error will be raised if the user uses both old and new names for the
        same parameter.
    modify_docstring : bool, optional
        If the wrapped function has a docstring, add the deprecated parameters
        to the "Other Parameters" section.
    stacklevel : int, optional
        This decorator attempts to detect the appropriate stacklevel for the
        deprecation warning automatically. If this fails, e.g., due to
        decorating a closure, you can set the stacklevel manually. The
        outermost decorator should have stacklevel 2, the next inner one
        stacklevel 3, etc.

    Notes
    -----
    Assign `DEPRECATED` as the new default value for the deprecated parameter.
    This marks the status of the parameter also in the signature and rendered
    HTML docs.

    This decorator can be stacked to deprecate more than one parameter.

    Examples
    --------
    >>> from skimage._shared.utils import deprecate_parameter, DEPRECATED
    >>> @deprecate_parameter(
    ...     "b", new_name="c", start_version="0.1", stop_version="0.3"
    ... )
    ... def foo(a, b=DEPRECATED, *, c=None):
    ...     return a, c

    Calling ``foo(1, b=2)``  will warn with::

        FutureWarning: Parameter `b` is deprecated since version 0.1 and will
        be removed in 0.3 (or later). To avoid this warning, please use the
        parameter `c` instead. For more details, see the documentation of
        `foo`.
    """

    DEPRECATED = DEPRECATED  # Make signal value accessible for convenience

    remove_parameter_template = (
        "Parameter `{deprecated_name}` is deprecated since version "
        "{deprecated_version} and will be removed in {changed_version} (or "
        "later). To avoid this warning, please do not use the parameter "
        "`{deprecated_name}`. For more details, see the documentation of "
        "`{func_name}`."
    )

    replace_parameter_template = (
        "Parameter `{deprecated_name}` is deprecated since version "
        "{deprecated_version} and will be removed in {changed_version} (or "
        "later). To avoid this warning, please use the parameter `{new_name}` "
        "instead. For more details, see the documentation of `{func_name}`."
    )

    def __init__(
        self,
        deprecated_name,
        *,
        start_version,
        stop_version,
        template=None,
        new_name=None,
        modify_docstring=True,
        stacklevel=None,
    ):
        self.deprecated_name = deprecated_name
        self.new_name = new_name
        self.template = template
        self.start_version = start_version
        self.stop_version = stop_version
        self.modify_docstring = modify_docstring
        self.stacklevel = stacklevel

    def __call__(self, func):
        parameters = inspect.signature(func).parameters
        deprecated_idx = list(parameters.keys()).index(self.deprecated_name)
        if self.new_name:
            new_idx = list(parameters.keys()).index(self.new_name)
        else:
            new_idx = False

        if parameters[self.deprecated_name].default is not DEPRECATED:
            raise RuntimeError(
                f"Expected `{self.deprecated_name}` to have the value {DEPRECATED!r} "
                f"to indicate its status in the rendered signature."
            )

        if self.template is not None:
            template = self.template
        elif self.new_name is not None:
            template = self.replace_parameter_template
        else:
            template = self.remove_parameter_template
        warning_message = template.format(
            deprecated_name=self.deprecated_name,
            deprecated_version=self.start_version,
            changed_version=self.stop_version,
            func_name=func.__qualname__,
            new_name=self.new_name,
        )

        @functools.wraps(func)
        def fixed_func(*args, **kwargs):
            deprecated_value = DEPRECATED
            new_value = DEPRECATED

            # Extract value of deprecated parameter
            if len(args) > deprecated_idx:
                deprecated_value = args[deprecated_idx]
                # Overwrite old with DEPRECATED if replacement exists
                if self.new_name is not None:
                    args = (
                        args[:deprecated_idx]
                        + (DEPRECATED,)
                        + args[deprecated_idx + 1 :]
                    )
            if self.deprecated_name in kwargs.keys():
                deprecated_value = kwargs[self.deprecated_name]
                # Overwrite old with DEPRECATED if replacement exists
                if self.new_name is not None:
                    kwargs[self.deprecated_name] = DEPRECATED

            # Extract value of new parameter (if present)
            if new_idx is not False and len(args) > new_idx:
                new_value = args[new_idx]
            if self.new_name and self.new_name in kwargs.keys():
                new_value = kwargs[self.new_name]

            if deprecated_value is not DEPRECATED:
                stacklevel = (
                    self.stacklevel
                    if self.stacklevel is not None
                    else _warning_stacklevel(func)
                )
                warnings.warn(
                    warning_message, category=FutureWarning, stacklevel=stacklevel
                )

                if new_value is not DEPRECATED:
                    raise ValueError(
                        f"Both deprecated parameter `{self.deprecated_name}` "
                        f"and new parameter `{self.new_name}` are used. Use "
                        f"only the latter to avoid conflicting values."
                    )
                elif self.new_name is not None:
                    # Assign old value to new one
                    kwargs[self.new_name] = deprecated_value

            return func(*args, **kwargs)

        if self.modify_docstring and func.__doc__ is not None:
            newdoc = _docstring_add_deprecated(
                func, {self.deprecated_name: self.new_name}, self.start_version
            )
            fixed_func.__doc__ = newdoc

        return fixed_func


def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
    """Add deprecated kwarg(s) to the "Other Params" section of a docstring.

    Parameters
    ----------
    func : function
        The function whose docstring we wish to update.
    kwarg_mapping : dict
        A dict containing {old_arg: new_arg} key/value pairs, see
        `deprecate_parameter`.
    deprecated_version : str
        A major.minor version string specifying when old_arg was
        deprecated.

    Returns
    -------
    new_doc : str
        The updated docstring. Returns the original docstring if numpydoc is
        not available.
    """
    if func.__doc__ is None:
        return None
    try:
        from numpydoc.docscrape import FunctionDoc, Parameter
    except ImportError:
        # Return an unmodified docstring if numpydoc is not available.
        return func.__doc__

    Doc = FunctionDoc(func)
    for old_arg, new_arg in kwarg_mapping.items():
        desc = []
        if new_arg is None:
            desc.append(f'`{old_arg}` is deprecated.')
        else:
            desc.append(f'Deprecated in favor of `{new_arg}`.')

        desc += ['', f'.. deprecated:: {deprecated_version}']
        Doc['Other Parameters'].append(
            Parameter(name=old_arg, type='DEPRECATED', desc=desc)
        )
    new_docstring = str(Doc)

    # new_docstring will have a header starting with:
    #
    # .. function:: func.__name__
    #
    # and some additional blank lines. We strip these off below.
    split = new_docstring.split('\n')
    no_header = split[1:]
    while not no_header[0].strip():
        no_header.pop(0)

    # Store the initial description before any of the Parameters fields.
    # Usually this is a single line, but the while loop covers any case
    # where it is not.
    descr = no_header.pop(0)
    while no_header[0].strip():
        descr += '\n    ' + no_header.pop(0)
    descr += '\n\n'
    # '\n    ' rather than '\n' here to restore the original indentation.
    final_docstring = descr + '\n    '.join(no_header)
    # strip any extra spaces from ends of lines
    final_docstring = '\n'.join([line.rstrip() for line in final_docstring.split('\n')])
    return final_docstring


class channel_as_last_axis:
    """Decorator for automatically making channels axis last for all arrays.

    This decorator reorders axes for compatibility with functions that only
    support channels along the last axis. After the function call is complete
    the channels axis is restored back to its original position.

    Parameters
    ----------
    channel_arg_positions : tuple of int, optional
        Positional arguments at the positions specified in this tuple are
        assumed to be multichannel arrays. The default is to assume only the
        first argument to the function is a multichannel array.
    channel_kwarg_names : tuple of str, optional
        A tuple containing the names of any keyword arguments corresponding to
        multichannel arrays.
    multichannel_output : bool, optional
        A boolean that should be True if the output of the function is not a
        multichannel array and False otherwise. This decorator does not
        currently support the general case of functions with multiple outputs
        where some or all are multichannel.

    """

    def __init__(
        self,
        channel_arg_positions=(0,),
        channel_kwarg_names=(),
        multichannel_output=True,
    ):
        self.arg_positions = set(channel_arg_positions)
        self.kwarg_names = set(channel_kwarg_names)
        self.multichannel_output = multichannel_output

    def __call__(self, func):
        @functools.wraps(func)
        def fixed_func(*args, **kwargs):
            channel_axis = kwargs.get('channel_axis', None)

            if channel_axis is None:
                return func(*args, **kwargs)

            # TODO: convert scalars to a tuple in anticipation of eventually
            #       supporting a tuple of channel axes. Right now, only an
            #       integer or a single-element tuple is supported, though.
            if np.isscalar(channel_axis):
                channel_axis = (channel_axis,)
            if len(channel_axis) > 1:
                raise ValueError("only a single channel axis is currently supported")

            if channel_axis == (-1,) or channel_axis == -1:
                return func(*args, **kwargs)

            if self.arg_positions:
                new_args = []
                for pos, arg in enumerate(args):
                    if pos in self.arg_positions:
                        new_args.append(np.moveaxis(arg, channel_axis[0], -1))
                    else:
                        new_args.append(arg)
                new_args = tuple(new_args)
            else:
                new_args = args

            for name in self.kwarg_names:
                kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)

            # now that we have moved the channels axis to the last position,
            # change the channel_axis argument to -1
            kwargs["channel_axis"] = -1

            # Call the function with the fixed arguments
            out = func(*new_args, **kwargs)
            if self.multichannel_output:
                out = np.moveaxis(out, -1, channel_axis[0])
            return out

        return fixed_func


class deprecate_func(_DecoratorBaseClass):
    """Decorate a deprecated function and warn when it is called.

    Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.

    Parameters
    ----------
    deprecated_version : str
        The package version when the deprecation was introduced.
    removed_version : str
        The package version in which the deprecated function will be removed.
    hint : str, optional
        A hint on how to address this deprecation,
        e.g., "Use `skimage.submodule.alternative_func` instead."

    Examples
    --------
    >>> @deprecate_func(
    ...     deprecated_version="1.0.0",
    ...     removed_version="1.2.0",
    ...     hint="Use `bar` instead."
    ... )
    ... def foo():
    ...     pass

    Calling ``foo`` will warn with::

        FutureWarning: `foo` is deprecated since version 1.0.0
        and will be removed in version 1.2.0. Use `bar` instead.
    """

    def __init__(self, *, deprecated_version, removed_version=None, hint=None):
        self.deprecated_version = deprecated_version
        self.removed_version = removed_version
        self.hint = hint

    def __call__(self, func):
        message = (
            f"`{func.__name__}` is deprecated since version "
            f"{self.deprecated_version}"
        )
        if self.removed_version:
            message += f" and will be removed in version {self.removed_version}."
        if self.hint:
            # Prepend space and make sure it closes with "."
            message += f" {self.hint.rstrip('.')}."

        stack_rank = _count_wrappers(func)

        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            stacklevel = 1 + self.get_stack_length(func) - stack_rank
            warnings.warn(message, category=FutureWarning, stacklevel=stacklevel)
            return func(*args, **kwargs)

        # modify docstring to display deprecation warning
        doc = f'**Deprecated:** {message}'
        if wrapped.__doc__ is None:
            wrapped.__doc__ = doc
        else:
            wrapped.__doc__ = doc + '\n\n    ' + wrapped.__doc__

        return wrapped


def get_bound_method_class(m):
    """Return the class for a bound method."""
    return m.im_class if sys.version < '3' else m.__self__.__class__


def safe_as_int(val, atol=1e-3):
    """
    Attempt to safely cast values to integer format.

    Parameters
    ----------
    val : scalar or iterable of scalars
        Number or container of numbers which are intended to be interpreted as
        integers, e.g., for indexing purposes, but which may not carry integer
        type.
    atol : float
        Absolute tolerance away from nearest integer to consider values in
        ``val`` functionally integers.

    Returns
    -------
    val_int : NumPy scalar or ndarray of dtype `np.int64`
        Returns the input value(s) coerced to dtype `np.int64` assuming all
        were within ``atol`` of the nearest integer.

    Notes
    -----
    This operation calculates ``val`` modulo 1, which returns the mantissa of
    all values. Then all mantissas greater than 0.5 are subtracted from one.
    Finally, the absolute tolerance from zero is calculated. If it is less
    than ``atol`` for all value(s) in ``val``, they are rounded and returned
    in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
    returned.

    If any value(s) are outside the specified tolerance, an informative error
    is raised.

    Examples
    --------
    >>> safe_as_int(7.0)
    7

    >>> safe_as_int([9, 4, 2.9999999999])
    array([9, 4, 3])

    >>> safe_as_int(53.1)
    Traceback (most recent call last):
        ...
    ValueError: Integer argument required but received 53.1, check inputs.

    >>> safe_as_int(53.01, atol=0.01)
    53

    """
    mod = np.asarray(val) % 1  # Extract mantissa

    # Check for and subtract any mod values > 0.5 from 1
    if mod.ndim == 0:  # Scalar input, cannot be indexed
        if mod > 0.5:
            mod = 1 - mod
    else:  # Iterable input, now ndarray
        mod[mod > 0.5] = 1 - mod[mod > 0.5]  # Test on each side of nearest int

    try:
        np.testing.assert_allclose(mod, 0, atol=atol)
    except AssertionError:
        raise ValueError(
            f'Integer argument required but received ' f'{val}, check inputs.'
        )

    return np.round(val).astype(np.int64)


def check_shape_equality(*images):
    """Check that all images have the same shape"""
    image0 = images[0]
    if not all(image0.shape == image.shape for image in images[1:]):
        raise ValueError('Input images must have the same dimensions.')
    return


def slice_at_axis(sl, axis):
    """
    Construct tuple of slices to slice an array in the given dimension.

    Parameters
    ----------
    sl : slice
        The slice for the given dimension.
    axis : int
        The axis to which `sl` is applied. All other dimensions are left
        "unsliced".

    Returns
    -------
    sl : tuple of slices
        A tuple with slices matching `shape` in length.

    Examples
    --------
    >>> slice_at_axis(slice(None, 3, -1), 1)
    (slice(None, None, None), slice(None, 3, -1), Ellipsis)
    """
    return (slice(None),) * axis + (sl,) + (...,)


def reshape_nd(arr, ndim, dim):
    """Reshape a 1D array to have n dimensions, all singletons but one.

    Parameters
    ----------
    arr : array, shape (N,)
        Input array
    ndim : int
        Number of desired dimensions of reshaped array.
    dim : int
        Which dimension/axis will not be singleton-sized.

    Returns
    -------
    arr_reshaped : array, shape ([1, ...], N, [1,...])
        View of `arr` reshaped to the desired shape.

    Examples
    --------
    >>> rng = np.random.default_rng()
    >>> arr = rng.random(7)
    >>> reshape_nd(arr, 2, 0).shape
    (7, 1)
    >>> reshape_nd(arr, 3, 1).shape
    (1, 7, 1)
    >>> reshape_nd(arr, 4, -1).shape
    (1, 1, 1, 7)
    """
    if arr.ndim != 1:
        raise ValueError("arr must be a 1D array")
    new_shape = [1] * ndim
    new_shape[dim] = -1
    return np.reshape(arr, new_shape)


def check_nD(array, ndim, arg_name='image'):
    """
    Verify an array meets the desired ndims and array isn't empty.

    Parameters
    ----------
    array : array-like
        Input array to be validated
    ndim : int or iterable of ints
        Allowable ndim or ndims for the array.
    arg_name : str, optional
        The name of the array in the original function.

    """
    array = np.asanyarray(array)
    msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
    msg_empty_array = "The parameter `%s` cannot be an empty array"
    if isinstance(ndim, int):
        ndim = [ndim]
    if array.size == 0:
        raise ValueError(msg_empty_array % (arg_name))
    if array.ndim not in ndim:
        raise ValueError(
            msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
        )


def convert_to_float(image, preserve_range):
    """Convert input image to float image with the appropriate range.

    Parameters
    ----------
    image : ndarray
        Input image.
    preserve_range : bool
        Determines if the range of the image should be kept or transformed
        using img_as_float. Also see
        https://scikit-image.org/docs/dev/user_guide/data_types.html

    Notes
    -----
    * Input images with `float32` data type are not upcast.

    Returns
    -------
    image : ndarray
        Transformed version of the input.

    """
    if image.dtype == np.float16:
        return image.astype(np.float32)
    if preserve_range:
        # Convert image to double only if it is not single or double
        # precision float
        if image.dtype.char not in 'df':
            image = image.astype(float)
    else:
        from ..util.dtype import img_as_float

        image = img_as_float(image)
    return image


def _validate_interpolation_order(image_dtype, order):
    """Validate and return spline interpolation's order.

    Parameters
    ----------
    image_dtype : dtype
        Image dtype.
    order : int, optional
        The order of the spline interpolation. The order has to be in
        the range 0-5. See `skimage.transform.warp` for detail.

    Returns
    -------
    order : int
        if input order is None, returns 0 if image_dtype is bool and 1
        otherwise. Otherwise, image_dtype is checked and input order
        is validated accordingly (order > 0 is not supported for bool
        image dtype)

    """

    if order is None:
        return 0 if image_dtype == bool else 1

    if order < 0 or order > 5:
        raise ValueError("Spline interpolation order has to be in the " "range 0-5.")

    if image_dtype == bool and order != 0:
        raise ValueError(
            "Input image dtype is bool. Interpolation is not defined "
            "with bool data type. Please set order to 0 or explicitly "
            "cast input image to another data type."
        )

    return order


def _to_np_mode(mode):
    """Convert padding modes from `ndi.correlate` to `np.pad`."""
    mode_translation_dict = dict(nearest='edge', reflect='symmetric', mirror='reflect')
    if mode in mode_translation_dict:
        mode = mode_translation_dict[mode]
    return mode


def _to_ndimage_mode(mode):
    """Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
    mode_translation_dict = dict(
        constant='constant',
        edge='nearest',
        symmetric='reflect',
        reflect='mirror',
        wrap='wrap',
    )
    if mode not in mode_translation_dict:
        raise ValueError(
            f"Unknown mode: '{mode}', or cannot translate mode. The "
            f"mode should be one of 'constant', 'edge', 'symmetric', "
            f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
            f"more info."
        )
    return _fix_ndimage_mode(mode_translation_dict[mode])


def _fix_ndimage_mode(mode):
    # SciPy 1.6.0 introduced grid variants of constant and wrap which
    # have less surprising behavior for images. Use these when available
    grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
    return grid_modes.get(mode, mode)


new_float_type = {
    # preserved types
    np.float32().dtype.char: np.float32,
    np.float64().dtype.char: np.float64,
    np.complex64().dtype.char: np.complex64,
    np.complex128().dtype.char: np.complex128,
    # altered types
    np.float16().dtype.char: np.float32,
    'g': np.float64,  # np.float128 ; doesn't exist on windows
    'G': np.complex128,  # np.complex256 ; doesn't exist on windows
}


def _supported_float_type(input_dtype, allow_complex=False):
    """Return an appropriate floating-point dtype for a given dtype.

    float32, float64, complex64, complex128 are preserved.
    float16 is promoted to float32.
    complex256 is demoted to complex128.
    Other types are cast to float64.

    Parameters
    ----------
    input_dtype : np.dtype or tuple of np.dtype
        The input dtype. If a tuple of multiple dtypes is provided, each
        dtype is first converted to a supported floating point type and the
        final dtype is then determined by applying `np.result_type` on the
        sequence of supported floating point types.
    allow_complex : bool, optional
        If False, raise a ValueError on complex-valued inputs.

    Returns
    -------
    float_type : dtype
        Floating-point dtype for the image.
    """
    if isinstance(input_dtype, tuple):
        return np.result_type(*(_supported_float_type(d) for d in input_dtype))
    input_dtype = np.dtype(input_dtype)
    if not allow_complex and input_dtype.kind == 'c':
        raise ValueError("complex valued input is not supported")
    return new_float_type.get(input_dtype.char, np.float64)


def identity(image, *args, **kwargs):
    """Returns the first argument unmodified."""
    return image


def as_binary_ndarray(array, *, variable_name):
    """Return `array` as a numpy.ndarray of dtype bool.

    Raises
    ------
    ValueError:
        An error including the given `variable_name` if `array` can not be
        safely cast to a boolean array.
    """
    array = np.asarray(array)
    if array.dtype != bool:
        if np.any((array != 1) & (array != 0)):
            raise ValueError(
                f"{variable_name} array is not of dtype boolean or "
                f"contains values other than 0 and 1 so cannot be "
                f"safely cast to boolean array."
            )
    return np.asarray(array, dtype=bool)
