import functools
import numpy as np
import pytest

from skimage._shared.testing import assert_
from skimage._shared.utils import _supported_float_type
from skimage.data import binary_blobs
from skimage.data import camera, chelsea
from skimage.metrics import mean_squared_error as mse
from skimage.restoration import calibrate_denoiser, denoise_wavelet
from skimage.restoration.j_invariant import denoise_invariant
from skimage.util import img_as_float, random_noise
from skimage.restoration.tests.test_denoise import xfail_without_pywt

test_img = img_as_float(camera())
test_img_color = img_as_float(chelsea())
test_img_3d = img_as_float(binary_blobs(64, n_dim=3)) / 2
noisy_img = random_noise(test_img, mode='gaussian', var=0.01)
noisy_img_color = random_noise(test_img_color, mode='gaussian', var=0.01)
noisy_img_3d = random_noise(test_img_3d, mode='gaussian', var=0.1)

_denoise_wavelet = functools.partial(denoise_wavelet, rescale_sigma=True)


@xfail_without_pywt
def test_invariant_denoise():
    denoised_img = denoise_invariant(noisy_img, _denoise_wavelet)

    denoised_mse = mse(denoised_img, test_img)
    original_mse = mse(noisy_img, test_img)
    assert_(denoised_mse < original_mse)


@xfail_without_pywt
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_invariant_denoise_color(dtype):
    denoised_img_color = denoise_invariant(
        noisy_img_color.astype(dtype),
        _denoise_wavelet,
        denoiser_kwargs=dict(channel_axis=-1),
    )
    denoised_mse = mse(denoised_img_color, test_img_color)
    original_mse = mse(noisy_img_color, test_img_color)
    assert denoised_mse < original_mse
    assert denoised_img_color.dtype == _supported_float_type(dtype)


@xfail_without_pywt
def test_invariant_denoise_3d():
    denoised_img_3d = denoise_invariant(noisy_img_3d, _denoise_wavelet)

    denoised_mse = mse(denoised_img_3d, test_img_3d)
    original_mse = mse(noisy_img_3d, test_img_3d)
    assert_(denoised_mse < original_mse)


@xfail_without_pywt
def test_calibrate_denoiser_extra_output():
    parameter_ranges = {'sigma': np.linspace(0.1, 1, 5) / 2}
    _, (parameters_tested, losses) = calibrate_denoiser(
        noisy_img,
        _denoise_wavelet,
        denoise_parameters=parameter_ranges,
        extra_output=True,
    )

    all_denoised = [
        denoise_invariant(noisy_img, _denoise_wavelet, denoiser_kwargs=denoiser_kwargs)
        for denoiser_kwargs in parameters_tested
    ]

    ground_truth_losses = [mse(img, test_img) for img in all_denoised]
    assert_(np.argmin(losses) == np.argmin(ground_truth_losses))


@xfail_without_pywt
def test_calibrate_denoiser():
    parameter_ranges = {'sigma': np.linspace(0.1, 1, 5) / 2}

    denoiser = calibrate_denoiser(
        noisy_img, _denoise_wavelet, denoise_parameters=parameter_ranges
    )

    denoised_mse = mse(denoiser(noisy_img), test_img)
    original_mse = mse(noisy_img, test_img)
    assert_(denoised_mse < original_mse)


@xfail_without_pywt
def test_input_image_not_modified():
    input_image = noisy_img.copy()

    parameter_ranges = {'sigma': np.random.random(5) / 2}
    calibrate_denoiser(
        input_image, _denoise_wavelet, denoise_parameters=parameter_ranges
    )

    assert_(np.all(noisy_img == input_image))
