import sys
import warnings

import numpy as np
from numpy.testing import assert_, assert_equal, IS_PYPY
import pytest
from pytest import raises as assert_raises

import scipy.special as sc
from scipy.special._ufuncs import _sf_error_test_function

_sf_error_code_map = {
    # skip 'ok'
    'singular': 1,
    'underflow': 2,
    'overflow': 3,
    'slow': 4,
    'loss': 5,
    'no_result': 6,
    'domain': 7,
    'arg': 8,
    'other': 9,
    'memory': 10,
}

_sf_error_actions = [
    'ignore',
    'warn',
    'raise'
]


def _check_action(fun, args, action):
    # TODO: special expert should correct
    # the coercion at the true location?
    args = np.asarray(args, dtype=np.dtype("long"))
    if action == 'warn':
        with pytest.warns(sc.SpecialFunctionWarning):
            fun(*args)
    elif action == 'raise':
        with assert_raises(sc.SpecialFunctionError):
            fun(*args)
    else:
        # action == 'ignore', make sure there are no warnings/exceptions
        with warnings.catch_warnings():
            warnings.simplefilter("error")
            fun(*args)


def test_geterr():
    err = sc.geterr()
    for key, value in err.items():
        assert_(key in _sf_error_code_map)
        assert_(value in _sf_error_actions)


@pytest.mark.thread_unsafe
def test_seterr():
    entry_err = sc.geterr()
    try:
        for category, error_code in _sf_error_code_map.items():
            for action in _sf_error_actions:
                geterr_olderr = sc.geterr()
                seterr_olderr = sc.seterr(**{category: action})
                assert_(geterr_olderr == seterr_olderr)
                newerr = sc.geterr()
                assert_(newerr[category] == action)
                geterr_olderr.pop(category)
                newerr.pop(category)
                assert_(geterr_olderr == newerr)
                _check_action(_sf_error_test_function, (error_code,), action)
    finally:
        sc.seterr(**entry_err)


@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_sf_error_special_refcount():
    # Regression test for gh-16233.
    # Check that the reference count of scipy.special is not increased
    # when a SpecialFunctionError is raised.
    refcount_before = sys.getrefcount(sc)
    with sc.errstate(all='raise'):
        with pytest.raises(sc.SpecialFunctionError, match='domain error'):
            sc.ndtri(2.0)
    refcount_after = sys.getrefcount(sc)
    assert refcount_after == refcount_before


def test_errstate_pyx_basic():
    olderr = sc.geterr()
    with sc.errstate(singular='raise'):
        with assert_raises(sc.SpecialFunctionError):
            sc.loggamma(0)
    assert_equal(olderr, sc.geterr())


def test_errstate_c_basic():
    olderr = sc.geterr()
    with sc.errstate(domain='raise'):
        with assert_raises(sc.SpecialFunctionError):
            sc.spence(-1)
    assert_equal(olderr, sc.geterr())


def test_errstate_cpp_basic():
    olderr = sc.geterr()
    with sc.errstate(underflow='raise'):
        with assert_raises(sc.SpecialFunctionError):
            sc.wrightomega(-1000)
    assert_equal(olderr, sc.geterr())


def test_errstate_cpp_scipy_special():
    olderr = sc.geterr()
    with sc.errstate(singular='raise'):
        with assert_raises(sc.SpecialFunctionError):
            sc.lambertw(0, 1)
    assert_equal(olderr, sc.geterr())


def test_errstate_cpp_alt_ufunc_machinery():
    olderr = sc.geterr()
    with sc.errstate(singular='raise'):
        with assert_raises(sc.SpecialFunctionError):
            sc.gammaln(0)
    assert_equal(olderr, sc.geterr())


@pytest.mark.thread_unsafe
def test_errstate():
    for category, error_code in _sf_error_code_map.items():
        for action in _sf_error_actions:
            olderr = sc.geterr()
            with sc.errstate(**{category: action}):
                _check_action(_sf_error_test_function, (error_code,), action)
            assert_equal(olderr, sc.geterr())


def test_errstate_all_but_one():
    olderr = sc.geterr()
    with sc.errstate(all='raise', singular='ignore'):
        sc.gammaln(0)
        with assert_raises(sc.SpecialFunctionError):
            sc.spence(-1.0)
    assert_equal(olderr, sc.geterr())
