import itertools
import functools

import numpy as np
from scipy import ndimage as ndi

from .._shared.utils import _supported_float_type
from ..metrics import mean_squared_error
from ..util import img_as_float


def _interpolate_image(image, *, multichannel=False):
    """Replacing each pixel in ``image`` with the average of its neighbors.

    Parameters
    ----------
    image : ndarray
        Input data to be interpolated.
    multichannel : bool, optional
        Whether the last axis of the image is to be interpreted as multiple
        channels or another spatial dimension.

    Returns
    -------
    interp : ndarray
        Interpolated version of `image`.
    """
    spatialdims = image.ndim if not multichannel else image.ndim - 1
    conv_filter = ndi.generate_binary_structure(spatialdims, 1).astype(image.dtype)
    conv_filter.ravel()[conv_filter.size // 2] = 0
    conv_filter /= conv_filter.sum()

    if multichannel:
        interp = np.zeros_like(image)
        for i in range(image.shape[-1]):
            interp[..., i] = ndi.convolve(image[..., i], conv_filter, mode='mirror')
    else:
        interp = ndi.convolve(image, conv_filter, mode='mirror')
    return interp


def _generate_grid_slice(shape, *, offset, stride=3):
    """Generate slices of uniformly-spaced points in an array.

    Parameters
    ----------
    shape : tuple of int
        Shape of the mask.
    offset : int
        The offset of the grid of ones. Iterating over ``offset`` will cover
        the entire array. It should be between 0 and ``stride ** ndim``, not
        inclusive, where ``ndim = len(shape)``.
    stride : int, optional
        The spacing between ones, used in each dimension.

    Returns
    -------
    mask : ndarray
        The mask.

    Examples
    --------
    >>> shape = (4, 4)
    >>> array = np.zeros(shape, dtype=int)
    >>> grid_slice = _generate_grid_slice(shape, offset=0, stride=2)
    >>> array[grid_slice] = 1
    >>> print(array)
    [[1 0 1 0]
     [0 0 0 0]
     [1 0 1 0]
     [0 0 0 0]]

    Changing the offset moves the location of the 1s:

    >>> array = np.zeros(shape, dtype=int)
    >>> grid_slice = _generate_grid_slice(shape, offset=3, stride=2)
    >>> array[grid_slice] = 1
    >>> print(array)
    [[0 0 0 0]
     [0 1 0 1]
     [0 0 0 0]
     [0 1 0 1]]
    """
    phases = np.unravel_index(offset, (stride,) * len(shape))
    mask = tuple(slice(p, None, stride) for p in phases)

    return mask


def denoise_invariant(
    image, denoise_function, *, stride=4, masks=None, denoiser_kwargs=None
):
    """Apply a J-invariant version of a denoising function.

    Parameters
    ----------
    image : ndarray (M[, N[, ...]][, C]) of ints, uints or floats
        Input data to be denoised. `image` can be of any numeric type,
        but it is cast into a ndarray of floats (using `img_as_float`) for the
        computation of the denoised image.
    denoise_function : function
        Original denoising function.
    stride : int, optional
        Stride used in masking procedure that converts `denoise_function`
        to J-invariance.
    masks : list of ndarray, optional
        Set of masks to use for computing J-invariant output. If `None`,
        a full set of masks covering the image will be used.
    denoiser_kwargs:
        Keyword arguments passed to `denoise_function`.

    Returns
    -------
    output : ndarray
        Denoised image, of same shape as `image`.

    Notes
    -----
    A denoising function is J-invariant if the prediction it makes for each
    pixel does not depend on the value of that pixel in the original image.
    The prediction for each pixel may instead use all the relevant information
    contained in the rest of the image, which is typically quite significant.
    Any function can be converted into a J-invariant one using a simple masking
    procedure, as described in [1].

    The pixel-wise error of a J-invariant denoiser is uncorrelated to the noise,
    so long as the noise in each pixel is independent. Consequently, the average
    difference between the denoised image and the oisy image, the
    *self-supervised loss*, is the same as the difference between the denoised
    image and the original clean image, the *ground-truth loss* (up to a
    constant).

    This means that the best J-invariant denoiser for a given image can be found
    using the noisy data alone, by selecting the denoiser minimizing the self-
    supervised loss.

    References
    ----------
    .. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
       International Conference on Machine Learning, p. 524-533 (2019).

    Examples
    --------
    >>> import skimage
    >>> from skimage.restoration import denoise_invariant, denoise_tv_chambolle
    >>> image = skimage.util.img_as_float(skimage.data.chelsea())
    >>> noisy = skimage.util.random_noise(image, var=0.2 ** 2)
    >>> denoised = denoise_invariant(noisy, denoise_function=denoise_tv_chambolle)
    """
    image = img_as_float(image)

    # promote float16->float32 if needed
    float_dtype = _supported_float_type(image.dtype)
    image = image.astype(float_dtype, copy=False)

    if denoiser_kwargs is None:
        denoiser_kwargs = {}

    multichannel = denoiser_kwargs.get('channel_axis', None) is not None
    interp = _interpolate_image(image, multichannel=multichannel)
    output = np.zeros_like(image)

    if masks is None:
        spatialdims = image.ndim if not multichannel else image.ndim - 1
        n_masks = stride**spatialdims
        masks = (
            _generate_grid_slice(image.shape[:spatialdims], offset=idx, stride=stride)
            for idx in range(n_masks)
        )

    for mask in masks:
        input_image = image.copy()
        input_image[mask] = interp[mask]
        output[mask] = denoise_function(input_image, **denoiser_kwargs)[mask]
    return output


def _product_from_dict(dictionary):
    """Utility function to convert parameter ranges to parameter combinations.

    Converts a dict of lists into a list of dicts whose values consist of the
    cartesian product of the values in the original dict.

    Parameters
    ----------
    dictionary : dict of lists
        Dictionary of lists to be multiplied.

    Yields
    ------
    selections : dicts of values
        Dicts containing individual combinations of the values in the input
        dict.
    """
    keys = dictionary.keys()
    for element in itertools.product(*dictionary.values()):
        yield dict(zip(keys, element))


def calibrate_denoiser(
    image,
    denoise_function,
    denoise_parameters,
    *,
    stride=4,
    approximate_loss=True,
    extra_output=False,
):
    """Calibrate a denoising function and return optimal J-invariant version.

    The returned function is partially evaluated with optimal parameter values
    set for denoising the input image.

    Parameters
    ----------
    image : ndarray
        Input data to be denoised (converted using `img_as_float`).
    denoise_function : function
        Denoising function to be calibrated.
    denoise_parameters : dict of list
        Ranges of parameters for `denoise_function` to be calibrated over.
    stride : int, optional
        Stride used in masking procedure that converts `denoise_function`
        to J-invariance.
    approximate_loss : bool, optional
        Whether to approximate the self-supervised loss used to evaluate the
        denoiser by only computing it on one masked version of the image.
        If False, the runtime will be a factor of `stride**image.ndim` longer.
    extra_output : bool, optional
        If True, return parameters and losses in addition to the calibrated
        denoising function

    Returns
    -------
    best_denoise_function : function
        The optimal J-invariant version of `denoise_function`.

    If `extra_output` is True, the following tuple is also returned:

    (parameters_tested, losses) : tuple (list of dict, list of int)
        List of parameters tested for `denoise_function`, as a dictionary of
        kwargs
        Self-supervised loss for each set of parameters in `parameters_tested`.


    Notes
    -----

    The calibration procedure uses a self-supervised mean-square-error loss
    to evaluate the performance of J-invariant versions of `denoise_function`.
    The minimizer of the self-supervised loss is also the minimizer of the
    ground-truth loss (i.e., the true MSE error) [1]. The returned function
    can be used on the original noisy image, or other images with similar
    characteristics.

    Increasing the stride increases the performance of `best_denoise_function`
     at the expense of increasing its runtime. It has no effect on the runtime
     of the calibration.

    References
    ----------
    .. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
           International Conference on Machine Learning, p. 524-533 (2019).

    Examples
    --------
    >>> from skimage import color, data
    >>> from skimage.restoration import denoise_tv_chambolle
    >>> import numpy as np
    >>> img = color.rgb2gray(data.astronaut()[:50, :50])
    >>> rng = np.random.default_rng()
    >>> noisy = img + 0.5 * img.std() * rng.standard_normal(img.shape)
    >>> parameters = {'weight': np.arange(0.01, 0.3, 0.02)}
    >>> denoising_function = calibrate_denoiser(noisy, denoise_tv_chambolle,
    ...                                         denoise_parameters=parameters)
    >>> denoised_img = denoising_function(img)

    """
    parameters_tested, losses = _calibrate_denoiser_search(
        image,
        denoise_function,
        denoise_parameters=denoise_parameters,
        stride=stride,
        approximate_loss=approximate_loss,
    )

    idx = np.argmin(losses)
    best_parameters = parameters_tested[idx]

    best_denoise_function = functools.partial(
        denoise_invariant,
        denoise_function=denoise_function,
        stride=stride,
        denoiser_kwargs=best_parameters,
    )

    if extra_output:
        return best_denoise_function, (parameters_tested, losses)
    else:
        return best_denoise_function


def _calibrate_denoiser_search(
    image, denoise_function, denoise_parameters, *, stride=4, approximate_loss=True
):
    """Return a parameter search history with losses for a denoise function.

    Parameters
    ----------
    image : ndarray
        Input data to be denoised (converted using `img_as_float`).
    denoise_function : function
        Denoising function to be calibrated.
    denoise_parameters : dict of list
        Ranges of parameters for `denoise_function` to be calibrated over.
    stride : int, optional
        Stride used in masking procedure that converts `denoise_function`
        to J-invariance.
    approximate_loss : bool, optional
        Whether to approximate the self-supervised loss used to evaluate the
        denoiser by only computing it on one masked version of the image.
        If False, the runtime will be a factor of `stride**image.ndim` longer.

    Returns
    -------
    parameters_tested : list of dict
        List of parameters tested for `denoise_function`, as a dictionary of
        kwargs.
    losses : list of int
        Self-supervised loss for each set of parameters in `parameters_tested`.
    """
    image = img_as_float(image)
    parameters_tested = list(_product_from_dict(denoise_parameters))
    losses = []

    for denoiser_kwargs in parameters_tested:
        multichannel = denoiser_kwargs.get('channel_axis', None) is not None
        if not approximate_loss:
            denoised = denoise_invariant(
                image, denoise_function, stride=stride, denoiser_kwargs=denoiser_kwargs
            )
            loss = mean_squared_error(image, denoised)
        else:
            spatialdims = image.ndim if not multichannel else image.ndim - 1
            n_masks = stride**spatialdims
            mask = _generate_grid_slice(
                image.shape[:spatialdims], offset=n_masks // 2, stride=stride
            )

            masked_denoised = denoise_invariant(
                image, denoise_function, masks=[mask], denoiser_kwargs=denoiser_kwargs
            )

            loss = mean_squared_error(image[mask], masked_denoised[mask])

        losses.append(loss)

    return parameters_tested, losses
