import numpy as np
import sys
import traceback

from numba import jit, njit
from numba.core import types, errors, utils
from numba.tests.support import (TestCase, expected_failure_py311,
                                 expected_failure_py312,
                                 expected_failure_py313,
                                 )
import unittest


force_pyobj_flags = {'nopython': False, 'forceobj': True}
no_pyobj_flags = {'nopython': True, '_nrt': False}
no_pyobj_flags_w_nrt = {'nopython': True, '_nrt': True}
no_gil_flags = {'nopython': True, 'nogil': True, '_nrt': True}


class MyError(Exception):
    pass


class OtherError(Exception):
    pass


class UDEArgsToSuper(Exception):
    def __init__(self, arg, value0):
        super(UDEArgsToSuper, self).__init__(arg)
        self.value0 = value0

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return False
        same = True
        same |= self.args == other.args
        same |= self.value0 == other.value0
        return same

    def __hash__(self):
        return hash((super(UDEArgsToSuper).__hash__(), self.value0))


class UDENoArgSuper(Exception):
    def __init__(self, arg, value0):
        super(UDENoArgSuper, self).__init__()
        self.deferarg = arg
        self.value0 = value0

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return False
        same = True
        same |= self.args == other.args
        same |= self.deferarg == other.deferarg
        same |= self.value0 == other.value0
        return same

    def __hash__(self):
        return hash((super(UDENoArgSuper).__hash__(), self.deferarg,
                     self.value0))


def raise_class(exc):
    def raiser(i):
        if i == 1:
            raise exc
        elif i == 2:
            raise ValueError
        elif i == 3:
            # The exception type is looked up on a module (issue #1624)
            raise np.linalg.LinAlgError
        return i
    return raiser


def raise_instance(exc, arg):
    def raiser(i):
        if i == 1:
            raise exc(arg, 1)
        elif i == 2:
            raise ValueError(arg, 2)
        elif i == 3:
            raise np.linalg.LinAlgError(arg, 3)
        return i
    return raiser


def raise_instance_runtime_args(exc):
    def raiser(i, arg):
        if i == 1:
            raise exc(arg, 1)
        elif i == 2:
            raise ValueError(arg, 2)
        elif i == 3:
            raise np.linalg.LinAlgError(arg, 3)
        return i
    return raiser


def reraise():
    raise


def outer_function(inner):
    def outer(i):
        if i == 3:
            raise OtherError("bar", 3)
        return inner(i)
    return outer


def assert_usecase(i):
    assert i == 1, "bar"


def ude_bug_usecase():
    raise UDEArgsToSuper()  # oops user forgot args to exception ctor


def raise_runtime_value(arg):
    raise ValueError(arg)


class TestRaising(TestCase):

    def test_unituple_index_error(self):
        def pyfunc(a, i):
            return a.shape[i]

        cfunc = njit((types.Array(types.int32, 1, 'A'), types.int32),)(pyfunc)

        a = np.empty(2, dtype=np.int32)

        self.assertEqual(cfunc(a, 0), pyfunc(a, 0))

        with self.assertRaises(IndexError) as cm:
            cfunc(a, 2)
        self.assertEqual(str(cm.exception), "tuple index out of range")

    def check_against_python(self, exec_mode, pyfunc, cfunc,
                             expected_error_class, *args):

        assert exec_mode in (force_pyobj_flags, no_pyobj_flags,
                             no_pyobj_flags_w_nrt, no_gil_flags)

        # invariant of mode, check the error class and args are the same
        with self.assertRaises(expected_error_class) as pyerr:
            pyfunc(*args)
        with self.assertRaises(expected_error_class) as jiterr:
            cfunc(*args)
        self.assertEqual(pyerr.exception.args, jiterr.exception.args)

        # special equality check for UDEs
        if isinstance(pyerr.exception, (UDEArgsToSuper, UDENoArgSuper)):
            self.assertTrue(pyerr.exception == jiterr.exception)

        # in npm check bottom of traceback matches as frame injection with
        # location info should ensure this
        if exec_mode is no_pyobj_flags:

            # we only care about the bottom two frames, the error and the
            # location it was raised.
            try:
                pyfunc(*args)
            except Exception:
                py_frames = traceback.format_exception(*sys.exc_info())
                expected_frames = py_frames[-2:]

            try:
                cfunc(*args)
            except Exception:
                c_frames = traceback.format_exception(*sys.exc_info())
                got_frames = c_frames[-2:]

            # check exception and the injected frame are the same
            for expf, gotf in zip(expected_frames, got_frames):
                # Note use of assertIn not assertEqual, Py 3.11 has markers (^)
                # that point to the variable causing the problem, Numba doesn't
                # do this so only the start of the string will match.
                self.assertIn(gotf, expf)

    def check_raise_class(self, flags):
        pyfunc = raise_class(MyError)
        cfunc = jit((types.int32,), **flags)(pyfunc)
        self.assertEqual(cfunc(0), 0)
        self.check_against_python(flags, pyfunc, cfunc, MyError, 1)
        self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
        self.check_against_python(flags, pyfunc, cfunc,
                                  np.linalg.linalg.LinAlgError, 3)

    def test_raise_class_nopython(self):
        self.check_raise_class(flags=no_pyobj_flags)

    def test_raise_class_objmode(self):
        self.check_raise_class(flags=force_pyobj_flags)

    def check_raise_instance(self, flags):
        for clazz in [MyError, UDEArgsToSuper,
                      UDENoArgSuper]:
            pyfunc = raise_instance(clazz, "some message")
            cfunc = jit((types.int32,), **flags)(pyfunc)

            self.assertEqual(cfunc(0), 0)
            self.check_against_python(flags, pyfunc, cfunc, clazz, 1)
            self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
            self.check_against_python(flags, pyfunc, cfunc,
                                      np.linalg.linalg.LinAlgError, 3)

    def test_raise_instance_objmode(self):
        self.check_raise_instance(flags=force_pyobj_flags)

    def test_raise_instance_nopython(self):
        self.check_raise_instance(flags=no_pyobj_flags)

    def check_raise_nested(self, flags, **jit_args):
        """
        Check exception propagation from nested functions.
        """
        for clazz in [MyError, UDEArgsToSuper,
                      UDENoArgSuper]:
            inner_pyfunc = raise_instance(clazz, "some message")
            pyfunc = outer_function(inner_pyfunc)
            inner_cfunc = jit(**jit_args)(inner_pyfunc)
            cfunc = jit(**jit_args)(outer_function(inner_cfunc))

            self.check_against_python(flags, pyfunc, cfunc, clazz, 1)
            self.check_against_python(flags, pyfunc, cfunc, ValueError, 2)
            self.check_against_python(flags, pyfunc, cfunc, OtherError, 3)

    def test_raise_nested_objmode(self):
        self.check_raise_nested(force_pyobj_flags, forceobj=True)

    def test_raise_nested_nopython(self):
        self.check_raise_nested(no_pyobj_flags, nopython=True)

    def check_reraise(self, flags):
        def raise_exc(exc):
            raise exc
        pyfunc = reraise
        cfunc = jit((), **flags)(pyfunc)
        for op, err in [(lambda : raise_exc(ZeroDivisionError),
                         ZeroDivisionError),
                        (lambda : raise_exc(UDEArgsToSuper("msg", 1)),
                         UDEArgsToSuper),
                        (lambda : raise_exc(UDENoArgSuper("msg", 1)),
                         UDENoArgSuper)]:
            def gen_impl(fn):
                def impl():
                    try:
                        op()
                    except err:
                        fn()
                return impl
            pybased = gen_impl(pyfunc)
            cbased = gen_impl(cfunc)
            self.check_against_python(flags, pybased, cbased, err,)

    def test_reraise_objmode(self):
        self.check_reraise(flags=force_pyobj_flags)

    def test_reraise_nopython(self):
        self.check_reraise(flags=no_pyobj_flags)

    def check_raise_invalid_class(self, cls, flags):
        pyfunc = raise_class(cls)
        cfunc = jit((types.int32,), **flags)(pyfunc)
        with self.assertRaises(TypeError) as cm:
            cfunc(1)
        self.assertEqual(str(cm.exception),
                         "exceptions must derive from BaseException")

    def test_raise_invalid_class_objmode(self):
        self.check_raise_invalid_class(int, flags=force_pyobj_flags)
        self.check_raise_invalid_class(1, flags=force_pyobj_flags)

    def test_raise_invalid_class_nopython(self):
        msg = "Encountered unsupported constant type used for exception"
        with self.assertRaises(errors.UnsupportedError) as raises:
            self.check_raise_invalid_class(int, flags=no_pyobj_flags)
        self.assertIn(msg, str(raises.exception))
        with self.assertRaises(errors.UnsupportedError) as raises:
            self.check_raise_invalid_class(1, flags=no_pyobj_flags)
        self.assertIn(msg, str(raises.exception))

    def test_raise_bare_string_nopython(self):
        @njit
        def foo():
            raise "illegal"
        msg = ("Directly raising a string constant as an exception is not "
               "supported")
        with self.assertRaises(errors.UnsupportedError) as raises:
            foo()
        self.assertIn(msg, str(raises.exception))

    def check_assert_statement(self, flags):
        pyfunc = assert_usecase
        cfunc = jit((types.int32,), **flags)(pyfunc)
        cfunc(1)
        self.check_against_python(flags, pyfunc, cfunc, AssertionError, 2)

    def test_assert_statement_objmode(self):
        self.check_assert_statement(flags=force_pyobj_flags)

    def test_assert_statement_nopython(self):
        self.check_assert_statement(flags=no_pyobj_flags)

    def check_raise_from_exec_string(self, flags):
        # issue #3428
        simple_raise = "def f(a):\n  raise exc('msg', 10)"
        assert_raise = "def f(a):\n  assert a != 1"
        py312_pep695_raise = "def f[T: int](a: T) -> T:\n  assert a != 1"
        py312_pep695_raise_2 = "def f[T: int\n](a: T) -> T:\n  assert a != 1"
        test_cases = [
            (assert_raise, AssertionError),
            (simple_raise, UDEArgsToSuper),
            (simple_raise, UDENoArgSuper),
        ]
        if utils.PYVERSION >= (3, 12):
            # Added for https://github.com/numba/numba/issues/9443
            test_cases.append((py312_pep695_raise, AssertionError))
            test_cases.append((py312_pep695_raise_2, AssertionError))
        for f_text, exc in test_cases:
            loc = {}
            exec(f_text, {'exc': exc}, loc)
            pyfunc = loc['f']
            cfunc = jit((types.int32,), **flags)(pyfunc)
            self.check_against_python(flags, pyfunc, cfunc, exc, 1)

    def test_assert_from_exec_string_objmode(self):
        self.check_raise_from_exec_string(flags=force_pyobj_flags)

    def test_assert_from_exec_string_nopython(self):
        self.check_raise_from_exec_string(flags=no_pyobj_flags)

    def check_user_code_error_traceback(self, flags):
        # this test checks that if a user tries to compile code that contains
        # a bug in exception initialisation (e.g. missing arg) then this also
        # has a frame injected with the location information.
        pyfunc = ude_bug_usecase
        cfunc = jit((), **flags)(pyfunc)
        self.check_against_python(flags, pyfunc, cfunc, TypeError)

    def test_user_code_error_traceback_objmode(self):
        self.check_user_code_error_traceback(flags=force_pyobj_flags)

    def test_user_code_error_traceback_nopython(self):
        self.check_user_code_error_traceback(flags=no_pyobj_flags)

    def check_raise_runtime_value(self, flags):
        pyfunc = raise_runtime_value
        cfunc = jit((types.string,), **flags)(pyfunc)
        self.check_against_python(flags, pyfunc, cfunc, ValueError, 'hello')

    def test_raise_runtime_value_objmode(self):
        self.check_raise_runtime_value(flags=force_pyobj_flags)

    def test_raise_runtime_value_nopython(self):
        self.check_raise_runtime_value(flags=no_pyobj_flags_w_nrt)

    def test_raise_runtime_value_nogil(self):
        self.check_raise_runtime_value(flags=no_gil_flags)

    def check_raise_instance_with_runtime_args(self, flags):
        for clazz in [MyError, UDEArgsToSuper,
                      UDENoArgSuper]:
            pyfunc = raise_instance_runtime_args(clazz)
            cfunc = jit((types.int32, types.string), **flags)(pyfunc)

            self.assertEqual(cfunc(0, 'test'), 0)
            self.check_against_python(flags, pyfunc, cfunc, clazz, 1, 'hello')
            self.check_against_python(flags, pyfunc, cfunc, ValueError, 2,
                                      'world')
            self.check_against_python(flags, pyfunc, cfunc,
                                      np.linalg.linalg.LinAlgError, 3, 'linalg')

    def test_raise_instance_with_runtime_args_objmode(self):
        self.check_raise_instance_with_runtime_args(flags=force_pyobj_flags)

    def test_raise_instance_with_runtime_args_nopython(self):
        self.check_raise_instance_with_runtime_args(flags=no_pyobj_flags_w_nrt)

    def test_raise_instance_with_runtime_args_nogil(self):
        self.check_raise_instance_with_runtime_args(flags=no_gil_flags)

    def test_dynamic_raise_bad_args(self):
        def raise_literal_dict():
            raise ValueError({'a': 1, 'b': np.ones(4)})

        def raise_range():
            raise ValueError(range(3))

        def raise_rng(rng):
            raise ValueError(rng.bit_generator)

        funcs = [
            (raise_literal_dict, ()),
            (raise_range, ()),
            (raise_rng, (types.npy_rng,)),
        ]

        for pyfunc, argtypes in funcs:
            msg = '.*Cannot convert native .* to a Python object.*'
            with self.assertRaisesRegex(errors.TypingError, msg):
                njit(argtypes)(pyfunc)

    def test_dynamic_raise_dict(self):
        @njit
        def raise_literal_dict2():
            raise ValueError({'a': 1, 'b': 3})

        msg = "{a: 1, b: 3}"
        with self.assertRaisesRegex(ValueError, msg):
            raise_literal_dict2()

    def test_disable_nrt(self):
        @njit(_nrt=False)
        def raise_with_no_nrt(i):
            raise ValueError(i)

        msg = 'NRT required but not enabled'
        with self.assertRaisesRegex(errors.NumbaRuntimeError, msg):
            raise_with_no_nrt(123)

    def test_try_raise(self):

        @njit
        def raise_(a):
            raise ValueError(a)

        @njit
        def try_raise(a):
            try:
                raise_(a)
            except Exception:
                pass
            return a + 1

        self.assertEqual(try_raise.py_func(3), try_raise(3))

    @expected_failure_py311
    @expected_failure_py312
    @expected_failure_py313
    def test_dynamic_raise(self):

        @njit
        def raise_(a):
            raise ValueError(a)

        @njit
        def try_raise_(a):
            try:
                raise_(a)
            except Exception:
                raise ValueError(a)

        args = [
            1,
            1.1,
            'hello',
            np.ones(3),
            [1, 2],
            (1, 2),
            set([1, 2]),
        ]
        for fn in (raise_, try_raise_):
            for arg in args:
                with self.assertRaises(ValueError) as e:
                    fn(arg)
                self.assertEqual((arg,), e.exception.args)


if __name__ == '__main__':
    unittest.main()
