import numpy as np
import pytest

from skimage.util.compare import compare_images
from skimage._shared.testing import assert_stacklevel


def test_compare_images_ValueError_shape():
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img2 = np.zeros((10, 1), dtype=np.uint8)
    with pytest.raises(ValueError):
        compare_images(img1, img2)


def test_compare_images_ValueError_args():
    a = np.ones((10, 10)) * 3
    b = np.zeros((10, 10))
    with pytest.raises(ValueError):
        compare_images(a, b, method="unknown")


def test_compare_images_diff():
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img1[3:8, 3:8] = 255
    img2 = np.zeros_like(img1)
    img2[3:8, 0:8] = 255
    expected_result = np.zeros_like(img1, dtype=np.float64)
    expected_result[3:8, 0:3] = 1
    result = compare_images(img1, img2, method='diff')
    np.testing.assert_array_equal(result, expected_result)


def test_compare_images_replaced_param():
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img1[3:8, 3:8] = 255
    img2 = np.zeros_like(img1)
    img2[3:8, 0:8] = 255
    expected_result = np.zeros_like(img1, dtype=np.float64)
    expected_result[3:8, 0:3] = 1

    regex = ".*Please use `image0, image1`.*"
    with pytest.warns(FutureWarning, match=regex) as record:
        result = compare_images(image1=img1, image2=img2)
    assert_stacklevel(record)
    np.testing.assert_array_equal(result, expected_result)

    with pytest.warns(FutureWarning, match=regex) as record:
        result = compare_images(image0=img1, image2=img2)
    assert_stacklevel(record)
    np.testing.assert_array_equal(result, expected_result)

    with pytest.warns(FutureWarning, match=regex) as record:
        result = compare_images(img1, image2=img2)
    assert_stacklevel(record)
    np.testing.assert_array_equal(result, expected_result)

    # Test making "method" keyword-only here as well
    # so whole test can be removed in one go
    regex = ".*Please pass `method=`.*"
    with pytest.warns(FutureWarning, match=regex) as record:
        result = compare_images(img1, img2, "diff")
    assert_stacklevel(record)
    np.testing.assert_array_equal(result, expected_result)


def test_compare_images_blend():
    img1 = np.zeros((10, 10), dtype=np.uint8)
    img1[3:8, 3:8] = 255
    img2 = np.zeros_like(img1)
    img2[3:8, 0:8] = 255
    expected_result = np.zeros_like(img1, dtype=np.float64)
    expected_result[3:8, 3:8] = 1
    expected_result[3:8, 0:3] = 0.5
    result = compare_images(img1, img2, method='blend')
    np.testing.assert_array_equal(result, expected_result)


def test_compare_images_checkerboard_default():
    img1 = np.zeros((2**4, 2**4), dtype=np.uint8)
    img2 = np.full(img1.shape, fill_value=255, dtype=np.uint8)
    res = compare_images(img1, img2, method='checkerboard')
    # fmt: off
    exp_row1 = np.array([0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1.])
    exp_row2 = np.array([1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.])
    # fmt: on
    for i in (0, 1, 4, 5, 8, 9, 12, 13):
        np.testing.assert_array_equal(res[i, :], exp_row1)
    for i in (2, 3, 6, 7, 10, 11, 14, 15):
        np.testing.assert_array_equal(res[i, :], exp_row2)


def test_compare_images_checkerboard_tuple():
    img1 = np.zeros((2**4, 2**4), dtype=np.uint8)
    img2 = np.full(img1.shape, fill_value=255, dtype=np.uint8)
    res = compare_images(img1, img2, method='checkerboard', n_tiles=(4, 8))
    exp_row1 = np.array(
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0]
    )
    exp_row2 = np.array(
        [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]
    )
    for i in (0, 1, 2, 3, 8, 9, 10, 11):
        np.testing.assert_array_equal(res[i, :], exp_row1)
    for i in (4, 5, 6, 7, 12, 13, 14, 15):
        np.testing.assert_array_equal(res[i, :], exp_row2)
