from statsmodels.compat.python import lrange

from io import BytesIO
from itertools import product

import numpy as np
from numpy.testing import assert_, assert_raises
import pandas as pd
import pytest

from statsmodels.api import datasets

# utilities for the tests

try:
    import matplotlib.pyplot as plt  # noqa:F401
except ImportError:
    pass

# other functions to be tested for accuracy
# the main drawing function
from statsmodels.graphics.mosaicplot import (
    _hierarchical_split,
    _key_splitting,
    _normalize_split,
    _reduce_dict,
    _split_rect,
    mosaic,
)


@pytest.mark.matplotlib
def test_data_conversion(close_figures):
    # It will not reorder the elements
    # so the dictionary will look odd
    # as it key order has the c and b
    # keys swapped
    import pandas
    _, ax = plt.subplots(4, 4)
    data = {'ax': 1, 'bx': 2, 'cx': 3}
    mosaic(data, ax=ax[0, 0], title='basic dict', axes_label=False)
    data = pandas.Series(data)
    mosaic(data, ax=ax[0, 1], title='basic series', axes_label=False)
    data = [1, 2, 3]
    mosaic(data, ax=ax[0, 2], title='basic list', axes_label=False)
    data = np.asarray(data)
    mosaic(data, ax=ax[0, 3], title='basic array', axes_label=False)
    plt.close("all")

    data = {('ax', 'cx'): 1, ('bx', 'cx'): 2, ('ax', 'dx'): 3, ('bx', 'dx'): 4}
    mosaic(data, ax=ax[1, 0], title='compound dict', axes_label=False)
    mosaic(data, ax=ax[2, 0], title='inverted keys dict', index=[1, 0], axes_label=False)
    data = pandas.Series(data)
    mosaic(data, ax=ax[1, 1], title='compound series', axes_label=False)
    mosaic(data, ax=ax[2, 1], title='inverted keys series', index=[1, 0])
    data = [[1, 2], [3, 4]]
    mosaic(data, ax=ax[1, 2], title='compound list', axes_label=False)
    mosaic(data, ax=ax[2, 2], title='inverted keys list', index=[1, 0])
    data = np.array([[1, 2], [3, 4]])
    mosaic(data, ax=ax[1, 3], title='compound array', axes_label=False)
    mosaic(data, ax=ax[2, 3], title='inverted keys array', index=[1, 0], axes_label=False)
    plt.close("all")

    gender = ['male', 'male', 'male', 'female', 'female', 'female']
    pet = ['cat', 'dog', 'dog', 'cat', 'dog', 'cat']
    data = pandas.DataFrame({'gender': gender, 'pet': pet})
    mosaic(data, ['gender'], ax=ax[3, 0], title='dataframe by key 1', axes_label=False)
    mosaic(data, ['pet'], ax=ax[3, 1], title='dataframe by key 2', axes_label=False)
    mosaic(data, ['gender', 'pet'], ax=ax[3, 2], title='both keys', axes_label=False)
    mosaic(data, ['pet', 'gender'], ax=ax[3, 3], title='keys inverted', axes_label=False)
    plt.close("all")
    plt.suptitle('testing data conversion (plot 1 of 4)')


@pytest.mark.matplotlib
def test_mosaic_simple(close_figures):
    # display a simple plot of 4 categories of data, splitted in four
    # levels with increasing size for each group
    # creation of the levels
    key_set = (['male', 'female'], ['old', 'adult', 'young'],
               ['worker', 'unemployed'], ['healty', 'ill'])
    # the cartesian product of all the categories is
    # the complete set of categories
    keys = list(product(*key_set))
    data = dict(zip(keys, range(1, 1 + len(keys))))
    # which colours should I use for the various categories?
    # put it into a dict
    props = {}
    #males and females in blue and red
    props[('male',)] = {'color': 'b'}
    props[('female',)] = {'color': 'r'}
    # all the groups corresponding to ill groups have a different color
    for key in keys:
        if 'ill' in key:
            if 'male' in key:
                props[key] = {'color': 'BlueViolet' , 'hatch': '+'}
            else:
                props[key] = {'color': 'Crimson' , 'hatch': '+'}
    # mosaic of the data, with given gaps and colors
    mosaic(data, gap=0.05, properties=props, axes_label=False)
    plt.suptitle('syntetic data, 4 categories (plot 2 of 4)')


@pytest.mark.matplotlib
def test_mosaic(close_figures):
    # make the same analysis on a known dataset

    # load the data and clean it a bit
    affairs = datasets.fair.load_pandas()
    datas = affairs.exog
    # any time greater than 0 is cheating
    datas['cheated'] = affairs.endog > 0
    # sort by the marriage quality and give meaningful name
    # [rate_marriage, age, yrs_married, children,
    # religious, educ, occupation, occupation_husb]
    datas = datas.sort_values(['rate_marriage', 'religious'])

    num_to_desc = {1: 'awful', 2: 'bad', 3: 'intermediate',
                   4: 'good', 5: 'wonderful'}
    datas['rate_marriage'] = datas['rate_marriage'].map(num_to_desc)
    num_to_faith = {1: 'non religious', 2: 'poorly religious', 3: 'religious',
                    4: 'very religious'}
    datas['religious'] = datas['religious'].map(num_to_faith)
    num_to_cheat = {False: 'faithful', True: 'cheated'}
    datas['cheated'] = datas['cheated'].map(num_to_cheat)
    # finished cleaning
    _, ax = plt.subplots(2, 2)
    mosaic(datas, ['rate_marriage', 'cheated'], ax=ax[0, 0],
           title='by marriage happiness')
    mosaic(datas, ['religious', 'cheated'], ax=ax[0, 1],
           title='by religiosity')
    mosaic(datas, ['rate_marriage', 'religious', 'cheated'], ax=ax[1, 0],
           title='by both', labelizer=lambda k:'')
    ax[1, 0].set_xlabel('marriage rating')
    ax[1, 0].set_ylabel('religion status')
    mosaic(datas, ['religious', 'rate_marriage'], ax=ax[1, 1],
           title='inter-dependence', axes_label=False)
    plt.suptitle("extramarital affairs (plot 3 of 4)")


@pytest.mark.matplotlib
def test_mosaic_very_complex(close_figures):
    # make a scattermatrix of mosaic plots to show the correlations between
    # each pair of variable in a dataset. Could be easily converted into a
    # new function that does this automatically based on the type of data
    key_name = ['gender', 'age', 'health', 'work']
    key_base = (['male', 'female'], ['old', 'young'],
                ['healty', 'ill'], ['work', 'unemployed'])
    keys = list(product(*key_base))
    data = dict(zip(keys, range(1, 1 + len(keys))))
    props = {}
    props[('male', 'old')] = {'color': 'r'}
    props[('female',)] = {'color': 'pink'}
    L = len(key_base)
    _, axes = plt.subplots(L, L)
    for i in range(L):
        for j in range(L):
            m = set(range(L)).difference({i, j})
            if i == j:
                axes[i, i].text(0.5, 0.5, key_name[i],
                                ha='center', va='center')
                axes[i, i].set_xticks([])
                axes[i, i].set_xticklabels([])
                axes[i, i].set_yticks([])
                axes[i, i].set_yticklabels([])
            else:
                ji = max(i, j)
                ij = min(i, j)
                temp_data = {(k[ij], k[ji]) + tuple(k[r] for r in m): v
                                  for k, v in data.items()}

                keys = list(temp_data.keys())
                for k in keys:
                    value = _reduce_dict(temp_data, k[:2])
                    temp_data[k[:2]] = value
                    del temp_data[k]
                mosaic(temp_data, ax=axes[i, j], axes_label=False,
                       properties=props, gap=0.05, horizontal=i > j)
    plt.suptitle('old males should look bright red,  (plot 4 of 4)')


@pytest.mark.matplotlib
def test_axes_labeling(close_figures):
    from numpy.random import rand
    key_set = (['male', 'female'], ['old', 'adult', 'young'],
               ['worker', 'unemployed'], ['yes', 'no'])
    # the cartesian product of all the categories is
    # the complete set of categories
    keys = list(product(*key_set))
    data = dict(zip(keys, rand(len(keys))))
    lab = lambda k: ''.join(s[0] for s in k)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    mosaic(data, ax=ax1, labelizer=lab, horizontal=True, label_rotation=45)
    mosaic(data, ax=ax2, labelizer=lab, horizontal=False,
           label_rotation=[0, 45, 90, 0])
    #fig.tight_layout()
    fig.suptitle("correct alignment of the axes labels")


@pytest.mark.smoke
@pytest.mark.matplotlib
def test_mosaic_empty_cells(close_figures):
    # GH#2286
    import pandas as pd
    mydata = pd.DataFrame({'id2': {64: 'Angelica',
                                   65: 'DXW_UID', 66: 'casuid01',
                                   67: 'casuid01', 68: 'EC93_uid',
                                   69: 'EC93_uid', 70: 'EC93_uid',
                                   60: 'DXW_UID',  61: 'AtmosFox',
                                   62: 'DXW_UID', 63: 'DXW_UID'},
                           'id1': {64: 'TGP',
                                   65: 'Retention01', 66: 'default',
                                   67: 'default', 68: 'Musa_EC_9_3',
                                   69: 'Musa_EC_9_3', 70: 'Musa_EC_9_3',
                                   60: 'default', 61: 'default',
                                   62: 'default', 63: 'default'}})

    ct = pd.crosstab(mydata.id1, mydata.id2)
    _, vals = mosaic(ct.T.unstack())
    _, vals = mosaic(mydata, ['id1','id2'])


eq = lambda x, y: assert_(np.allclose(x, y))


def test_recursive_split():
    keys = list(product('mf'))
    data = dict(zip(keys, [1] * len(keys)))
    res = _hierarchical_split(data, gap=0)
    assert_(list(res.keys()) == keys)
    res[('m',)] = (0.0, 0.0, 0.5, 1.0)
    res[('f',)] = (0.5, 0.0, 0.5, 1.0)
    keys = list(product('mf', 'yao'))
    data = dict(zip(keys, [1] * len(keys)))
    res = _hierarchical_split(data, gap=0)
    assert_(list(res.keys()) == keys)
    res[('m', 'y')] = (0.0, 0.0, 0.5, 1 / 3)
    res[('m', 'a')] = (0.0, 1 / 3, 0.5, 1 / 3)
    res[('m', 'o')] = (0.0, 2 / 3, 0.5, 1 / 3)
    res[('f', 'y')] = (0.5, 0.0, 0.5, 1 / 3)
    res[('f', 'a')] = (0.5, 1 / 3, 0.5, 1 / 3)
    res[('f', 'o')] = (0.5, 2 / 3, 0.5, 1 / 3)


def test__reduce_dict():
    data = dict(zip(list(product('mf', 'oy', 'wn')), [1] * 8))
    eq(_reduce_dict(data, ('m',)), 4)
    eq(_reduce_dict(data, ('m', 'o')), 2)
    eq(_reduce_dict(data, ('m', 'o', 'w')), 1)
    data = dict(zip(list(product('mf', 'oy', 'wn')), lrange(8)))
    eq(_reduce_dict(data, ('m',)), 6)
    eq(_reduce_dict(data, ('m', 'o')), 1)
    eq(_reduce_dict(data, ('m', 'o', 'w')), 0)


def test__key_splitting():
    # subdivide starting with an empty tuple
    base_rect = {tuple(): (0, 0, 1, 1)}
    res = _key_splitting(base_rect, ['a', 'b'], [1, 1], tuple(), True, 0)
    assert_(list(res.keys()) == [('a',), ('b',)])
    eq(res[('a',)], (0, 0, 0.5, 1))
    eq(res[('b',)], (0.5, 0, 0.5, 1))
    # subdivide a in two sublevel
    res_bis = _key_splitting(res, ['c', 'd'], [1, 1], ('a',), False, 0)
    assert_(list(res_bis.keys()) == [('a', 'c'), ('a', 'd'), ('b',)])
    eq(res_bis[('a', 'c')], (0.0, 0.0, 0.5, 0.5))
    eq(res_bis[('a', 'd')], (0.0, 0.5, 0.5, 0.5))
    eq(res_bis[('b',)], (0.5, 0, 0.5, 1))
    # starting with a non empty tuple and uneven distribution
    base_rect = {('total',): (0, 0, 1, 1)}
    res = _key_splitting(base_rect, ['a', 'b'], [1, 2], ('total',), True, 0)
    assert_(list(res.keys()) == [('total',) + (e,) for e in ['a', 'b']])
    eq(res[('total', 'a')], (0, 0, 1 / 3, 1))
    eq(res[('total', 'b')], (1 / 3, 0, 2 / 3, 1))


def test_proportion_normalization():
    # extremes should give the whole set, as well
    # as if 0 is inserted
    eq(_normalize_split(0.), [0.0, 0.0, 1.0])
    eq(_normalize_split(1.), [0.0, 1.0, 1.0])
    eq(_normalize_split(2.), [0.0, 1.0, 1.0])
    # negative values should raise ValueError
    assert_raises(ValueError, _normalize_split, -1)
    assert_raises(ValueError, _normalize_split, [1., -1])
    assert_raises(ValueError, _normalize_split, [1., -1, 0.])
    # if everything is zero it will complain
    assert_raises(ValueError, _normalize_split, [0.])
    assert_raises(ValueError, _normalize_split, [0., 0.])
    # one-element array should return the whole interval
    eq(_normalize_split([0.5]), [0.0, 1.0])
    eq(_normalize_split([1.]), [0.0, 1.0])
    eq(_normalize_split([2.]), [0.0, 1.0])
    # simple division should give two pieces
    for x in [0.3, 0.5, 0.9]:
        eq(_normalize_split(x), [0., x, 1.0])
    # multiple division should split as the sum of the components
    for x, y in [(0.25, 0.5), (0.1, 0.8), (10., 30.)]:
        eq(_normalize_split([x, y]), [0., x / (x + y), 1.0])
    for x, y, z in [(1., 1., 1.), (0.1, 0.5, 0.7), (10., 30., 40)]:
        eq(_normalize_split(
            [x, y, z]), [0., x / (x + y + z), (x + y) / (x + y + z), 1.0])


def test_false_split():
    # if you ask it to be divided in only one piece, just return the original
    # one
    pure_square = [0., 0., 1., 1.]
    conf_h = dict(proportion=[1], gap=0.0, horizontal=True)
    conf_v = dict(proportion=[1], gap=0.0, horizontal=False)
    eq(_split_rect(*pure_square, **conf_h), pure_square)
    eq(_split_rect(*pure_square, **conf_v), pure_square)
    conf_h = dict(proportion=[1], gap=0.5, horizontal=True)
    conf_v = dict(proportion=[1], gap=0.5, horizontal=False)
    eq(_split_rect(*pure_square, **conf_h), pure_square)
    eq(_split_rect(*pure_square, **conf_v), pure_square)

    # identity on a void rectangle should not give anything strange
    null_square = [0., 0., 0., 0.]
    conf = dict(proportion=[1], gap=0.0, horizontal=True)
    eq(_split_rect(*null_square, **conf), null_square)
    conf = dict(proportion=[1], gap=1.0, horizontal=True)
    eq(_split_rect(*null_square, **conf), null_square)

    # splitting a negative rectangle should raise error
    neg_square = [0., 0., -1., 0.]
    conf = dict(proportion=[1], gap=0.0, horizontal=True)
    assert_raises(ValueError, _split_rect, *neg_square, **conf)
    conf = dict(proportion=[1, 1], gap=0.0, horizontal=True)
    assert_raises(ValueError, _split_rect, *neg_square, **conf)
    conf = dict(proportion=[1], gap=0.5, horizontal=True)
    assert_raises(ValueError, _split_rect, *neg_square, **conf)
    conf = dict(proportion=[1, 1], gap=0.5, horizontal=True)
    assert_raises(ValueError, _split_rect, *neg_square, **conf)


def test_rect_pure_split():
    pure_square = [0., 0., 1., 1.]
    # division in two equal pieces from the perfect square
    h_2split = [(0.0, 0.0, 0.5, 1.0), (0.5, 0.0, 0.5, 1.0)]
    conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)

    v_2split = [(0.0, 0.0, 1.0, 0.5), (0.0, 0.5, 1.0, 0.5)]
    conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False)
    eq(_split_rect(*pure_square, **conf_v), v_2split)

    # division in two non-equal pieces from the perfect square
    h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 2 / 3, 1.0)]
    conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)

    v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 2 / 3)]
    conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False)
    eq(_split_rect(*pure_square, **conf_v), v_2split)

    # division in three equal pieces from the perfect square
    h_2split = [(0.0, 0.0, 1 / 3, 1.0), (1 / 3, 0.0, 1 / 3, 1.0), (2 / 3, 0.0,
                 1 / 3, 1.0)]
    conf_h = dict(proportion=[1, 1, 1], gap=0.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)

    v_2split = [(0.0, 0.0, 1.0, 1 / 3), (0.0, 1 / 3, 1.0, 1 / 3), (0.0, 2 / 3,
                 1.0, 1 / 3)]
    conf_v = dict(proportion=[1, 1, 1], gap=0.0, horizontal=False)
    eq(_split_rect(*pure_square, **conf_v), v_2split)

    # division in three non-equal pieces from the perfect square
    h_2split = [(0.0, 0.0, 1 / 4, 1.0), (1 / 4, 0.0, 1 / 2, 1.0), (3 / 4, 0.0,
                 1 / 4, 1.0)]
    conf_h = dict(proportion=[1, 2, 1], gap=0.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)

    v_2split = [(0.0, 0.0, 1.0, 1 / 4), (0.0, 1 / 4, 1.0, 1 / 2), (0.0, 3 / 4,
                 1.0, 1 / 4)]
    conf_v = dict(proportion=[1, 2, 1], gap=0.0, horizontal=False)
    eq(_split_rect(*pure_square, **conf_v), v_2split)

    # splitting on a void rectangle should give multiple void
    null_square = [0., 0., 0., 0.]
    conf = dict(proportion=[1, 1], gap=0.0, horizontal=True)
    eq(_split_rect(*null_square, **conf), [null_square, null_square])
    conf = dict(proportion=[1, 2], gap=1.0, horizontal=True)
    eq(_split_rect(*null_square, **conf), [null_square, null_square])


def test_rect_deformed_split():
    non_pure_square = [1., -1., 1., 0.5]
    # division in two equal pieces from the perfect square
    h_2split = [(1.0, -1.0, 0.5, 0.5), (1.5, -1.0, 0.5, 0.5)]
    conf_h = dict(proportion=[1, 1], gap=0.0, horizontal=True)
    eq(_split_rect(*non_pure_square, **conf_h), h_2split)

    v_2split = [(1.0, -1.0, 1.0, 0.25), (1.0, -0.75, 1.0, 0.25)]
    conf_v = dict(proportion=[1, 1], gap=0.0, horizontal=False)
    eq(_split_rect(*non_pure_square, **conf_v), v_2split)

    # division in two non-equal pieces from the perfect square
    h_2split = [(1.0, -1.0, 1 / 3, 0.5), (1 + 1 / 3, -1.0, 2 / 3, 0.5)]
    conf_h = dict(proportion=[1, 2], gap=0.0, horizontal=True)
    eq(_split_rect(*non_pure_square, **conf_h), h_2split)

    v_2split = [(1.0, -1.0, 1.0, 1 / 6), (1.0, 1 / 6 - 1, 1.0, 2 / 6)]
    conf_v = dict(proportion=[1, 2], gap=0.0, horizontal=False)
    eq(_split_rect(*non_pure_square, **conf_v), v_2split)


def test_gap_split():
    pure_square = [0., 0., 1., 1.]

    # null split
    conf_h = dict(proportion=[1], gap=1.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), pure_square)

    # equal split
    h_2split = [(0.0, 0.0, 0.25, 1.0), (0.75, 0.0, 0.25, 1.0)]
    conf_h = dict(proportion=[1, 1], gap=1.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)

    # disequal split
    h_2split = [(0.0, 0.0, 1 / 6, 1.0), (0.5 + 1 / 6, 0.0, 1 / 3, 1.0)]
    conf_h = dict(proportion=[1, 2], gap=1.0, horizontal=True)
    eq(_split_rect(*pure_square, **conf_h), h_2split)


@pytest.mark.matplotlib
def test_default_arg_index(close_figures):
    # 2116
    df = pd.DataFrame({'size' : ['small', 'large', 'large', 'small', 'large',
                                 'small'],
                       'length' : ['long', 'short', 'short', 'long', 'long',
                                   'short']})
    assert_raises(ValueError, mosaic, data=df, title='foobar')


@pytest.mark.matplotlib
def test_missing_category(close_figures):
    # GH5639
    animal = ['dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat',
              'dog', 'dog', 'cat']
    size = ['medium', 'large', 'medium', 'medium', 'medium', 'medium',
            'large', 'large', 'large', 'small']
    testdata = pd.DataFrame({'animal': animal, 'size': size})
    testdata['size'] = pd.Categorical(testdata['size'],
                                      categories=['small', 'medium', 'large'])
    testdata = testdata.sort_values('size')
    fig, _ = mosaic(testdata, ['animal', 'size'])
    bio = BytesIO()
    fig.savefig(bio, format='png')
