import numpy as np
import pytest
from scipy import ndimage as ndi
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from skimage import color, data, transform
from skimage._shared._warnings import expected_warnings
from skimage._shared.testing import fetch, assert_stacklevel
from skimage.morphology import gray, footprints, footprint_rectangle
from skimage.util import img_as_uint, img_as_ubyte


@pytest.fixture
def cam_image():
    from skimage import data

    return np.ascontiguousarray(data.camera()[64:112, 64:96])


@pytest.fixture
def cell3d_image():
    from skimage import data

    return np.ascontiguousarray(data.cells3d()[30:48, 0, 20:36, 20:32])


gray_morphology_funcs = (
    gray.erosion,
    gray.dilation,
    gray.opening,
    gray.closing,
    gray.white_tophat,
    gray.black_tophat,
)


class TestMorphology:
    # These expected outputs were generated with skimage v0.22.0 + PR #6695
    # using:
    #
    #   from skimage.morphology.tests.test_gray import TestMorphology
    #   import numpy as np
    #   output = TestMorphology()._build_expected_output()
    #   np.savez_compressed('gray_morph_output.npz', **output)

    def _build_expected_output(self):
        def square(n):
            return footprint_rectangle((n, n))

        footprints_2D = (
            square,
            footprints.diamond,
            footprints.disk,
            footprints.star,
        )

        image = img_as_ubyte(
            transform.downscale_local_mean(color.rgb2gray(data.coffee()), (20, 20))
        )

        output = {}
        for n in range(1, 4):
            for strel in footprints_2D:
                for func in gray_morphology_funcs:
                    key = f'{strel.__name__}_{n}_{func.__name__}'
                    output[key] = func(image, strel(n))

        return output

    def test_gray_morphology(self):
        expected = dict(np.load(fetch('data/gray_morph_output.npz')))
        calculated = self._build_expected_output()
        assert_equal(expected, calculated)

    def test_gray_closing_extensive(self):
        img = data.coins()
        footprint = np.array([[0, 0, 1], [0, 1, 1], [1, 1, 1]])

        # Default mode="reflect" is not extensive for backwards-compatibility
        result_default = gray.closing(img, footprint=footprint)
        assert not np.all(result_default >= img)

        result = gray.closing(img, footprint=footprint, mode="ignore")
        assert np.all(result >= img)

    def test_gray_opening_anti_extensive(self):
        img = data.coins()
        footprint = np.array([[0, 0, 1], [0, 1, 1], [1, 1, 1]])

        # Default mode="reflect" is not extensive for backwards-compatibility
        result_default = gray.opening(img, footprint=footprint)
        assert not np.all(result_default <= img)

        result_ignore = gray.opening(img, footprint=footprint, mode="ignore")
        assert np.all(result_ignore <= img)

    @pytest.mark.parametrize("func", gray_morphology_funcs)
    @pytest.mark.parametrize("mode", gray._SUPPORTED_MODES)
    def test_supported_mode(self, func, mode):
        img = np.ones((10, 10))
        func(img, mode=mode)

    @pytest.mark.parametrize("func", gray_morphology_funcs)
    @pytest.mark.parametrize("mode", ["", "symmetric", 3, None])
    def test_unsupported_mode(self, func, mode):
        img = np.ones((10, 10))
        with pytest.raises(ValueError, match="unsupported mode"):
            func(img, mode=mode)


class TestEccentricStructuringElements:
    def setup_class(self):
        self.black_pixel = 255 * np.ones((6, 6), dtype=np.uint8)
        self.black_pixel[2, 2] = 0
        self.white_pixel = 255 - self.black_pixel
        self.footprints = [
            footprint_rectangle((2, 2)),
            footprint_rectangle((2, 1)),
            footprint_rectangle((2, 1)),
        ]

    def test_dilate_erode_symmetry(self):
        for s in self.footprints:
            c = gray.erosion(self.black_pixel, s)
            d = gray.dilation(self.white_pixel, s)
            assert np.all(c == (255 - d))

    def test_open_black_pixel(self):
        for s in self.footprints:
            gray_open = gray.opening(self.black_pixel, s)
            assert np.all(gray_open == self.black_pixel)

    def test_close_white_pixel(self):
        for s in self.footprints:
            gray_close = gray.closing(self.white_pixel, s)
            assert np.all(gray_close == self.white_pixel)

    def test_open_white_pixel(self):
        for s in self.footprints:
            assert np.all(gray.opening(self.white_pixel, s) == 0)

    def test_close_black_pixel(self):
        for s in self.footprints:
            assert np.all(gray.closing(self.black_pixel, s) == 255)

    def test_white_tophat_white_pixel(self):
        for s in self.footprints:
            tophat = gray.white_tophat(self.white_pixel, s)
            assert np.all(tophat == self.white_pixel)

    def test_black_tophat_black_pixel(self):
        for s in self.footprints:
            tophat = gray.black_tophat(self.black_pixel, s)
            assert np.all(tophat == self.white_pixel)

    def test_white_tophat_black_pixel(self):
        for s in self.footprints:
            tophat = gray.white_tophat(self.black_pixel, s)
            assert np.all(tophat == 0)

    def test_black_tophat_white_pixel(self):
        for s in self.footprints:
            tophat = gray.black_tophat(self.white_pixel, s)
            assert np.all(tophat == 0)


gray_functions = [
    gray.erosion,
    gray.dilation,
    gray.opening,
    gray.closing,
    gray.white_tophat,
    gray.black_tophat,
]


@pytest.mark.parametrize("function", gray_functions)
def test_default_footprint(function):
    strel = footprints.diamond(radius=1)
    image = np.array(
        [
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
            [0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
            [0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ],
        np.uint8,
    )
    im_expected = function(image, strel)
    im_test = function(image)
    assert_array_equal(im_expected, im_test)


def test_3d_fallback_default_footprint():
    # 3x3x3 cube inside a 7x7x7 image:
    image = np.zeros((7, 7, 7), bool)
    image[2:-2, 2:-2, 2:-2] = 1

    opened = gray.opening(image)

    # expect a "hyper-cross" centered in the 5x5x5:
    image_expected = np.zeros((7, 7, 7), dtype=bool)
    image_expected[2:5, 2:5, 2:5] = ndi.generate_binary_structure(3, 1)
    assert_array_equal(opened, image_expected)


gray_3d_fallback_functions = [gray.closing, gray.opening]


@pytest.mark.parametrize("function", gray_3d_fallback_functions)
def test_3d_fallback_cube_footprint(function):
    # 3x3x3 cube inside a 7x7x7 image:
    image = np.zeros((7, 7, 7), bool)
    image[2:-2, 2:-2, 2:-2] = 1

    cube = np.ones((3, 3, 3), dtype=np.uint8)

    new_image = function(image, cube)
    assert_array_equal(new_image, image)


def test_3d_fallback_white_tophat():
    image = np.zeros((7, 7, 7), dtype=bool)
    image[2, 2:4, 2:4] = 1
    image[3, 2:5, 2:5] = 1
    image[4, 3:5, 3:5] = 1

    with expected_warnings([r'operator.*deprecated|\A\Z']):
        new_image = gray.white_tophat(image)
    footprint = ndi.generate_binary_structure(3, 1)
    with expected_warnings([r'operator.*deprecated|\A\Z']):
        image_expected = ndi.white_tophat(
            image.view(dtype=np.uint8), footprint=footprint
        )
    assert_array_equal(new_image, image_expected)


def test_3d_fallback_black_tophat():
    image = np.ones((7, 7, 7), dtype=bool)
    image[2, 2:4, 2:4] = 0
    image[3, 2:5, 2:5] = 0
    image[4, 3:5, 3:5] = 0

    with expected_warnings([r'operator.*deprecated|\A\Z']):
        new_image = gray.black_tophat(image)
    footprint = ndi.generate_binary_structure(3, 1)
    with expected_warnings([r'operator.*deprecated|\A\Z']):
        image_expected = ndi.black_tophat(
            image.view(dtype=np.uint8), footprint=footprint
        )
    assert_array_equal(new_image, image_expected)


def test_2d_ndimage_equivalence():
    image = np.zeros((9, 9), np.uint8)
    image[2:-2, 2:-2] = 128
    image[3:-3, 3:-3] = 196
    image[4, 4] = 255

    opened = gray.opening(image)
    closed = gray.closing(image)

    footprint = ndi.generate_binary_structure(2, 1)
    ndimage_opened = ndi.grey_opening(image, footprint=footprint)
    ndimage_closed = ndi.grey_closing(image, footprint=footprint)

    assert_array_equal(opened, ndimage_opened)
    assert_array_equal(closed, ndimage_closed)


# float test images
im = np.array(
    [
        [0.55, 0.72, 0.6, 0.54, 0.42],
        [0.65, 0.44, 0.89, 0.96, 0.38],
        [0.79, 0.53, 0.57, 0.93, 0.07],
        [0.09, 0.02, 0.83, 0.78, 0.87],
        [0.98, 0.8, 0.46, 0.78, 0.12],
    ]
)

eroded = np.array(
    [
        [0.55, 0.44, 0.54, 0.42, 0.38],
        [0.44, 0.44, 0.44, 0.38, 0.07],
        [0.09, 0.02, 0.53, 0.07, 0.07],
        [0.02, 0.02, 0.02, 0.78, 0.07],
        [0.09, 0.02, 0.46, 0.12, 0.12],
    ]
)

dilated = np.array(
    [
        [0.72, 0.72, 0.89, 0.96, 0.54],
        [0.79, 0.89, 0.96, 0.96, 0.96],
        [0.79, 0.79, 0.93, 0.96, 0.93],
        [0.98, 0.83, 0.83, 0.93, 0.87],
        [0.98, 0.98, 0.83, 0.78, 0.87],
    ]
)

opened = np.array(
    [
        [0.55, 0.55, 0.54, 0.54, 0.42],
        [0.55, 0.44, 0.54, 0.44, 0.38],
        [0.44, 0.53, 0.53, 0.78, 0.07],
        [0.09, 0.02, 0.78, 0.78, 0.78],
        [0.09, 0.46, 0.46, 0.78, 0.12],
    ]
)

closed = np.array(
    [
        [0.72, 0.72, 0.72, 0.54, 0.54],
        [0.72, 0.72, 0.89, 0.96, 0.54],
        [0.79, 0.79, 0.79, 0.93, 0.87],
        [0.79, 0.79, 0.83, 0.78, 0.87],
        [0.98, 0.83, 0.78, 0.78, 0.78],
    ]
)


def test_float():
    assert_allclose(gray.erosion(im), eroded)
    assert_allclose(gray.dilation(im), dilated)
    assert_allclose(gray.opening(im), opened)
    assert_allclose(gray.closing(im), closed)


def test_uint16():
    im16, eroded16, dilated16, opened16, closed16 = map(
        img_as_uint, [im, eroded, dilated, opened, closed]
    )
    assert_allclose(gray.erosion(im16), eroded16)
    assert_allclose(gray.dilation(im16), dilated16)
    assert_allclose(gray.opening(im16), opened16)
    assert_allclose(gray.closing(im16), closed16)


def test_discontiguous_out_array():
    image = np.array([[5, 6, 2], [7, 2, 2], [3, 5, 1]], np.uint8)
    out_array_big = np.zeros((5, 5), np.uint8)
    out_array = out_array_big[::2, ::2]
    expected_dilation = np.array(
        [
            [7, 0, 6, 0, 6],
            [0, 0, 0, 0, 0],
            [7, 0, 7, 0, 2],
            [0, 0, 0, 0, 0],
            [7, 0, 5, 0, 5],
        ],
        np.uint8,
    )
    expected_erosion = np.array(
        [
            [5, 0, 2, 0, 2],
            [0, 0, 0, 0, 0],
            [2, 0, 2, 0, 1],
            [0, 0, 0, 0, 0],
            [3, 0, 1, 0, 1],
        ],
        np.uint8,
    )
    gray.dilation(image, out=out_array)
    assert_array_equal(out_array_big, expected_dilation)
    gray.erosion(image, out=out_array)
    assert_array_equal(out_array_big, expected_erosion)


def test_1d_erosion():
    image = np.array([1, 2, 3, 2, 1])
    expected = np.array([1, 1, 2, 1, 1])
    eroded = gray.erosion(image)
    assert_array_equal(eroded, expected)


@pytest.mark.parametrize(
    "function",
    ["erosion", "dilation", "closing", "opening", "white_tophat", "black_tophat"],
)
@pytest.mark.parametrize("nrows", [3, 7, 11])
@pytest.mark.parametrize("ncols", [3, 7, 11])
@pytest.mark.parametrize("decomposition", ['separable', 'sequence'])
def test_rectangle_decomposition(cam_image, function, nrows, ncols, decomposition):
    """Validate footprint decomposition for various shapes.

    comparison is made to the case without decomposition.
    """
    footprint_ndarray = footprint_rectangle((nrows, ncols), decomposition=None)
    footprint = footprint_rectangle((nrows, ncols), decomposition=decomposition)
    func = getattr(gray, function)
    expected = func(cam_image, footprint=footprint_ndarray)
    out = func(cam_image, footprint=footprint)
    assert_array_equal(expected, out)


@pytest.mark.parametrize(
    "function",
    ["erosion", "dilation", "closing", "opening", "white_tophat", "black_tophat"],
)
@pytest.mark.parametrize("radius", (2, 3))
@pytest.mark.parametrize("decomposition", ['sequence'])
def test_diamond_decomposition(cam_image, function, radius, decomposition):
    """Validate footprint decomposition for various shapes.

    comparison is made to the case without decomposition.
    """
    footprint_ndarray = footprints.diamond(radius, decomposition=None)
    footprint = footprints.diamond(radius, decomposition=decomposition)
    func = getattr(gray, function)
    expected = func(cam_image, footprint=footprint_ndarray)
    out = func(cam_image, footprint=footprint)
    assert_array_equal(expected, out)


@pytest.mark.parametrize(
    "function",
    ["erosion", "dilation", "closing", "opening", "white_tophat", "black_tophat"],
)
@pytest.mark.parametrize("m", (0, 1, 3, 5))
@pytest.mark.parametrize("n", (0, 1, 2, 3))
@pytest.mark.parametrize("decomposition", ['sequence'])
@pytest.mark.filterwarnings(
    "ignore:.*falling back to decomposition='separable':UserWarning:skimage"
)
def test_octagon_decomposition(cam_image, function, m, n, decomposition):
    """Validate footprint decomposition for various shapes.

    comparison is made to the case without decomposition.
    """
    if m == 0 and n == 0:
        with pytest.raises(ValueError):
            footprints.octagon(m, n, decomposition=decomposition)
    else:
        footprint_ndarray = footprints.octagon(m, n, decomposition=None)
        footprint = footprints.octagon(m, n, decomposition=decomposition)
        func = getattr(gray, function)
        expected = func(cam_image, footprint=footprint_ndarray)
        out = func(cam_image, footprint=footprint)
        assert_array_equal(expected, out)


@pytest.mark.parametrize(
    "function",
    ["erosion", "dilation", "closing", "opening", "white_tophat", "black_tophat"],
)
@pytest.mark.parametrize("shape", [(5, 5, 5), (5, 5, 7)])
@pytest.mark.parametrize("decomposition", ['separable', 'sequence'])
def test_cube_decomposition(cell3d_image, function, shape, decomposition):
    """Validate footprint decomposition for various shapes.

    comparison is made to the case without decomposition.
    """
    footprint_ndarray = footprint_rectangle(shape, decomposition=None)
    footprint = footprint_rectangle(shape, decomposition=decomposition)
    func = getattr(gray, function)
    expected = func(cell3d_image, footprint=footprint_ndarray)
    out = func(cell3d_image, footprint=footprint)
    assert_array_equal(expected, out)


@pytest.mark.parametrize(
    "function",
    ["erosion", "dilation", "closing", "opening", "white_tophat", "black_tophat"],
)
@pytest.mark.parametrize("radius", (3,))
@pytest.mark.parametrize("decomposition", ['sequence'])
def test_octahedron_decomposition(cell3d_image, function, radius, decomposition):
    """Validate footprint decomposition for various shapes.

    comparison is made to the case without decomposition.
    """
    footprint_ndarray = footprints.octahedron(radius, decomposition=None)
    footprint = footprints.octahedron(radius, decomposition=decomposition)
    func = getattr(gray, function)
    expected = func(cell3d_image, footprint=footprint_ndarray)
    out = func(cell3d_image, footprint=footprint)
    assert_array_equal(expected, out)


@pytest.mark.parametrize("func", [gray.erosion, gray.dilation])
@pytest.mark.parametrize("name", ["shift_x", "shift_y"])
@pytest.mark.parametrize("value", [True, False, None])
def test_deprecated_shift(func, name, value):
    img = np.ones(10)
    func(img)  # Shouldn't warn

    regex = "`shift_x` and `shift_y` are deprecated"
    with pytest.warns(FutureWarning, match=regex) as record:
        func(img, **{name: value})
    assert_stacklevel(record)
