# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
#          Christos Aridas
# License: MIT

import re
import warnings
from contextlib import suppress
from functools import partial
from inspect import isfunction

from sklearn import clone, config_context
from sklearn.exceptions import SkipTestWarning
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._testing import SkipTest
from sklearn.utils.fixes import parse_version

from imblearn.combine import SMOTEENN, SMOTETomek
from imblearn.ensemble import (
    BalancedBaggingClassifier,
    BalancedRandomForestClassifier,
    EasyEnsembleClassifier,
    RUSBoostClassifier,
)
from imblearn.over_sampling import (
    ADASYN,
    SMOTE,
    SMOTEN,
    SMOTENC,
    SVMSMOTE,
    BorderlineSMOTE,
    KMeansSMOTE,
    RandomOverSampler,
)
from imblearn.pipeline import Pipeline
from imblearn.under_sampling import (
    ClusterCentroids,
    CondensedNearestNeighbour,
    InstanceHardnessThreshold,
    NearMiss,
    OneSidedSelection,
    RandomUnderSampler,
)
from imblearn.utils._sklearn_compat import sklearn_version
from imblearn.utils.testing import all_estimators

# The following dictionary is to indicate constructor arguments suitable for the test
# suite, which uses very small datasets, and is intended to run rather quickly.
INIT_PARAMS = {
    # estimator
    BalancedBaggingClassifier: dict(random_state=42),
    BalancedRandomForestClassifier: dict(random_state=42),
    EasyEnsembleClassifier: [
        # AdaBoostClassifier does not allow nan values
        dict(random_state=42),
        # DecisionTreeClassifier allows nan values
        dict(estimator=DecisionTreeClassifier(random_state=42), random_state=42),
    ],
    Pipeline: dict(
        steps=[
            ("sampler", RandomUnderSampler(random_state=0)),
            ("logistic", LogisticRegression()),
        ]
    ),
    # over-sampling
    ADASYN: dict(random_state=42),
    BorderlineSMOTE: dict(random_state=42),
    KMeansSMOTE: dict(random_state=0),
    RandomOverSampler: dict(random_state=42),
    SMOTE: dict(random_state=42),
    SMOTEN: dict(random_state=42),
    SMOTENC: dict(categorical_features=[0], random_state=42),
    SVMSMOTE: dict(random_state=42),
    # under-sampling
    ClusterCentroids: dict(random_state=42),
    CondensedNearestNeighbour: dict(random_state=42),
    InstanceHardnessThreshold: dict(random_state=42),
    NearMiss: [dict(version=1), dict(version=2), dict(version=3)],
    OneSidedSelection: dict(random_state=42),
    RandomUnderSampler: dict(random_state=42),
    # combination
    SMOTEENN: dict(random_state=42),
    SMOTETomek: dict(random_state=42),
}

# This dictionary stores parameters for specific checks. It also enables running the
# same check with multiple instances of the same estimator with different parameters.
# The special key "*" allows to apply the parameters to all checks.
# TODO(devtools): allow third-party developers to pass test specific params to checks
PER_ESTIMATOR_CHECK_PARAMS: dict = {
    Pipeline: {
        "check_classifiers_with_encoded_labels": dict(
            sampler__sampling_strategy={"setosa": 20, "virginica": 20}
        )
    }
}

SKIPPED_ESTIMATORS = [SMOTENC]


def _tested_estimators(type_filter=None):
    for _, Estimator in all_estimators(type_filter=type_filter):
        with suppress(SkipTest):
            for estimator in _construct_instances(Estimator):
                yield estimator


def _construct_instances(Estimator):
    """Construct Estimator instances if possible.

    If parameter sets in INIT_PARAMS are provided, use them. If there are a list
    of parameter sets, return one instance for each set.
    """
    if Estimator in SKIPPED_ESTIMATORS:
        msg = f"Can't instantiate estimator {Estimator.__name__}"
        # raise additional warning to be shown by pytest
        warnings.warn(msg, SkipTestWarning)
        raise SkipTest(msg)

    if Estimator in INIT_PARAMS:
        param_sets = INIT_PARAMS[Estimator]
        if not isinstance(param_sets, list):
            param_sets = [param_sets]
        for params in param_sets:
            est = Estimator(**params)
            yield est
    else:
        yield Estimator()


def _get_check_estimator_ids(obj):
    """Create pytest ids for checks.

    When `obj` is an estimator, this returns the pprint version of the
    estimator (with `print_changed_only=True`). When `obj` is a function, the
    name of the function is returned with its keyword arguments.

    `_get_check_estimator_ids` is designed to be used as the `id` in
    `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)`
    is yielding estimators and checks.

    Parameters
    ----------
    obj : estimator or function
        Items generated by `check_estimator`.

    Returns
    -------
    id : str or None

    See Also
    --------
    check_estimator
    """
    if isfunction(obj):
        return obj.__name__
    if isinstance(obj, partial):
        if not obj.keywords:
            return obj.func.__name__
        kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()])
        return "{}({})".format(obj.func.__name__, kwstring)
    if hasattr(obj, "get_params"):
        with config_context(print_changed_only=True):
            return re.sub(r"\s", "", str(obj))


def _yield_instances_for_check(check, estimator_orig):
    """Yield instances for a check.

    For most estimators, this is a no-op.

    For estimators which have an entry in PER_ESTIMATOR_CHECK_PARAMS, this will yield
    an estimator for each parameter set in PER_ESTIMATOR_CHECK_PARAMS[estimator].
    """
    # TODO(devtools): enable this behavior for third party estimators as well
    if type(estimator_orig) not in PER_ESTIMATOR_CHECK_PARAMS:
        yield estimator_orig
        return

    check_params = PER_ESTIMATOR_CHECK_PARAMS[type(estimator_orig)]

    try:
        check_name = check.__name__
    except AttributeError:
        # partial tests
        check_name = check.func.__name__

    if check_name not in check_params:
        yield estimator_orig
        return

    param_set = check_params[check_name]
    if isinstance(param_set, dict):
        param_set = [param_set]

    for params in param_set:
        estimator = clone(estimator_orig)
        estimator.set_params(**params)
        yield estimator


PER_ESTIMATOR_XFAIL_CHECKS = {
    BalancedRandomForestClassifier: {
        "check_sample_weight_equivalence": "FIXME",
        "check_sample_weight_equivalence_on_sparse_data": "FIXME",
        "check_sample_weight_equivalence_on_dense_data": "FIXME",
    },
    NearMiss: {
        "check_samplers_fit_resample": "FIXME",
    },
    Pipeline: {
        "check_classifiers_train": "FIXME",
        "check_supervised_y_2d": "FIXME",
        "check_dont_overwrite_parameters": (
            "Pipeline changes the `steps` parameter, which it shouldn't. "
            "Therefore this test is x-fail until we fix this."
        ),
        "check_estimators_overwrite_params": (
            "Pipeline changes the `steps` parameter, which it shouldn't. "
            "Therefore this test is x-fail until we fix this."
        ),
    },
    RUSBoostClassifier: {
        "check_sample_weight_equivalence": "FIXME",
        "check_sample_weight_equivalence_on_sparse_data": "FIXME",
        "check_sample_weight_equivalence_on_dense_data": "FIXME",
        "check_estimator_sparse_data": "FIXME",
        "check_estimator_sparse_matrix": "FIXME",
        "check_estimator_sparse_array": "FIXME",
    },
}

if sklearn_version < parse_version("1.4"):
    for _, Estimator in all_estimators():
        if Estimator in PER_ESTIMATOR_XFAIL_CHECKS:
            PER_ESTIMATOR_XFAIL_CHECKS[Estimator]["check_estimators_pickle"] = "FIXME"
        else:
            PER_ESTIMATOR_XFAIL_CHECKS[Estimator] = {"check_estimators_pickle": "FIXME"}


def _get_expected_failed_checks(estimator):
    """Get the expected failed checks for all estimators in scikit-learn."""
    failed_checks = PER_ESTIMATOR_XFAIL_CHECKS.get(type(estimator), {})
    return failed_checks
