#!/usr/bin/env python

import numpy as np
import pytest
from numpy.testing import assert_allclose

import pywt
from pywt import data
from pywt._utils import AxisError

# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
atol = 1e-7


####
# 1d mra tests
####

@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
    'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
def test_mra_roundtrip(wavelet, transform, mode, dtype):
    x = data.ecg()[:64].astype(dtype)
    if x.dtype.kind == 'c':
        # fill some data for the imaginary channel
        x.imag = x[::-1].real

    if transform == 'swt':
        # swt mode only supports periodization
        if mode != 'periodization':
            with pytest.raises(ValueError):
                pywt.mra(x, wavelet, transform=transform, mode=mode)
            return

    coeffs = pywt.mra(x, wavelet, transform=transform, mode=mode)
    assert isinstance(coeffs, list)
    assert isinstance(coeffs[0], np.ndarray)
    # assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))

    y = pywt.imra(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
def test_mra_warns_on_non_orthogonal(wavelet, transform):
    dtype = np.float64
    x = data.ecg()[:64].astype(dtype)

    assert not pywt.Wavelet(wavelet).orthogonal

    if transform == 'swt':
        # bi-orthogonal wavelets raise a warning for SWT case
        msg = 'norm=True, but the wavelet is not orthogonal'
        with pytest.warns(UserWarning, match=msg):
            coeffs = pywt.mra(x, wavelet, transform=transform)
    else:
        coeffs = pywt.mra(x, wavelet, transform=transform)

    y = pywt.imra(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize('axis', [0, -1, 1, 2, -3])
@pytest.mark.parametrize('ndim', [1, 2, 3])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_mra_axis(transform, ndim, axis, dtype):
    # Test transforms over a specific axis of 1D, 2D or 3D data
    if ndim == 1:
        x = data.ecg()[:64]
    elif ndim == 2:
        x = data.camera()[:64, :32]
    elif ndim == 3:
        x = data.camera()[:48, :8]
        x = np.stack((x,) * 8, axis=-1)
    x = x.astype(dtype, copy=False)

    # out of range axis
    if axis < -x.ndim or axis >= x.ndim:
        with pytest.raises(AxisError):
            pywt.mra(x, 'db1', transform=transform, axis=axis)
        return

    coeffs = pywt.mra(x, 'db1', transform=transform, axis=axis)
    y = pywt.imra(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


####
# 2d mra tests
####

@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
    'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
def test_mra2_roundtrip(wavelet, transform, mode, dtype):
    x = data.camera()[:32, :16].astype(dtype, copy=False)
    if x.dtype.kind == 'c':
        # fill some data for the imaginary channel
        x.imag = x[::-1, :].real

    if transform == 'swt2':
        # swt mode only supports periodization
        if mode != 'periodization':
            with pytest.raises(ValueError):
                pywt.mra2(x, wavelet, transform=transform, mode=mode)
            return

    coeffs = pywt.mra2(x, wavelet, transform=transform, mode=mode)
    assert isinstance(coeffs, list)
    assert isinstance(coeffs[0], np.ndarray)
    # assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))

    y = pywt.imra2(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
def test_mra2_warns_on_non_orthogonal(wavelet, transform):
    dtype = np.float64
    x = data.camera()[:32, :8].astype(dtype, copy=False)

    assert not pywt.Wavelet(wavelet).orthogonal

    if transform == 'swt2':
        # bi-orthogonal wavelets raise a warning for SWT case
        msg = 'norm=True, but the wavelets used are not orthogonal'
        with pytest.warns(UserWarning, match=msg):
            coeffs = pywt.mra2(x, wavelet, transform=transform)
    else:
        coeffs = pywt.mra2(x, wavelet, transform=transform)

    y = pywt.imra2(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
@pytest.mark.parametrize('ndim', [2, 3])
@pytest.mark.parametrize('axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4)])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_mra2_axes(transform, axes, ndim, dtype):
    # Test transforms over various axes of 2D or 3D data.
    x = data.camera()[:32, :16].astype(dtype, copy=False)
    if ndim == 3:
        x = np.stack((x,) * 8, axis=-1)

    # out of range axis
    if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
        with pytest.raises(AxisError):
            pywt.mra2(x, 'db1', transform=transform, axes=axes)
        return

    coeffs = pywt.mra2(x, 'db1', transform=transform, axes=axes)
    y = pywt.imra2(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


####
# nd mra tests
####

@pytest.mark.slow
@pytest.mark.parametrize('wavelet', ['sym2', ])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
    'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
@pytest.mark.parametrize('ndim', [1, 2, 3])
def test_mran_roundtrip(wavelet, transform, mode, dtype, ndim):
    if ndim == 1:
        x = data.ecg()[:48].astype(dtype, copy=False)
    elif ndim == 2:
        x = data.camera()[:16, :8].astype(dtype, copy=False)
    elif ndim == 3:
        x = data.camera()[:16, :8].astype(dtype, copy=False)
        x = np.stack((x,) * 8, axis=-1)

    if x.dtype.kind == 'c':
        # fill some data for the imaginary channel
        x.imag = x[::-1, ...].real

    if transform == 'swtn':
        # swt mode only supports periodization
        if mode != 'periodization':
            with pytest.raises(ValueError):
                pywt.mran(x, wavelet, transform=transform, mode=mode)
            return

    coeffs = pywt.mran(x, wavelet, transform=transform, mode=mode)
    assert isinstance(coeffs, list)
    assert isinstance(coeffs[0], np.ndarray)
    # assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))

    y = pywt.imran(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
def test_mran_warns_on_non_orthogonal(wavelet, transform):
    dtype = np.float64
    x = data.camera()[:32, :8].astype(dtype, copy=False)

    assert not pywt.Wavelet(wavelet).orthogonal

    if transform == 'swtn':
        # bi-orthogonal wavelets raise a warning for SWT case
        msg = 'norm=True, but the wavelets used are not orthogonal'
        with pytest.warns(UserWarning, match=msg):
            coeffs = pywt.mran(x, wavelet, transform=transform)
    else:
        coeffs = pywt.mran(x, wavelet, transform=transform)

    y = pywt.imran(coeffs)
    rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
    assert_allclose(x, y, rtol=rtol, atol=rtol)


@pytest.mark.parametrize(
    'axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4), (-3, -2, -1),
             (0, 2, 1), (0, 5, 1), (0,), (1,), (2,), (-2,),  (-3,), (-4,)])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
def test_mran_axes(axes, transform):
    # Test with transforms over 1, 2 or 3 axes of 3d data.
    # Cases with out of range axes are also tested
    dtype = np.float64
    x = data.camera()[:32, :16].astype(dtype, copy=False)
    x3d = np.stack((x,) * 8, axis=-1)

    # out of range axis
    if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
        with pytest.raises(AxisError):
            pywt.mran(x, 'db1', transform='dwtn', axes=axes)
        return

    coeffs = pywt.mran(x3d, 'db1', transform='dwtn', axes=axes)
    y = pywt.imran(coeffs)
    rtol = tol_single if x3d.real.dtype.kind == 'f' else tol_double
    assert_allclose(x3d, y, rtol=rtol, atol=rtol)
