from functools import partial, reduce

import numpy as np

from ._multilevel import (
    _prep_axes_wavedecn,
    wavedec,
    wavedec2,
    wavedecn,
    waverec,
    waverec2,
    waverecn,
)
from ._swt import iswt, iswt2, iswtn, swt, swt2, swt_max_level, swtn
from ._utils import _modes_per_axis, _wavelets_per_axis

__all__ = ["mra", "mra2", "mran", "imra", "imra2", "imran"]


def mra(data, wavelet, level=None, axis=-1, transform='swt',
        mode='periodization'):
    """Forward 1D multiresolution analysis.

    It is a projection onto the wavelet subspaces.

    Parameters
    ----------
    data: array_like
        Input data
    wavelet : Wavelet object or name string
        Wavelet to use
    level : int, optional
        Decomposition level (must be >= 0). If level is None (default) then it
        will be calculated using the `dwt_max_level` function.
    axis: int, optional
        Axis over which to compute the DWT. If not given, the last axis is
        used. Currently only available when ``transform='dwt'``.
    transform : {'dwt', 'swt'}
        Whether to use the DWT or SWT for the transforms.
    mode : str, optional
        Signal extension mode, see `Modes` (default: 'symmetric'). This option
        is only used when transform='dwt'.

    Returns
    -------
    [cAn, {details_level_n}, ... {details_level_1}] : list
        For more information, see the detailed description in `wavedec`

    See Also
    --------
    imra, swt

    Notes
    -----
    This is sometimes referred to as an additive decomposition because the
    inverse transform (``imra``) is just the sum of the coefficient arrays
    [1]_. The decomposition using ``transform='dwt'`` corresponds to section
    2.2 while that using an undecimated transform (``transform='swt'``) is
    described in section 3.2 and appendix A.

    This transform does not share the variance partition property of ``swt``
    with `norm=True`. It does however, result in coefficients that are
    temporally aligned regardless of the symmetry of the wavelet used.

    The redundancy of this transform is ``(level + 1)``.

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551

    """
    if transform == 'swt':
        if mode != 'periodization':
            raise ValueError(
                "transform swt only supports mode='periodization'")
        kwargs = {"wavelet": wavelet, "axis": axis, "norm": True}
        forward = partial(swt, level=level, trim_approx=True, **kwargs)
        inverse = partial(iswt, **kwargs)
        is_swt = True
    elif transform == 'dwt':
        kwargs = {"wavelet": wavelet, "mode": mode, "axis": axis}
        forward = partial(wavedec, level=level, **kwargs)
        inverse = partial(waverec, **kwargs)
        is_swt = False
    else:
        raise ValueError(f"unrecognized transform: {transform}")

    wav_coeffs = forward(data)

    mra_coeffs = []
    nc = len(wav_coeffs)

    if is_swt:
        # replicate same zeros array to save memory
        z = np.zeros_like(wav_coeffs[0])
        tmp = [z, ] * nc
    else:
        # zero arrays have variable size in DWT case
        tmp = [np.zeros_like(c) for c in wav_coeffs]

    for j in range(nc):
        # tmp has arrays of zeros except for the jth entry
        tmp[j] = wav_coeffs[j]

        # reconstruct
        rec = inverse(tmp)
        if rec.shape != data.shape:
            # trim any excess coefficients
            rec = rec[tuple([slice(sz) for sz in data.shape])]
        mra_coeffs.append(rec)

        # restore zeros
        if is_swt:
            tmp[j] = z
        else:
            tmp[j] = np.zeros_like(tmp[j])
    return mra_coeffs


def imra(mra_coeffs):
    """Inverse 1D multiresolution analysis via summation.

    Parameters
    ----------
    mra_coeffs : list of ndarray
        Multiresolution analysis coefficients as returned by `mra`.

    Returns
    -------
    rec : ndarray
        The reconstructed signal.

    See Also
    --------
    mra

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551
    """
    return reduce(lambda x, y: x + y, mra_coeffs)


def mra2(data, wavelet, level=None, axes=(-2, -1), transform='swt2',
         mode='periodization'):
    """Forward 2D multiresolution analysis.

    It is a projection onto wavelet subspaces.

    Parameters
    ----------
    data: array_like
        Input data
    wavelet : Wavelet object or name string, or 2-tuple of wavelets
        Wavelet to use.  This can also be a tuple containing a wavelet to
        apply along each axis in `axes`.
    level : int, optional
        Decomposition level (must be >= 0). If level is None (default) then it
        will be calculated using the `dwt_max_level` function.
    axes : 2-tuple of ints, optional
        Axes over which to compute the DWT. Repeated elements are not allowed.
        Currently only available when ``transform='dwt2'``.
    transform : {'dwt2', 'swt2'}
        Whether to use the DWT or SWT for the transforms.
    mode : str or 2-tuple of str, optional
        Signal extension mode, see `Modes` (default: 'symmetric'). This option
        is only used when transform='dwt2'.

    Returns
    -------
    coeffs : list
        For more information, see the detailed description in `wavedec2`

    Notes
    -----
    This is sometimes referred to as an additive decomposition because the
    inverse transform (``imra2``) is just the sum of the coefficient arrays
    [1]_. The decomposition using ``transform='dwt'`` corresponds to section
    2.2 while that using an undecimated transform (``transform='swt'``) is
    described in section 3.2 and appendix A.

    This transform does not share the variance partition property of ``swt2``
    with `norm=True`. It does however, result in coefficients that are
    temporally aligned regardless of the symmetry of the wavelet used.

    The redundancy of this transform is ``3 * level + 1``.

    See Also
    --------
    imra2, swt2

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551
    """
    if transform == 'swt2':
        if mode != 'periodization':
            raise ValueError(
                "transform swt only supports mode='periodization'")
        if level is None:
            level = min(swt_max_level(s) for s in data.shape)
        kwargs = {"wavelet": wavelet, "axes": axes, "norm": True}
        forward = partial(swt2, level=level, trim_approx=True, **kwargs)
        inverse = partial(iswt2, **kwargs)
    elif transform == 'dwt2':
        kwargs = {"wavelet": wavelet, "mode": mode, "axes": axes}
        forward = partial(wavedec2, level=level, **kwargs)
        inverse = partial(waverec2, **kwargs)
    else:
        raise ValueError(f"unrecognized transform: {transform}")

    wav_coeffs = forward(data)

    mra_coeffs = []
    nc = len(wav_coeffs)
    z = np.zeros_like(wav_coeffs[0])
    tmp = [z]
    for j in range(1, nc):
        tmp.append([np.zeros_like(c) for c in wav_coeffs[j]])

    # tmp has arrays of zeros except for the jth entry
    tmp[0] = wav_coeffs[0]
    # reconstruct
    rec = inverse(tmp)
    if rec.shape != data.shape:
        # trim any excess coefficients
        rec = rec[tuple([slice(sz) for sz in data.shape])]
    mra_coeffs.append(rec)
    # restore zeros
    tmp[0] = z

    for j in range(1, nc):
        dcoeffs = []
        for n in range(3):
            # tmp has arrays of zeros except for the jth entry
            z = tmp[j][n]
            tmp[j][n] = wav_coeffs[j][n]
            # reconstruct
            rec = inverse(tmp)
            if rec.shape != data.shape:
                # trim any excess coefficients
                rec = rec[tuple([slice(sz) for sz in data.shape])]
            dcoeffs.append(rec)
            # restore zeros
            tmp[j][n] = z
        mra_coeffs.append(tuple(dcoeffs))
    return mra_coeffs


def imra2(mra_coeffs):
    """Inverse 2D multiresolution analysis via summation.

    Parameters
    ----------
    mra_coeffs : list
        Multiresolution analysis coefficients as returned by `mra2`.

    Returns
    -------
    rec : ndarray
        The reconstructed signal.

    See Also
    --------
    mra2

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551
    """
    rec = mra_coeffs[0]
    for j in range(1, len(mra_coeffs)):
        for n in range(3):
            rec += mra_coeffs[j][n]
    return rec


def mran(data, wavelet, level=None, axes=None, transform='swtn',
         mode='periodization'):
    """Forward nD multiresolution analysis.

    It is a projection onto the wavelet subspaces.

    Parameters
    ----------
    data: array_like
        Input data
    wavelet : Wavelet object or name string, or tuple of wavelets
        Wavelet to use. This can also be a tuple containing a wavelet to
        apply along each axis in `axes`.
    level : int, optional
        Decomposition level (must be >= 0). If level is None (default) then it
        will be calculated using the `dwt_max_level` function.
    axes : tuple of ints, optional
        Axes over which to compute the DWT. Repeated elements are not allowed.
    transform : {'dwtn', 'swtn'}
        Whether to use the DWT or SWT for the transforms.
    mode : str or tuple of str, optional
        Signal extension mode, see `Modes` (default: 'symmetric'). This option
        is only used when transform='dwtn'.

    Returns
    -------
    coeffs : list
        For more information, see the detailed description in `wavedecn`.

    See Also
    --------
    imran, swtn

    Notes
    -----
    This is sometimes referred to as an additive decomposition because the
    inverse transform (``imran``) is just the sum of the coefficient arrays
    [1]_. The decomposition using ``transform='dwt'`` corresponds to section
    2.2 while that using an undecimated transform (``transform='swt'``) is
    described in section 3.2 and appendix A.

    This transform does not share the variance partition property of ``swtn``
    with `norm=True`. It does however, result in coefficients that are
    temporally aligned regardless of the symmetry of the wavelet used.

    The redundancy of this transform is ``(2**n - 1) * level + 1`` where ``n``
    corresponds to the number of axes transformed.

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551
    """
    axes, axes_shapes, ndim_transform = _prep_axes_wavedecn(data.shape, axes)
    wavelets = _wavelets_per_axis(wavelet, axes)

    if transform == 'swtn':
        if mode != 'periodization':
            raise ValueError(
                "transform swt only supports mode='periodization'")
        if level is None:
            level = min(swt_max_level(s) for s in data.shape)
        kwargs = {"wavelet": wavelets, "axes": axes, "norm": True}
        forward = partial(swtn, level=level, trim_approx=True, **kwargs)
        inverse = partial(iswtn, **kwargs)
    elif transform == 'dwtn':
        modes = _modes_per_axis(mode, axes)
        kwargs = {"wavelet": wavelets, "mode": modes, "axes": axes}
        forward = partial(wavedecn, level=level, **kwargs)
        inverse = partial(waverecn, **kwargs)
    else:
        raise ValueError(f"unrecognized transform: {transform}")

    wav_coeffs = forward(data)

    mra_coeffs = []
    nc = len(wav_coeffs)
    z = np.zeros_like(wav_coeffs[0])
    tmp = [z]
    for j in range(1, nc):
        tmp.append({k: np.zeros_like(v) for k, v in wav_coeffs[j].items()})

    # tmp has arrays of zeros except for the jth entry
    tmp[0] = wav_coeffs[0]
    # reconstruct
    rec = inverse(tmp)
    if rec.shape != data.shape:
        # trim any excess coefficients
        rec = rec[tuple([slice(sz) for sz in data.shape])]
    mra_coeffs.append(rec)
    # restore zeros
    tmp[0] = z

    for j in range(1, nc):
        dcoeffs = {}
        dkeys = list(wav_coeffs[j].keys())
        for k in dkeys:
            # tmp has arrays of zeros except for the jth entry
            z = tmp[j][k]
            tmp[j][k] = wav_coeffs[j][k]
            # tmp[j]['a' * len(k)] = z
            # reconstruct
            rec = inverse(tmp)
            if rec.shape != data.shape:
                # trim any excess coefficients
                rec = rec[tuple([slice(sz) for sz in data.shape])]
            dcoeffs[k] = rec
            # restore zeros
            tmp[j][k] = z
            # tmp[j].pop('a' * len(k))
        mra_coeffs.append(dcoeffs)
    return mra_coeffs


def imran(mra_coeffs):
    """Inverse nD multiresolution analysis via summation.

    Parameters
    ----------
    mra_coeffs : list
        Multiresolution analysis coefficients as returned by `mra2`.

    Returns
    -------
    rec : ndarray
        The reconstructed signal.

    See Also
    --------
    mran

    References
    ----------
    .. [1] Donald B. Percival and Harold O. Mofjeld. Analysis of Subtidal
        Coastal Sea Level Fluctuations Using Wavelets. Journal of the American
        Statistical Association Vol. 92, No. 439 (Sep., 1997), pp. 868-880.
        https://doi.org/10.2307/2965551
    """
    rec = mra_coeffs[0]
    for j in range(1, len(mra_coeffs)):
        for k, v in mra_coeffs[j].items():
            rec += v
    return rec
