"""
Tests for Rolling Ball Filter
(skimage.restoration.rolling_ball)
"""

import numpy as np
import pytest

from skimage import data
from skimage.restoration._rolling_ball import rolling_ball
from skimage.restoration._rolling_ball import ellipsoid_kernel


@pytest.mark.parametrize(
    'dtype', [np.uint8, np.int32, np.float16, np.float32, np.float64]
)
def test_ellipsoid_const(dtype):
    img = 155 * np.ones((100, 100), dtype=dtype)
    kernel = ellipsoid_kernel((25, 53), 50)
    background = rolling_ball(img, kernel=kernel)
    assert np.allclose(img - background, np.zeros_like(img))
    assert background.dtype == img.dtype


def test_nan_const():
    img = 123 * np.ones((100, 100), dtype=float)
    img[20, 20] = np.nan
    img[50, 53] = np.nan

    kernel_shape = (10, 10)
    x = np.arange(-kernel_shape[1] // 2, kernel_shape[1] // 2 + 1)[np.newaxis, :]
    y = np.arange(-kernel_shape[0] // 2, kernel_shape[0] // 2 + 1)[:, np.newaxis]
    expected_img = np.zeros_like(img)
    expected_img[y + 20, x + 20] = np.nan
    expected_img[y + 50, x + 53] = np.nan
    kernel = ellipsoid_kernel(kernel_shape, 100)
    background = rolling_ball(img, kernel=kernel, nansafe=True)
    assert np.allclose(img - background, expected_img, equal_nan=True)


@pytest.mark.parametrize("radius", [1, 2.5, 10.346, 50])
def test_const_image(radius):
    # infinite plane light source at top left corner
    img = 23 * np.ones((100, 100), dtype=np.uint8)
    background = rolling_ball(img, radius=radius)
    assert np.allclose(img - background, np.zeros_like(img))


def test_radial_gradient():
    # spot light source at top left corner
    spot_radius = 50
    x, y = np.meshgrid(range(5), range(5))
    img = np.sqrt(np.clip(spot_radius**2 - y**2 - x**2, 0, None))

    background = rolling_ball(img, radius=5)
    assert np.allclose(img - background, np.zeros_like(img))


def test_linear_gradient():
    # linear light source centered at top left corner
    x, y = np.meshgrid(range(100), range(100))
    img = y * 20 + x * 20

    expected_img = 19 * np.ones_like(img)
    expected_img[0, 0] = 0

    background = rolling_ball(img, radius=1)
    assert np.allclose(img - background, expected_img)


@pytest.mark.parametrize("radius", [2, 10, 12.5, 50])
def test_preserve_peaks(radius):
    x, y = np.meshgrid(range(100), range(100))
    img = 0 * x + 0 * y + 10
    img[10, 10] = 20
    img[20, 20] = 35
    img[45, 26] = 156

    expected_img = img - 10
    background = rolling_ball(img, radius=radius)
    assert np.allclose(img - background, expected_img)


@pytest.mark.parametrize("num_threads", [None, 1, 2])
def test_threads(num_threads):
    # not testing if we use multiple threads
    # just checking if the API throws an exception
    img = 23 * np.ones((100, 100), dtype=np.uint8)
    rolling_ball(img, radius=10, num_threads=num_threads)
    rolling_ball(img, radius=10, nansafe=True, num_threads=num_threads)


def test_ndim():
    image = data.cells3d()[:5, 1, ...]
    kernel = ellipsoid_kernel((3, 100, 100), 100)
    rolling_ball(image, kernel=kernel)
