import numpy as np

from io import StringIO
from numba import cuda, float32, float64, int32, intp
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (skip_on_cudasim, skip_with_nvdisasm,
                                skip_without_nvdisasm)


@skip_on_cudasim('Simulator does not generate code to be inspected')
class TestInspect(CUDATestCase):
    @property
    def cc(self):
        return cuda.current_context().device.compute_capability

    def test_monotyped(self):
        sig = (float32, int32)

        @cuda.jit(sig)
        def foo(x, y):
            pass

        file = StringIO()
        foo.inspect_types(file=file)
        typeanno = file.getvalue()
        # Function name in annotation
        self.assertIn("foo", typeanno)
        # Signature in annotation
        self.assertIn("(float32, int32)", typeanno)
        file.close()
        # Function name in LLVM
        llvm = foo.inspect_llvm(sig)
        self.assertIn("foo", llvm)

        # Kernel in LLVM
        self.assertIn('cuda.kernel.wrapper', llvm)

        # Wrapped device function body in LLVM
        self.assertIn("define linkonce_odr i32", llvm)

        asm = foo.inspect_asm(sig)

        # Function name in PTX
        self.assertIn("foo", asm)
        # NVVM inserted comments in PTX
        self.assertIn("Generated by NVIDIA NVVM Compiler", asm)

    def test_polytyped(self):
        @cuda.jit
        def foo(x, y):
            pass

        foo[1, 1](1, 1)
        foo[1, 1](1.2, 2.4)

        file = StringIO()
        foo.inspect_types(file=file)
        typeanno = file.getvalue()
        file.close()
        # Signature in annotation
        self.assertIn("({0}, {0})".format(intp), typeanno)
        self.assertIn("(float64, float64)", typeanno)

        # Signature in LLVM dict
        llvmirs = foo.inspect_llvm()
        self.assertEqual(2, len(llvmirs), )
        self.assertIn((intp, intp), llvmirs)
        self.assertIn((float64, float64), llvmirs)

        # Function name in LLVM
        self.assertIn("foo", llvmirs[intp, intp])
        self.assertIn("foo", llvmirs[float64, float64])

        # Kernels in LLVM
        self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp])
        self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64])

        # Wrapped device function bodies in LLVM
        self.assertIn("define linkonce_odr i32", llvmirs[intp, intp])
        self.assertIn("define linkonce_odr i32", llvmirs[float64, float64])

        asmdict = foo.inspect_asm()

        # Signature in assembly dict
        self.assertEqual(2, len(asmdict), )
        self.assertIn((intp, intp), asmdict)
        self.assertIn((float64, float64), asmdict)

        # NVVM inserted in PTX
        self.assertIn("foo", asmdict[intp, intp])
        self.assertIn("foo", asmdict[float64, float64])

    def _test_inspect_sass(self, kernel, name, sass):
        # Ensure function appears in output
        seen_function = False
        for line in sass.split():
            if '.text' in line and name in line:
                seen_function = True
        self.assertTrue(seen_function)

        self.assertRegex(sass, r'//## File ".*/test_inspect.py", line [0-9]')

        # Some instructions common to all supported architectures that should
        # appear in the output
        self.assertIn('S2R', sass)   # Special register to register
        self.assertIn('BRA', sass)   # Branch
        self.assertIn('EXIT', sass)  # Exit program

    @skip_without_nvdisasm('nvdisasm needed for inspect_sass()')
    def test_inspect_sass_eager(self):
        sig = (float32[::1], int32[::1])

        @cuda.jit(sig, lineinfo=True)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        self._test_inspect_sass(add, 'add', add.inspect_sass(sig))

    @skip_without_nvdisasm('nvdisasm needed for inspect_sass()')
    def test_inspect_sass_lazy(self):
        @cuda.jit(lineinfo=True)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        x = np.arange(10).astype(np.int32)
        y = np.arange(10).astype(np.float32)
        add[1, 10](x, y)

        signature = (int32[::1], float32[::1])
        self._test_inspect_sass(add, 'add', add.inspect_sass(signature))

    @skip_with_nvdisasm('Missing nvdisasm exception only generated when it is '
                        'not present')
    def test_inspect_sass_nvdisasm_missing(self):
        @cuda.jit((float32[::1],))
        def f(x):
            x[0] = 0

        with self.assertRaises(RuntimeError) as raises:
            f.inspect_sass()

        self.assertIn('nvdisasm has not been found', str(raises.exception))

    @skip_without_nvdisasm('nvdisasm needed for inspect_sass_cfg()')
    def test_inspect_sass_cfg(self):
        sig = (float32[::1], int32[::1])

        @cuda.jit(sig)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        self.assertRegex(
            add.inspect_sass_cfg(signature=sig),
            r'digraph\s*\w\s*{(.|\n)*\n}'
        )


if __name__ == '__main__':
    unittest.main()
