import math
import pytest
import numpy as np

from scipy import stats, special
import scipy._lib._elementwise_iterative_method as eim
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import array_namespace, is_cupy, is_numpy, xp_ravel, xp_size
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
                                         xp_assert_less)

from scipy.optimize.elementwise import find_minimum, find_root
from scipy.optimize._tstutils import _CHANDRUPATLA_TESTS

from itertools import permutations
from .test_zeros import TestScalarRootFinders


def _vectorize(xp):
    # xp-compatible version of np.vectorize
    # assumes arguments are all arrays of the same shape
    def decorator(f):
        def wrapped(*arg_arrays):
            shape = arg_arrays[0].shape
            arg_arrays = [xp_ravel(arg_array, xp=xp) for arg_array in arg_arrays]
            res = []
            for i in range(math.prod(shape)):
                arg_scalars = [arg_array[i] for arg_array in arg_arrays]
                res.append(f(*arg_scalars))
            return res

        return wrapped

    return decorator


# These tests were originally written for the private `optimize._chandrupatla`
# interfaces, but now we want the tests to check the behavior of the public
# `optimize.elementwise` interfaces. Therefore, rather than importing
# `_chandrupatla`/`_chandrupatla_minimize` from `_chandrupatla.py`, we import
# `find_root`/`find_minimum` from `optimize.elementwise` and wrap those
# functions to conform to the private interface. This may look a little strange,
# since it effectively just inverts the interface transformation done within the
# `find_root`/`find_minimum` functions, but it allows us to run the original,
# unmodified tests on the public interfaces, simplifying the PR that adds
# the public interfaces. We'll refactor this when we want to @parametrize the
# tests over multiple `method`s.
def _wrap_chandrupatla(func):
    def _chandrupatla_wrapper(f, *bracket, **kwargs):
        # avoid passing arguments to `find_minimum` to this function
        tol_keys = {'xatol', 'xrtol', 'fatol', 'frtol'}
        tolerances = {key: kwargs.pop(key) for key in tol_keys if key in kwargs}
        _callback = kwargs.pop('callback', None)
        if callable(_callback):
            def callback(res):
                if func == find_root:
                    res.xl, res.xr = res.bracket
                    res.fl, res.fr = res.f_bracket
                else:
                    res.xl, res.xm, res.xr = res.bracket
                    res.fl, res.fm, res.fr = res.f_bracket
                res.fun = res.f_x
                del res.bracket
                del res.f_bracket
                del res.f_x
                return _callback(res)
        else:
            callback = _callback

        res = func(f, bracket, tolerances=tolerances, callback=callback, **kwargs)
        if func == find_root:
            res.xl, res.xr = res.bracket
            res.fl, res.fr = res.f_bracket
        else:
            res.xl, res.xm, res.xr = res.bracket
            res.fl, res.fm, res.fr = res.f_bracket
        res.fun = res.f_x
        del res.bracket
        del res.f_bracket
        del res.f_x
        return res
    return _chandrupatla_wrapper


_chandrupatla_root = _wrap_chandrupatla(find_root)
_chandrupatla_minimize = _wrap_chandrupatla(find_minimum)


def f1(x):
    return 100*(1 - x**3.)**2 + (1-x**2.) + 2*(1-x)**2.


def f2(x):
    return 5 + (x - 2.)**6


def f3(x):
    xp = array_namespace(x)
    return xp.exp(x) - 5*x


def f4(x):
    return x**5. - 5*x**3. - 20.*x + 5.


def f5(x):
    return 8*x**3 - 2*x**2 - 7*x + 3


def _bracket_minimum(func, x1, x2):
    phi = 1.61803398875
    maxiter = 100
    f1 = func(x1)
    f2 = func(x2)
    step = x2 - x1
    x1, x2, f1, f2, step = ((x2, x1, f2, f1, -step) if f2 > f1
                            else (x1, x2, f1, f2, step))

    for i in range(maxiter):
        step *= phi
        x3 = x2 + step
        f3 = func(x3)
        if f3 < f2:
            x1, x2, f1, f2 = x2, x3, f2, f3
        else:
            break
    return x1, x2, x3, f1, f2, f3


cases = [
    (f1, -1, 11),
    (f1, -2, 13),
    (f1, -4, 13),
    (f1, -8, 15),
    (f1, -16, 16),
    (f1, -32, 19),
    (f1, -64, 20),
    (f1, -128, 21),
    (f1, -256, 21),
    (f1, -512, 19),
    (f1, -1024, 24),
    (f2, -1, 8),
    (f2, -2, 6),
    (f2, -4, 6),
    (f2, -8, 7),
    (f2, -16, 8),
    (f2, -32, 8),
    (f2, -64, 9),
    (f2, -128, 11),
    (f2, -256, 13),
    (f2, -512, 12),
    (f2, -1024, 13),
    (f3, -1, 11),
    (f3, -2, 11),
    (f3, -4, 11),
    (f3, -8, 10),
    (f3, -16, 14),
    (f3, -32, 12),
    (f3, -64, 15),
    (f3, -128, 18),
    (f3, -256, 18),
    (f3, -512, 19),
    (f3, -1024, 19),
    (f4, -0.05, 9),
    (f4, -0.10, 11),
    (f4, -0.15, 11),
    (f4, -0.20, 11),
    (f4, -0.25, 11),
    (f4, -0.30, 9),
    (f4, -0.35, 9),
    (f4, -0.40, 9),
    (f4, -0.45, 10),
    (f4, -0.50, 10),
    (f4, -0.55, 10),
    (f5, -0.05, 6),
    (f5, -0.10, 7),
    (f5, -0.15, 8),
    (f5, -0.20, 10),
    (f5, -0.25, 9),
    (f5, -0.30, 8),
    (f5, -0.35, 7),
    (f5, -0.40, 7),
    (f5, -0.45, 9),
    (f5, -0.50, 9),
    (f5, -0.55, 8)
]


@array_api_compatible
@pytest.mark.usefixtures("skip_xp_backends")
@pytest.mark.skip_xp_backends('jax.numpy',
                              reason='JAX arrays do not support item assignment.')
@pytest.mark.skip_xp_backends('array_api_strict',
                              reason='Currently uses fancy indexing assignment.')
class TestChandrupatlaMinimize:

    def f(self, x, loc):
        xp = array_namespace(x, loc)
        res = -xp.exp(-1/2 * (x-loc)**2) / (2*xp.pi)**0.5
        return xp.asarray(res, dtype=x.dtype)[()]

    @pytest.mark.parametrize('dtype', ('float32', 'float64'))
    @pytest.mark.parametrize('loc', [0.6, np.linspace(-1.05, 1.05, 10)])
    def test_basic(self, loc, xp, dtype):
        # Find mode of normal distribution. Compare mode against location
        # parameter and value of pdf at mode against expected pdf.
        rtol = {'float32': 5e-3, 'float64': 5e-7}[dtype]
        dtype = getattr(xp, dtype)
        bracket = (xp.asarray(xi, dtype=dtype) for xi in (-5, 0, 5))
        loc = xp.asarray(loc, dtype=dtype)
        fun = xp.broadcast_to(xp.asarray(-stats.norm.pdf(0), dtype=dtype), loc.shape)

        res = _chandrupatla_minimize(self.f, *bracket, args=(loc,))
        xp_assert_close(res.x, loc, rtol=rtol)
        xp_assert_equal(res.fun, fun)

    @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
    def test_vectorization(self, shape, xp):
        # Test for correct functionality, output shapes, and dtypes for various
        # input shapes.
        loc = xp.linspace(-0.05, 1.05, 12).reshape(shape) if shape else xp.asarray(0.6)
        args = (loc,)
        bracket = xp.asarray(-5.), xp.asarray(0.), xp.asarray(5.)
        xp_test = array_namespace(loc)  # need xp.stack

        @_vectorize(xp)
        def chandrupatla_single(loc_single):
            return _chandrupatla_minimize(self.f, *bracket, args=(loc_single,))

        def f(*args, **kwargs):
            f.f_evals += 1
            return self.f(*args, **kwargs)
        f.f_evals = 0

        res = _chandrupatla_minimize(f, *bracket, args=args)
        refs = chandrupatla_single(loc)

        attrs = ['x', 'fun', 'success', 'status', 'nfev', 'nit',
                 'xl', 'xm', 'xr', 'fl', 'fm', 'fr']
        for attr in attrs:
            ref_attr = xp_test.stack([getattr(ref, attr) for ref in refs])
            res_attr = xp_ravel(getattr(res, attr))
            xp_assert_equal(res_attr, ref_attr)
            assert getattr(res, attr).shape == shape

        xp_assert_equal(res.fun, self.f(res.x, *args))
        xp_assert_equal(res.fl, self.f(res.xl, *args))
        xp_assert_equal(res.fm, self.f(res.xm, *args))
        xp_assert_equal(res.fr, self.f(res.xr, *args))
        assert xp.max(res.nfev) == f.f_evals
        assert xp.max(res.nit) == f.f_evals - 3

        assert xp_test.isdtype(res.success.dtype, 'bool')
        assert xp_test.isdtype(res.status.dtype, 'integral')
        assert xp_test.isdtype(res.nfev.dtype, 'integral')
        assert xp_test.isdtype(res.nit.dtype, 'integral')


    def test_flags(self, xp):
        # Test cases that should produce different status flags; show that all
        # can be produced simultaneously.
        def f(xs, js):
            funcs = [lambda x: (x - 2.5) ** 2,
                     lambda x: x - 10,
                     lambda x: (x - 2.5) ** 4,
                     lambda x: xp.full_like(x, xp.asarray(xp.nan))]
            res = []
            for i in range(xp_size(js)):
                x = xs[i, ...]
                j = int(xp_ravel(js)[i])
                res.append(funcs[j](x))
            return xp.stack(res)

        args = (xp.arange(4, dtype=xp.int64),)
        bracket = (xp.asarray([0]*4, dtype=xp.float64),
                   xp.asarray([2]*4, dtype=xp.float64),
                   xp.asarray([np.pi]*4, dtype=xp.float64))
        res = _chandrupatla_minimize(f, *bracket, args=args, maxiter=10)

        ref_flags = xp.asarray([eim._ECONVERGED, eim._ESIGNERR, eim._ECONVERR,
                                eim._EVALUEERR], dtype=xp.int32)
        xp_assert_equal(res.status, ref_flags)

    def test_convergence(self, xp):
        # Test that the convergence tolerances behave as expected
        rng = np.random.default_rng(2585255913088665241)
        p = xp.asarray(rng.random(size=3))
        bracket = (xp.asarray(-5), xp.asarray(0), xp.asarray(5))
        args = (p,)
        kwargs0 = dict(args=args, xatol=0, xrtol=0, fatol=0, frtol=0)

        kwargs = kwargs0.copy()
        kwargs['xatol'] = 1e-3
        res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        j1 = xp.abs(res1.xr - res1.xl)
        tol = xp.asarray(4*kwargs['xatol'], dtype=p.dtype)
        xp_assert_less(j1, xp.full((3,), tol, dtype=p.dtype))
        kwargs['xatol'] = 1e-6
        res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        j2 = xp.abs(res2.xr - res2.xl)
        tol = xp.asarray(4*kwargs['xatol'], dtype=p.dtype)
        xp_assert_less(j2, xp.full((3,), tol, dtype=p.dtype))
        xp_assert_less(j2, j1)

        kwargs = kwargs0.copy()
        kwargs['xrtol'] = 1e-3
        res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        j1 = xp.abs(res1.xr - res1.xl)
        tol = xp.asarray(4*kwargs['xrtol']*xp.abs(res1.x), dtype=p.dtype)
        xp_assert_less(j1, tol)
        kwargs['xrtol'] = 1e-6
        res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        j2 = xp.abs(res2.xr - res2.xl)
        tol = xp.asarray(4*kwargs['xrtol']*xp.abs(res2.x), dtype=p.dtype)
        xp_assert_less(j2, tol)
        xp_assert_less(j2, j1)

        kwargs = kwargs0.copy()
        kwargs['fatol'] = 1e-3
        res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        h1 = xp.abs(res1.fl - 2 * res1.fm + res1.fr)
        tol = xp.asarray(2*kwargs['fatol'], dtype=p.dtype)
        xp_assert_less(h1, xp.full((3,), tol, dtype=p.dtype))
        kwargs['fatol'] = 1e-6
        res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        h2 = xp.abs(res2.fl - 2 * res2.fm + res2.fr)
        tol = xp.asarray(2*kwargs['fatol'], dtype=p.dtype)
        xp_assert_less(h2, xp.full((3,), tol, dtype=p.dtype))
        xp_assert_less(h2, h1)

        kwargs = kwargs0.copy()
        kwargs['frtol'] = 1e-3
        res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        h1 = xp.abs(res1.fl - 2 * res1.fm + res1.fr)
        tol = xp.asarray(2*kwargs['frtol']*xp.abs(res1.fun), dtype=p.dtype)
        xp_assert_less(h1, tol)
        kwargs['frtol'] = 1e-6
        res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
        h2 = xp.abs(res2.fl - 2 * res2.fm + res2.fr)
        tol = xp.asarray(2*kwargs['frtol']*abs(res2.fun), dtype=p.dtype)
        xp_assert_less(h2, tol)
        xp_assert_less(h2, h1)

    def test_maxiter_callback(self, xp):
        # Test behavior of `maxiter` parameter and `callback` interface
        loc = xp.asarray(0.612814)
        bracket = (xp.asarray(-5), xp.asarray(0), xp.asarray(5))
        maxiter = 5

        res = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
                                     maxiter=maxiter)
        assert not xp.any(res.success)
        assert xp.all(res.nfev == maxiter+3)
        assert xp.all(res.nit == maxiter)

        def callback(res):
            callback.iter += 1
            callback.res = res
            assert hasattr(res, 'x')
            if callback.iter == 0:
                # callback is called once with initial bracket
                assert (res.xl, res.xm, res.xr) == bracket
            else:
                changed_xr = (res.xl == callback.xl) & (res.xr != callback.xr)
                changed_xl = (res.xl != callback.xl) & (res.xr == callback.xr)
                assert xp.all(changed_xr | changed_xl)

            callback.xl = res.xl
            callback.xr = res.xr
            assert res.status == eim._EINPROGRESS
            xp_assert_equal(self.f(res.xl, loc), res.fl)
            xp_assert_equal(self.f(res.xm, loc), res.fm)
            xp_assert_equal(self.f(res.xr, loc), res.fr)
            xp_assert_equal(self.f(res.x, loc), res.fun)
            if callback.iter == maxiter:
                raise StopIteration

        callback.xl = xp.nan
        callback.xr = xp.nan
        callback.iter = -1  # callback called once before first iteration
        callback.res = None

        res2 = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
                                      callback=callback)

        # terminating with callback is identical to terminating due to maxiter
        # (except for `status`)
        for key in res.keys():
            if key == 'status':
                assert res[key] == eim._ECONVERR
                # assert callback.res[key] == eim._EINPROGRESS
                assert res2[key] == eim._ECALLBACK
            else:
                assert res2[key] == callback.res[key] == res[key]

    @pytest.mark.parametrize('case', cases)
    def test_nit_expected(self, case, xp):
        # Test that `_chandrupatla` implements Chandrupatla's algorithm:
        # in all 55 test cases, the number of iterations performed
        # matches the number reported in the original paper.
        func, x1, nit = case

        # Find bracket using the algorithm in the paper
        step = 0.2
        x2 = x1 + step
        x1, x2, x3, f1, f2, f3 = _bracket_minimum(func, x1, x2)

        # Use tolerances from original paper
        xatol = 0.0001
        fatol = 0.000001
        xrtol = 1e-16
        frtol = 1e-16

        bracket = xp.asarray(x1), xp.asarray(x2), xp.asarray(x3, dtype=xp.float64)
        res = _chandrupatla_minimize(func, *bracket, xatol=xatol,
                                     fatol=fatol, xrtol=xrtol, frtol=frtol)
        xp_assert_equal(res.nit, xp.asarray(nit, dtype=xp.int32))

    @pytest.mark.parametrize("loc", (0.65, [0.65, 0.7]))
    @pytest.mark.parametrize("dtype", ('float16', 'float32', 'float64'))
    def test_dtype(self, loc, dtype, xp):
        # Test that dtypes are preserved
        dtype = getattr(xp, dtype)

        loc = xp.asarray(loc, dtype=dtype)
        bracket = (xp.asarray(-3, dtype=dtype),
                   xp.asarray(1, dtype=dtype),
                   xp.asarray(5, dtype=dtype))

        xp_test = array_namespace(loc)  # need astype
        def f(x, loc):
            assert x.dtype == dtype
            return xp_test.astype((x - loc)**2, dtype)

        res = _chandrupatla_minimize(f, *bracket, args=(loc,))
        assert res.x.dtype == dtype
        xp_assert_close(res.x, loc, rtol=math.sqrt(xp.finfo(dtype).eps))

    def test_input_validation(self, xp):
        # Test input validation for appropriate error messages

        message = '`func` must be callable.'
        bracket = xp.asarray(-4), xp.asarray(0), xp.asarray(4)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(None, *bracket)

        message = 'Abscissae and function output must be real numbers.'
        bracket = xp.asarray(-4 + 1j), xp.asarray(0), xp.asarray(4)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket)

        message = "...be broadcast..."
        bracket = xp.asarray([-2, -3]), xp.asarray([0, 0]), xp.asarray([3, 4, 5])
        # raised by `np.broadcast, but the traceback is readable IMO
        with pytest.raises((ValueError, RuntimeError), match=message):
            _chandrupatla_minimize(lambda x: x, *bracket)

        message = "The shape of the array returned by `func` must be the same"
        bracket = xp.asarray([-3, -3]), xp.asarray([0, 0]), xp.asarray([5, 5])
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: [x[0, ...], x[1, ...], x[1, ...]],
                                   *bracket)

        message = 'Tolerances must be non-negative scalars.'
        bracket = xp.asarray(-4), xp.asarray(0), xp.asarray(4)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, xatol=-1)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, xrtol=xp.nan)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, fatol='ekki')
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, frtol=xp.nan)

        message = '`maxiter` must be a non-negative integer.'
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, maxiter=1.5)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, maxiter=-1)

        message = '`callback` must be callable.'
        with pytest.raises(ValueError, match=message):
            _chandrupatla_minimize(lambda x: x, *bracket, callback='shrubbery')

    def test_bracket_order(self, xp):
        # Confirm that order of points in bracket doesn't
        xp_test = array_namespace(xp.asarray(1.))  # need `xp.newaxis`
        loc = xp.linspace(-1, 1, 6)[:, xp_test.newaxis]
        brackets = xp.asarray(list(permutations([-5, 0, 5]))).T
        res = _chandrupatla_minimize(self.f, *brackets, args=(loc,))
        assert xp.all(xp.isclose(res.x, loc) | (res.fun == self.f(loc, loc)))
        ref = res.x[:, 0]  # all columns should be the same
        xp_test = array_namespace(loc)  # need `xp.broadcast_arrays
        xp_assert_close(*xp_test.broadcast_arrays(res.x.T, ref), rtol=1e-15)

    def test_special_cases(self, xp):
        # Test edge cases and other special cases

        # Test that integers are not passed to `f`
        xp_test = array_namespace(xp.asarray(1.))  # need `xp.isdtype`
        def f(x):
            assert xp_test.isdtype(x.dtype, "real floating")
            return (x - 1)**2

        bracket = xp.asarray(-7), xp.asarray(0), xp.asarray(8)
        with np.errstate(invalid='ignore'):
            res = _chandrupatla_minimize(f, *bracket, fatol=0, frtol=0)
        assert res.success
        xp_assert_close(res.x, xp.asarray(1.), rtol=1e-3)
        xp_assert_close(res.fun, xp.asarray(0.), atol=1e-200)

        # Test that if all elements of bracket equal minimizer, algorithm
        # reports convergence
        def f(x):
            return (x-1)**2

        bracket = xp.asarray(1), xp.asarray(1), xp.asarray(1)
        res = _chandrupatla_minimize(f, *bracket)
        assert res.success
        xp_assert_equal(res.x, xp.asarray(1.))

        # Test maxiter = 0. Should do nothing to bracket.
        def f(x):
            return (x-1)**2

        bracket = xp.asarray(-3), xp.asarray(1.1), xp.asarray(5)
        res = _chandrupatla_minimize(f, *bracket, maxiter=0)
        assert res.xl, res.xr == bracket
        assert res.nit == 0
        assert res.nfev == 3
        assert res.status == -2
        assert res.x == 1.1  # best so far

        # Test scalar `args` (not in tuple)
        def f(x, c):
            return (x-c)**2 - 1

        bracket = xp.asarray(-1), xp.asarray(0), xp.asarray(1)
        c = xp.asarray(1/3)
        res = _chandrupatla_minimize(f, *bracket, args=(c,))
        xp_assert_close(res.x, c)

        # Test zero tolerances
        def f(x):
            return -xp.sin(x)

        bracket = xp.asarray(0), xp.asarray(1), xp.asarray(xp.pi)
        res = _chandrupatla_minimize(f, *bracket, xatol=0, xrtol=0, fatol=0, frtol=0)
        assert res.success
        # found a minimum exactly (according to floating point arithmetic)
        assert res.xl < res.xm < res.xr
        assert f(res.xl) == f(res.xm) == f(res.xr)


@array_api_compatible
@pytest.mark.usefixtures("skip_xp_backends")
@pytest.mark.skip_xp_backends('array_api_strict',
                              reason='Currently uses fancy indexing assignment.')
@pytest.mark.skip_xp_backends('jax.numpy',
                              reason='JAX arrays do not support item assignment.')
@pytest.mark.skip_xp_backends('cupy',
                              reason='cupy/cupy#8391')
class TestChandrupatla(TestScalarRootFinders):

    def f(self, q, p):
        return special.ndtr(q) - p

    @pytest.mark.parametrize('p', [0.6, np.linspace(-0.05, 1.05, 10)])
    def test_basic(self, p, xp):
        # Invert distribution CDF and compare against distribution `ppf`
        a, b = xp.asarray(-5.), xp.asarray(5.)
        res = _chandrupatla_root(self.f, a, b, args=(xp.asarray(p),))
        ref = xp.asarray(stats.norm().ppf(p), dtype=xp.asarray(p).dtype)
        xp_assert_close(res.x, ref)

    @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
    def test_vectorization(self, shape, xp):
        # Test for correct functionality, output shapes, and dtypes for various
        # input shapes.
        p = (np.linspace(-0.05, 1.05, 12).reshape(shape) if shape
             else np.float64(0.6))
        p_xp = xp.asarray(p)
        args_xp = (p_xp,)
        dtype = p_xp.dtype
        xp_test = array_namespace(p_xp)  # need xp.bool

        @np.vectorize
        def chandrupatla_single(p):
            return _chandrupatla_root(self.f, -5, 5, args=(p,))

        def f(*args, **kwargs):
            f.f_evals += 1
            return self.f(*args, **kwargs)
        f.f_evals = 0

        res = _chandrupatla_root(f, xp.asarray(-5.), xp.asarray(5.), args=args_xp)
        refs = chandrupatla_single(p).ravel()

        ref_x = [ref.x for ref in refs]
        ref_x = xp.reshape(xp.asarray(ref_x, dtype=dtype), shape)
        xp_assert_close(res.x, ref_x)

        ref_fun = [ref.fun for ref in refs]
        ref_fun = xp.reshape(xp.asarray(ref_fun, dtype=dtype), shape)
        xp_assert_close(res.fun, ref_fun, atol=1e-15)
        xp_assert_equal(res.fun, self.f(res.x, *args_xp))

        ref_success = [bool(ref.success) for ref in refs]
        ref_success = xp.reshape(xp.asarray(ref_success, dtype=xp_test.bool), shape)
        xp_assert_equal(res.success, ref_success)

        ref_flag = [ref.status for ref in refs]
        ref_flag = xp.reshape(xp.asarray(ref_flag, dtype=xp.int32), shape)
        xp_assert_equal(res.status, ref_flag)

        ref_nfev = [ref.nfev for ref in refs]
        ref_nfev = xp.reshape(xp.asarray(ref_nfev, dtype=xp.int32), shape)
        if is_numpy(xp):
            xp_assert_equal(res.nfev, ref_nfev)
            assert xp.max(res.nfev) == f.f_evals
        else:  # different backend may lead to different nfev
            assert res.nfev.shape == shape
            assert res.nfev.dtype == xp.int32

        ref_nit = [ref.nit for ref in refs]
        ref_nit = xp.reshape(xp.asarray(ref_nit, dtype=xp.int32), shape)
        if is_numpy(xp):
            xp_assert_equal(res.nit, ref_nit)
            assert xp.max(res.nit) == f.f_evals-2
        else:
            assert res.nit.shape == shape
            assert res.nit.dtype == xp.int32

        ref_xl = [ref.xl for ref in refs]
        ref_xl = xp.reshape(xp.asarray(ref_xl, dtype=dtype), shape)
        xp_assert_close(res.xl, ref_xl)

        ref_xr = [ref.xr for ref in refs]
        ref_xr = xp.reshape(xp.asarray(ref_xr, dtype=dtype), shape)
        xp_assert_close(res.xr, ref_xr)

        xp_assert_less(res.xl, res.xr)
        finite = xp.isfinite(res.x)
        assert xp.all((res.x[finite] == res.xl[finite])
                      | (res.x[finite] == res.xr[finite]))

        # PyTorch and CuPy don't solve to the same accuracy as NumPy - that's OK.
        atol = 1e-15 if is_numpy(xp) else 1e-9

        ref_fl = [ref.fl for ref in refs]
        ref_fl = xp.reshape(xp.asarray(ref_fl, dtype=dtype), shape)
        xp_assert_close(res.fl, ref_fl, atol=atol)
        xp_assert_equal(res.fl, self.f(res.xl, *args_xp))

        ref_fr = [ref.fr for ref in refs]
        ref_fr = xp.reshape(xp.asarray(ref_fr, dtype=dtype), shape)
        xp_assert_close(res.fr, ref_fr, atol=atol)
        xp_assert_equal(res.fr, self.f(res.xr, *args_xp))

        assert xp.all(xp.abs(res.fun[finite]) ==
                      xp.minimum(xp.abs(res.fl[finite]),
                                 xp.abs(res.fr[finite])))

    def test_flags(self, xp):
        # Test cases that should produce different status flags; show that all
        # can be produced simultaneously.
        def f(xs, js):
            # Note that full_like and int(j) shouldn't really be required. CuPy
            # is just really picky here, so I'm making it a special case to
            # make sure the other backends work when the user is less careful.
            assert js.dtype == xp.int64
            if is_cupy(xp):
                funcs = [lambda x: x - 2.5,
                         lambda x: x - 10,
                         lambda x: (x - 0.1)**3,
                         lambda x: xp.full_like(x, xp.asarray(xp.nan))]
                return [funcs[int(j)](x) for x, j in zip(xs, js)]

            funcs = [lambda x: x - 2.5,
                     lambda x: x - 10,
                     lambda x: (x - 0.1) ** 3,
                     lambda x: xp.nan]
            return [funcs[j](x) for x, j in zip(xs, js)]

        args = (xp.arange(4, dtype=xp.int64),)
        a, b = xp.asarray([0.]*4), xp.asarray([xp.pi]*4)
        res = _chandrupatla_root(f, a, b, args=args, maxiter=2)

        ref_flags = xp.asarray([eim._ECONVERGED,
                                eim._ESIGNERR,
                                eim._ECONVERR,
                                eim._EVALUEERR], dtype=xp.int32)
        xp_assert_equal(res.status, ref_flags)

    def test_convergence(self, xp):
        # Test that the convergence tolerances behave as expected
        rng = np.random.default_rng(2585255913088665241)
        p = xp.asarray(rng.random(size=3))
        bracket = (-xp.asarray(5.), xp.asarray(5.))
        args = (p,)
        kwargs0 = dict(args=args, xatol=0, xrtol=0, fatol=0, frtol=0)

        kwargs = kwargs0.copy()
        kwargs['xatol'] = 1e-3
        res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(res1.xr - res1.xl, xp.full_like(p, xp.asarray(1e-3)))
        kwargs['xatol'] = 1e-6
        res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(res2.xr - res2.xl, xp.full_like(p, xp.asarray(1e-6)))
        xp_assert_less(res2.xr - res2.xl, res1.xr - res1.xl)

        kwargs = kwargs0.copy()
        kwargs['xrtol'] = 1e-3
        res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(res1.xr - res1.xl, 1e-3 * xp.abs(res1.x))
        kwargs['xrtol'] = 1e-6
        res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(res2.xr - res2.xl, 1e-6 * xp.abs(res2.x))
        xp_assert_less(res2.xr - res2.xl, res1.xr - res1.xl)

        kwargs = kwargs0.copy()
        kwargs['fatol'] = 1e-3
        res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(xp.abs(res1.fun), xp.full_like(p, xp.asarray(1e-3)))
        kwargs['fatol'] = 1e-6
        res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(xp.abs(res2.fun), xp.full_like(p, xp.asarray(1e-6)))
        xp_assert_less(xp.abs(res2.fun), xp.abs(res1.fun))

        kwargs = kwargs0.copy()
        kwargs['frtol'] = 1e-3
        x1, x2 = bracket
        f0 = xp.minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
        res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(xp.abs(res1.fun), 1e-3*f0)
        kwargs['frtol'] = 1e-6
        res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
        xp_assert_less(xp.abs(res2.fun), 1e-6*f0)
        xp_assert_less(xp.abs(res2.fun), xp.abs(res1.fun))

    def test_maxiter_callback(self, xp):
        # Test behavior of `maxiter` parameter and `callback` interface
        p = xp.asarray(0.612814)
        bracket = (xp.asarray(-5.), xp.asarray(5.))
        maxiter = 5

        def f(q, p):
            res = special.ndtr(q) - p
            f.x = q
            f.fun = res
            return res
        f.x = None
        f.fun = None

        res = _chandrupatla_root(f, *bracket, args=(p,), maxiter=maxiter)
        assert not xp.any(res.success)
        assert xp.all(res.nfev == maxiter+2)
        assert xp.all(res.nit == maxiter)

        def callback(res):
            callback.iter += 1
            callback.res = res
            assert hasattr(res, 'x')
            if callback.iter == 0:
                # callback is called once with initial bracket
                assert (res.xl, res.xr) == bracket
            else:
                changed = (((res.xl == callback.xl) & (res.xr != callback.xr))
                           | ((res.xl != callback.xl) & (res.xr == callback.xr)))
                assert xp.all(changed)

            callback.xl = res.xl
            callback.xr = res.xr
            assert res.status == eim._EINPROGRESS
            xp_assert_equal(self.f(res.xl, p), res.fl)
            xp_assert_equal(self.f(res.xr, p), res.fr)
            xp_assert_equal(self.f(res.x, p), res.fun)
            if callback.iter == maxiter:
                raise StopIteration
        callback.iter = -1  # callback called once before first iteration
        callback.res = None
        callback.xl = None
        callback.xr = None

        res2 = _chandrupatla_root(f, *bracket, args=(p,), callback=callback)

        # terminating with callback is identical to terminating due to maxiter
        # (except for `status`)
        for key in res.keys():
            if key == 'status':
                xp_assert_equal(res[key], xp.asarray(eim._ECONVERR, dtype=xp.int32))
                xp_assert_equal(res2[key], xp.asarray(eim._ECALLBACK, dtype=xp.int32))
            elif key.startswith('_'):
                continue
            else:
                xp_assert_equal(res2[key], res[key])

    @pytest.mark.parametrize('case', _CHANDRUPATLA_TESTS)
    def test_nit_expected(self, case, xp):
        # Test that `_chandrupatla` implements Chandrupatla's algorithm:
        # in all 40 test cases, the number of iterations performed
        # matches the number reported in the original paper.
        f, bracket, root, nfeval, id = case
        # Chandrupatla's criterion is equivalent to
        # abs(x2-x1) < 4*abs(xmin)*xrtol + xatol, but we use the more standard
        # abs(x2-x1) < abs(xmin)*xrtol + xatol. Therefore, set xrtol to 4x
        # that used by Chandrupatla in tests.
        bracket = (xp.asarray(bracket[0], dtype=xp.float64),
                   xp.asarray(bracket[1], dtype=xp.float64))
        root = xp.asarray(root, dtype=xp.float64)

        res = _chandrupatla_root(f, *bracket, xrtol=4e-10, xatol=1e-5)
        xp_assert_close(res.fun, xp.asarray(f(root), dtype=xp.float64),
                        rtol=1e-8, atol=2e-3)
        xp_assert_equal(res.nfev, xp.asarray(nfeval, dtype=xp.int32))

    @pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
    @pytest.mark.parametrize("dtype", ('float16', 'float32', 'float64'))
    def test_dtype(self, root, dtype, xp):
        # Test that dtypes are preserved
        not_numpy = not is_numpy(xp)
        if not_numpy and dtype == 'float16':
            pytest.skip("`float16` dtype only supported for NumPy arrays.")

        dtype = getattr(xp, dtype, None)
        if dtype is None:
            pytest.skip(f"{xp} does not support {dtype}")

        def f(x, root):
            res = (x - root) ** 3.
            if is_numpy(xp):  # NumPy does not preserve dtype
                return xp.asarray(res, dtype=dtype)
            return res

        a, b = xp.asarray(-3, dtype=dtype), xp.asarray(3, dtype=dtype)
        root = xp.asarray(root, dtype=dtype)
        res = _chandrupatla_root(f, a, b, args=(root,), xatol=1e-3)
        try:
            xp_assert_close(res.x, root, atol=1e-3)
        except AssertionError:
            assert res.x.dtype == dtype
            xp.all(res.fun == 0)

    def test_input_validation(self, xp):
        # Test input validation for appropriate error messages

        def func(x):
            return x

        message = '`func` must be callable.'
        with pytest.raises(ValueError, match=message):
            bracket = xp.asarray(-4), xp.asarray(4)
            _chandrupatla_root(None, *bracket)

        message = 'Abscissae and function output must be real numbers.'
        with pytest.raises(ValueError, match=message):
            bracket = xp.asarray(-4+1j), xp.asarray(4)
            _chandrupatla_root(func, *bracket)

        # raised by `np.broadcast, but the traceback is readable IMO
        message = "...not be broadcast..."  # all messages include this part
        with pytest.raises((ValueError, RuntimeError), match=message):
            bracket = xp.asarray([-2, -3]), xp.asarray([3, 4, 5])
            _chandrupatla_root(func, *bracket)

        message = "The shape of the array returned by `func`..."
        with pytest.raises(ValueError, match=message):
            bracket = xp.asarray([-3, -3]), xp.asarray([5, 5])
            _chandrupatla_root(lambda x: [x[0], x[1], x[1]], *bracket)

        message = 'Tolerances must be non-negative scalars.'
        bracket = xp.asarray(-4), xp.asarray(4)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, xatol=-1)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, xrtol=xp.nan)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, fatol='ekki')
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, frtol=xp.nan)

        message = '`maxiter` must be a non-negative integer.'
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, maxiter=1.5)
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, maxiter=-1)

        message = '`callback` must be callable.'
        with pytest.raises(ValueError, match=message):
            _chandrupatla_root(func, *bracket, callback='shrubbery')

    def test_special_cases(self, xp):
        # Test edge cases and other special cases

        # Test infinite function values
        def f(x):
            return 1 / x + 1 - 1 / (-x + 1)

        a, b = xp.asarray([0.1, 0., 0., 0.1]),  xp.asarray([0.9, 1.0, 0.9, 1.0])

        with np.errstate(divide='ignore', invalid='ignore'):
            res = _chandrupatla_root(f, a, b)

        assert xp.all(res.success)
        xp_assert_close(res.x[1:], xp.full((3,), res.x[0]))

        # Test that integers are not passed to `f`
        # (otherwise this would overflow)
        xp_test = array_namespace(a)  # need isdtype
        def f(x):
            assert xp_test.isdtype(x.dtype, "real floating")
            # this would overflow if x were an xp integer dtype
            return x ** 31 - 1

        # note that all inputs are integer type; result is automatically default float
        res = _chandrupatla_root(f, xp.asarray(-7), xp.asarray(5))
        assert res.success
        xp_assert_close(res.x, xp.asarray(1.))

        # Test that if both ends of bracket equal root, algorithm reports
        # convergence.
        def f(x, root):
            return x**2 - root

        root = xp.asarray([0, 1])
        res = _chandrupatla_root(f, xp.asarray(1), xp.asarray(1), args=(root,))
        xp_assert_equal(res.success, xp.asarray([False, True]))
        xp_assert_equal(res.x, xp.asarray([xp.nan, 1.]))

        def f(x):
            return 1/x

        with np.errstate(invalid='ignore'):
            inf = xp.asarray(xp.inf)
            res = _chandrupatla_root(f, inf, inf)
        assert res.success
        xp_assert_equal(res.x, xp.asarray(xp.inf))

        # Test maxiter = 0. Should do nothing to bracket.
        def f(x):
            return x**3 - 1

        a, b = xp.asarray(-3.), xp.asarray(5.)
        res = _chandrupatla_root(f, a, b, maxiter=0)
        xp_assert_equal(res.success, xp.asarray(False))
        xp_assert_equal(res.status, xp.asarray(-2, dtype=xp.int32))
        xp_assert_equal(res.nit, xp.asarray(0, dtype=xp.int32))
        xp_assert_equal(res.nfev, xp.asarray(2, dtype=xp.int32))
        xp_assert_equal(res.xl, a)
        xp_assert_equal(res.xr, b)
        # The `x` attribute is the one with the smaller function value
        xp_assert_equal(res.x, a)
        # Reverse bracket; check that this is still true
        res = _chandrupatla_root(f, -b, -a, maxiter=0)
        xp_assert_equal(res.x, -a)

        # Test maxiter = 1
        res = _chandrupatla_root(f, a, b, maxiter=1)
        xp_assert_equal(res.success, xp.asarray(True))
        xp_assert_equal(res.status, xp.asarray(0, dtype=xp.int32))
        xp_assert_equal(res.nit, xp.asarray(1, dtype=xp.int32))
        xp_assert_equal(res.nfev, xp.asarray(3, dtype=xp.int32))
        xp_assert_close(res.x, xp.asarray(1.))

        # Test scalar `args` (not in tuple)
        def f(x, c):
            return c*x - 1

        res = _chandrupatla_root(f, xp.asarray(-1), xp.asarray(1), args=xp.asarray(3))
        xp_assert_close(res.x, xp.asarray(1/3))

        # # TODO: Test zero tolerance
        # # ~~What's going on here - why are iterations repeated?~~
        # # tl goes to zero when xatol=xrtol=0. When function is nearly linear,
        # # this causes convergence issues.
        # def f(x):
        #     return np.cos(x)
        #
        # res = _chandrupatla_root(f, 0, np.pi, xatol=0, xrtol=0)
        # assert res.nit < 100
        # xp = np.nextafter(res.x, np.inf)
        # xm = np.nextafter(res.x, -np.inf)
        # assert np.abs(res.fun) < np.abs(f(xp))
        # assert np.abs(res.fun) < np.abs(f(xm))
