"""
Testing for Multi-layer Perceptron module (sklearn.neural_network)
"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import re
import sys
import warnings
from io import StringIO

import joblib
import numpy as np
import pytest

from sklearn.datasets import (
    load_digits,
    load_iris,
    make_multilabel_classification,
    make_regression,
)
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import roc_auc_score
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.preprocessing import LabelBinarizer, MinMaxScaler, scale
from sklearn.utils._testing import (
    assert_allclose,
    assert_almost_equal,
    assert_array_equal,
    ignore_warnings,
)
from sklearn.utils.fixes import CSR_CONTAINERS

ACTIVATION_TYPES = ["identity", "logistic", "tanh", "relu"]

X_digits, y_digits = load_digits(n_class=3, return_X_y=True)

X_digits_multi = MinMaxScaler().fit_transform(X_digits[:200])
y_digits_multi = y_digits[:200]

X_digits, y_digits = load_digits(n_class=2, return_X_y=True)

X_digits_binary = MinMaxScaler().fit_transform(X_digits[:200])
y_digits_binary = y_digits[:200]

classification_datasets = [
    (X_digits_multi, y_digits_multi),
    (X_digits_binary, y_digits_binary),
]

X_reg, y_reg = make_regression(
    n_samples=200, n_features=10, bias=20.0, noise=100.0, random_state=7
)
y_reg = scale(y_reg)
regression_datasets = [(X_reg, y_reg)]

iris = load_iris()

X_iris = iris.data
y_iris = iris.target


def test_alpha():
    # Test that larger alpha yields weights closer to zero
    X = X_digits_binary[:100]
    y = y_digits_binary[:100]

    alpha_vectors = []
    alpha_values = np.arange(2)
    absolute_sum = lambda x: np.sum(np.abs(x))

    for alpha in alpha_values:
        mlp = MLPClassifier(hidden_layer_sizes=10, alpha=alpha, random_state=1)
        with ignore_warnings(category=ConvergenceWarning):
            mlp.fit(X, y)
        alpha_vectors.append(
            np.array([absolute_sum(mlp.coefs_[0]), absolute_sum(mlp.coefs_[1])])
        )

    for i in range(len(alpha_values) - 1):
        assert (alpha_vectors[i] > alpha_vectors[i + 1]).all()


def test_fit():
    # Test that the algorithm solution is equal to a worked out example.
    X = np.array([[0.6, 0.8, 0.7]])
    y = np.array([0])
    mlp = MLPClassifier(
        solver="sgd",
        learning_rate_init=0.1,
        alpha=0.1,
        activation="logistic",
        random_state=1,
        max_iter=1,
        hidden_layer_sizes=2,
        momentum=0,
    )
    # set weights
    mlp.coefs_ = [0] * 2
    mlp.intercepts_ = [0] * 2
    mlp.n_outputs_ = 1
    mlp.coefs_[0] = np.array([[0.1, 0.2], [0.3, 0.1], [0.5, 0]])
    mlp.coefs_[1] = np.array([[0.1], [0.2]])
    mlp.intercepts_[0] = np.array([0.1, 0.1])
    mlp.intercepts_[1] = np.array([1.0])
    mlp._coef_grads = [] * 2
    mlp._intercept_grads = [] * 2
    mlp.n_features_in_ = 3

    # Initialize parameters
    mlp.n_iter_ = 0
    mlp.learning_rate_ = 0.1

    # Compute the number of layers
    mlp.n_layers_ = 3

    # Pre-allocate gradient matrices
    mlp._coef_grads = [0] * (mlp.n_layers_ - 1)
    mlp._intercept_grads = [0] * (mlp.n_layers_ - 1)

    mlp.out_activation_ = "logistic"
    mlp.t_ = 0
    mlp.best_loss_ = np.inf
    mlp.loss_curve_ = []
    mlp._no_improvement_count = 0
    mlp._intercept_velocity = [
        np.zeros_like(intercepts) for intercepts in mlp.intercepts_
    ]
    mlp._coef_velocity = [np.zeros_like(coefs) for coefs in mlp.coefs_]

    mlp.partial_fit(X, y, classes=[0, 1])
    # Manually worked out example
    # h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.1 + 0.8 * 0.3 + 0.7 * 0.5 + 0.1)
    #       =  0.679178699175393
    # h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.2 + 0.8 * 0.1 + 0.7 * 0 + 0.1)
    #         = 0.574442516811659
    # o1 = g(h * W2 + b21) = g(0.679 * 0.1 + 0.574 * 0.2 + 1)
    #       = 0.7654329236196236
    # d21 = -(0 - 0.765) = 0.765
    # d11 = (1 - 0.679) * 0.679 * 0.765 * 0.1 = 0.01667
    # d12 = (1 - 0.574) * 0.574 * 0.765 * 0.2 = 0.0374
    # W1grad11 = X1 * d11 + alpha * W11 = 0.6 * 0.01667 + 0.1 * 0.1 = 0.0200
    # W1grad11 = X1 * d12 + alpha * W12 = 0.6 * 0.0374 + 0.1 * 0.2 = 0.04244
    # W1grad21 = X2 * d11 + alpha * W13 = 0.8 * 0.01667 + 0.1 * 0.3 = 0.043336
    # W1grad22 = X2 * d12 + alpha * W14 = 0.8 * 0.0374 + 0.1 * 0.1 = 0.03992
    # W1grad31 = X3 * d11 + alpha * W15 = 0.6 * 0.01667 + 0.1 * 0.5 = 0.060002
    # W1grad32 = X3 * d12 + alpha * W16 = 0.6 * 0.0374 + 0.1 * 0 = 0.02244
    # W2grad1 = h1 * d21 + alpha * W21 = 0.679 * 0.765 + 0.1 * 0.1 = 0.5294
    # W2grad2 = h2 * d21 + alpha * W22 = 0.574 * 0.765 + 0.1 * 0.2 = 0.45911
    # b1grad1 = d11 = 0.01667
    # b1grad2 = d12 = 0.0374
    # b2grad = d21 = 0.765
    # W1 = W1 - eta * [W1grad11, .., W1grad32] = [[0.1, 0.2], [0.3, 0.1],
    #          [0.5, 0]] - 0.1 * [[0.0200, 0.04244], [0.043336, 0.03992],
    #          [0.060002, 0.02244]] = [[0.098, 0.195756], [0.2956664,
    #          0.096008], [0.4939998, -0.002244]]
    # W2 = W2 - eta * [W2grad1, W2grad2] = [[0.1], [0.2]] - 0.1 *
    #        [[0.5294], [0.45911]] = [[0.04706], [0.154089]]
    # b1 = b1 - eta * [b1grad1, b1grad2] = 0.1 - 0.1 * [0.01667, 0.0374]
    #         = [0.098333, 0.09626]
    # b2 = b2 - eta * b2grad = 1.0 - 0.1 * 0.765 = 0.9235
    assert_almost_equal(
        mlp.coefs_[0],
        np.array([[0.098, 0.195756], [0.2956664, 0.096008], [0.4939998, -0.002244]]),
        decimal=3,
    )
    assert_almost_equal(mlp.coefs_[1], np.array([[0.04706], [0.154089]]), decimal=3)
    assert_almost_equal(mlp.intercepts_[0], np.array([0.098333, 0.09626]), decimal=3)
    assert_almost_equal(mlp.intercepts_[1], np.array(0.9235), decimal=3)
    # Testing output
    #  h1 = g(X1 * W_i1 + b11) = g(0.6 * 0.098 + 0.8 * 0.2956664 +
    #               0.7 * 0.4939998 + 0.098333) = 0.677
    #  h2 = g(X2 * W_i2 + b12) = g(0.6 * 0.195756 + 0.8 * 0.096008 +
    #            0.7 * -0.002244 + 0.09626) = 0.572
    #  o1 = h * W2 + b21 = 0.677 * 0.04706 +
    #             0.572 * 0.154089 + 0.9235 = 1.043
    #  prob = sigmoid(o1) = 0.739
    assert_almost_equal(mlp.predict_proba(X)[0, 1], 0.739, decimal=3)


def test_gradient():
    # Test gradient.

    # This makes sure that the activation functions and their derivatives
    # are correct. The numerical and analytical computation of the gradient
    # should be close.
    for n_labels in [2, 3]:
        n_samples = 5
        n_features = 10
        random_state = np.random.RandomState(seed=42)
        X = random_state.rand(n_samples, n_features)
        y = 1 + np.mod(np.arange(n_samples) + 1, n_labels)
        Y = LabelBinarizer().fit_transform(y)

        for activation in ACTIVATION_TYPES:
            mlp = MLPClassifier(
                activation=activation,
                hidden_layer_sizes=10,
                solver="lbfgs",
                alpha=1e-5,
                learning_rate_init=0.2,
                max_iter=1,
                random_state=1,
            )
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", ConvergenceWarning)
                mlp.fit(X, y)

            theta = np.hstack([l.ravel() for l in mlp.coefs_ + mlp.intercepts_])

            layer_units = [X.shape[1]] + [mlp.hidden_layer_sizes] + [mlp.n_outputs_]

            activations = []
            deltas = []
            coef_grads = []
            intercept_grads = []

            activations.append(X)
            for i in range(mlp.n_layers_ - 1):
                activations.append(np.empty((X.shape[0], layer_units[i + 1])))
                deltas.append(np.empty((X.shape[0], layer_units[i + 1])))

                fan_in = layer_units[i]
                fan_out = layer_units[i + 1]
                coef_grads.append(np.empty((fan_in, fan_out)))
                intercept_grads.append(np.empty(fan_out))

            # analytically compute the gradients
            def loss_grad_fun(t):
                return mlp._loss_grad_lbfgs(
                    t, X, Y, activations, deltas, coef_grads, intercept_grads
                )

            [value, grad] = loss_grad_fun(theta)
            numgrad = np.zeros(np.size(theta))
            n = np.size(theta, 0)
            E = np.eye(n)
            epsilon = 1e-5
            # numerically compute the gradients
            for i in range(n):
                dtheta = E[:, i] * epsilon
                numgrad[i] = (
                    loss_grad_fun(theta + dtheta)[0] - loss_grad_fun(theta - dtheta)[0]
                ) / (epsilon * 2.0)
            assert_almost_equal(numgrad, grad)


@pytest.mark.parametrize("X,y", classification_datasets)
def test_lbfgs_classification(X, y):
    # Test lbfgs on classification.
    # It should achieve a score higher than 0.95 for the binary and multi-class
    # versions of the digits dataset.
    X_train = X[:150]
    y_train = y[:150]
    X_test = X[150:]
    expected_shape_dtype = (X_test.shape[0], y_train.dtype.kind)

    for activation in ACTIVATION_TYPES:
        mlp = MLPClassifier(
            solver="lbfgs",
            hidden_layer_sizes=50,
            max_iter=150,
            shuffle=True,
            random_state=1,
            activation=activation,
        )
        mlp.fit(X_train, y_train)
        y_predict = mlp.predict(X_test)
        assert mlp.score(X_train, y_train) > 0.95
        assert (y_predict.shape[0], y_predict.dtype.kind) == expected_shape_dtype


@pytest.mark.parametrize("X,y", regression_datasets)
def test_lbfgs_regression(X, y):
    # Test lbfgs on the regression dataset.
    for activation in ACTIVATION_TYPES:
        mlp = MLPRegressor(
            solver="lbfgs",
            hidden_layer_sizes=50,
            max_iter=200,
            tol=1e-3,
            shuffle=True,
            random_state=1,
            activation=activation,
        )
        mlp.fit(X, y)
        if activation == "identity":
            assert mlp.score(X, y) > 0.80
        else:
            # Non linear models perform much better than linear bottleneck:
            assert mlp.score(X, y) > 0.98


@pytest.mark.parametrize("X,y", classification_datasets)
def test_lbfgs_classification_maxfun(X, y):
    # Test lbfgs parameter max_fun.
    # It should independently limit the number of iterations for lbfgs.
    max_fun = 10
    # classification tests
    for activation in ACTIVATION_TYPES:
        mlp = MLPClassifier(
            solver="lbfgs",
            hidden_layer_sizes=50,
            max_iter=150,
            max_fun=max_fun,
            shuffle=True,
            random_state=1,
            activation=activation,
        )
        with pytest.warns(ConvergenceWarning):
            mlp.fit(X, y)
            assert max_fun >= mlp.n_iter_


@pytest.mark.parametrize("X,y", regression_datasets)
def test_lbfgs_regression_maxfun(X, y):
    # Test lbfgs parameter max_fun.
    # It should independently limit the number of iterations for lbfgs.
    max_fun = 10
    # regression tests
    for activation in ACTIVATION_TYPES:
        mlp = MLPRegressor(
            solver="lbfgs",
            hidden_layer_sizes=50,
            tol=0.0,
            max_iter=150,
            max_fun=max_fun,
            shuffle=True,
            random_state=1,
            activation=activation,
        )
        with pytest.warns(ConvergenceWarning):
            mlp.fit(X, y)
            assert max_fun >= mlp.n_iter_


def test_learning_rate_warmstart():
    # Tests that warm_start reuse past solutions.
    X = [[3, 2], [1, 6], [5, 6], [-2, -4]]
    y = [1, 1, 1, 0]
    for learning_rate in ["invscaling", "constant"]:
        mlp = MLPClassifier(
            solver="sgd",
            hidden_layer_sizes=4,
            learning_rate=learning_rate,
            max_iter=1,
            power_t=0.25,
            warm_start=True,
        )
        with ignore_warnings(category=ConvergenceWarning):
            mlp.fit(X, y)
            prev_eta = mlp._optimizer.learning_rate
            mlp.fit(X, y)
            post_eta = mlp._optimizer.learning_rate

        if learning_rate == "constant":
            assert prev_eta == post_eta
        elif learning_rate == "invscaling":
            assert mlp.learning_rate_init / pow(8 + 1, mlp.power_t) == post_eta


def test_multilabel_classification():
    # Test that multi-label classification works as expected.
    # test fit method
    X, y = make_multilabel_classification(
        n_samples=50, random_state=0, return_indicator=True
    )
    mlp = MLPClassifier(
        solver="lbfgs",
        hidden_layer_sizes=50,
        alpha=1e-5,
        max_iter=150,
        random_state=0,
        activation="logistic",
        learning_rate_init=0.2,
    )
    mlp.fit(X, y)
    assert mlp.score(X, y) > 0.97

    # test partial fit method
    mlp = MLPClassifier(
        solver="sgd",
        hidden_layer_sizes=50,
        max_iter=150,
        random_state=0,
        activation="logistic",
        alpha=1e-5,
        learning_rate_init=0.2,
    )
    for i in range(100):
        mlp.partial_fit(X, y, classes=[0, 1, 2, 3, 4])
    assert mlp.score(X, y) > 0.9

    # Make sure early stopping still work now that splitting is stratified by
    # default (it is disabled for multilabel classification)
    mlp = MLPClassifier(early_stopping=True)
    mlp.fit(X, y).predict(X)


def test_multioutput_regression():
    # Test that multi-output regression works as expected
    X, y = make_regression(n_samples=200, n_targets=5, random_state=11)
    mlp = MLPRegressor(
        solver="lbfgs", hidden_layer_sizes=50, max_iter=200, tol=1e-2, random_state=1
    )
    mlp.fit(X, y)
    assert mlp.score(X, y) > 0.9


def test_partial_fit_classes_error():
    # Tests that passing different classes to partial_fit raises an error
    X = [[3, 2]]
    y = [0]
    clf = MLPClassifier(solver="sgd")
    clf.partial_fit(X, y, classes=[0, 1])
    with pytest.raises(ValueError):
        clf.partial_fit(X, y, classes=[1, 2])


def test_partial_fit_classification():
    # Test partial_fit on classification.
    # `partial_fit` should yield the same results as 'fit' for binary and
    # multi-class classification.
    for X, y in classification_datasets:
        mlp = MLPClassifier(
            solver="sgd",
            max_iter=100,
            random_state=1,
            tol=0,
            alpha=1e-5,
            learning_rate_init=0.2,
        )

        with ignore_warnings(category=ConvergenceWarning):
            mlp.fit(X, y)
        pred1 = mlp.predict(X)
        mlp = MLPClassifier(
            solver="sgd", random_state=1, alpha=1e-5, learning_rate_init=0.2
        )
        for i in range(100):
            mlp.partial_fit(X, y, classes=np.unique(y))
        pred2 = mlp.predict(X)
        assert_array_equal(pred1, pred2)
        assert mlp.score(X, y) > 0.95


def test_partial_fit_unseen_classes():
    # Non regression test for bug 6994
    # Tests for labeling errors in partial fit

    clf = MLPClassifier(random_state=0)
    clf.partial_fit([[1], [2], [3]], ["a", "b", "c"], classes=["a", "b", "c", "d"])
    clf.partial_fit([[4]], ["d"])
    assert clf.score([[1], [2], [3], [4]], ["a", "b", "c", "d"]) > 0


def test_partial_fit_regression():
    # Test partial_fit on regression.
    # `partial_fit` should yield the same results as 'fit' for regression.
    X = X_reg
    y = y_reg

    for momentum in [0, 0.9]:
        mlp = MLPRegressor(
            solver="sgd",
            max_iter=100,
            activation="relu",
            random_state=1,
            learning_rate_init=0.01,
            batch_size=X.shape[0],
            momentum=momentum,
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", ConvergenceWarning)
            mlp.fit(X, y)
        pred1 = mlp.predict(X)
        mlp = MLPRegressor(
            solver="sgd",
            activation="relu",
            learning_rate_init=0.01,
            random_state=1,
            batch_size=X.shape[0],
            momentum=momentum,
        )
        for i in range(100):
            mlp.partial_fit(X, y)

        pred2 = mlp.predict(X)
        assert_allclose(pred1, pred2)
        score = mlp.score(X, y)
        assert score > 0.65


def test_partial_fit_errors():
    # Test partial_fit error handling.
    X = [[3, 2], [1, 6]]
    y = [1, 0]

    # no classes passed
    with pytest.raises(ValueError):
        MLPClassifier(solver="sgd").partial_fit(X, y, classes=[2])

    # lbfgs doesn't support partial_fit
    assert not hasattr(MLPClassifier(solver="lbfgs"), "partial_fit")


def test_nonfinite_params():
    # Check that MLPRegressor throws ValueError when dealing with non-finite
    # parameter values
    rng = np.random.RandomState(0)
    n_samples = 10
    fmax = np.finfo(np.float64).max
    X = fmax * rng.uniform(size=(n_samples, 2))
    y = rng.standard_normal(size=n_samples)

    clf = MLPRegressor()
    msg = (
        "Solver produced non-finite parameter weights. The input data may contain large"
        " values and need to be preprocessed."
    )
    with pytest.raises(ValueError, match=msg):
        with warnings.catch_warnings():
            # RuntimeWarning: overflow encountered in square
            warnings.simplefilter("ignore")
            clf.fit(X, y)


def test_predict_proba_binary():
    # Test that predict_proba works as expected for binary class.
    X = X_digits_binary[:50]
    y = y_digits_binary[:50]

    clf = MLPClassifier(hidden_layer_sizes=5, activation="logistic", random_state=1)
    with ignore_warnings(category=ConvergenceWarning):
        clf.fit(X, y)
    y_proba = clf.predict_proba(X)
    y_log_proba = clf.predict_log_proba(X)

    (n_samples, n_classes) = y.shape[0], 2

    proba_max = y_proba.argmax(axis=1)
    proba_log_max = y_log_proba.argmax(axis=1)

    assert y_proba.shape == (n_samples, n_classes)
    assert_array_equal(proba_max, proba_log_max)
    assert_allclose(y_log_proba, np.log(y_proba))

    assert roc_auc_score(y, y_proba[:, 1]) == 1.0


def test_predict_proba_multiclass():
    # Test that predict_proba works as expected for multi class.
    X = X_digits_multi[:10]
    y = y_digits_multi[:10]

    clf = MLPClassifier(hidden_layer_sizes=5)
    with ignore_warnings(category=ConvergenceWarning):
        clf.fit(X, y)
    y_proba = clf.predict_proba(X)
    y_log_proba = clf.predict_log_proba(X)

    (n_samples, n_classes) = y.shape[0], np.unique(y).size

    proba_max = y_proba.argmax(axis=1)
    proba_log_max = y_log_proba.argmax(axis=1)

    assert y_proba.shape == (n_samples, n_classes)
    assert_array_equal(proba_max, proba_log_max)
    assert_allclose(y_log_proba, np.log(y_proba))


def test_predict_proba_multilabel():
    # Test that predict_proba works as expected for multilabel.
    # Multilabel should not use softmax which makes probabilities sum to 1
    X, Y = make_multilabel_classification(
        n_samples=50, random_state=0, return_indicator=True
    )
    n_samples, n_classes = Y.shape

    clf = MLPClassifier(solver="lbfgs", hidden_layer_sizes=30, random_state=0)
    clf.fit(X, Y)
    y_proba = clf.predict_proba(X)

    assert y_proba.shape == (n_samples, n_classes)
    assert_array_equal(y_proba > 0.5, Y)

    y_log_proba = clf.predict_log_proba(X)
    proba_max = y_proba.argmax(axis=1)
    proba_log_max = y_log_proba.argmax(axis=1)

    assert (y_proba.sum(1) - 1).dot(y_proba.sum(1) - 1) > 1e-10
    assert_array_equal(proba_max, proba_log_max)
    assert_allclose(y_log_proba, np.log(y_proba))


def test_shuffle():
    # Test that the shuffle parameter affects the training process (it should)
    X, y = make_regression(n_samples=50, n_features=5, n_targets=1, random_state=0)

    # The coefficients will be identical if both do or do not shuffle
    for shuffle in [True, False]:
        mlp1 = MLPRegressor(
            hidden_layer_sizes=1,
            max_iter=1,
            batch_size=1,
            random_state=0,
            shuffle=shuffle,
        )
        mlp2 = MLPRegressor(
            hidden_layer_sizes=1,
            max_iter=1,
            batch_size=1,
            random_state=0,
            shuffle=shuffle,
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", ConvergenceWarning)
            mlp1.fit(X, y)
            mlp2.fit(X, y)

        assert np.array_equal(mlp1.coefs_[0], mlp2.coefs_[0])

    # The coefficients will be slightly different if shuffle=True
    mlp1 = MLPRegressor(
        hidden_layer_sizes=1, max_iter=1, batch_size=1, random_state=0, shuffle=True
    )
    mlp2 = MLPRegressor(
        hidden_layer_sizes=1, max_iter=1, batch_size=1, random_state=0, shuffle=False
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ConvergenceWarning)
        mlp1.fit(X, y)
        mlp2.fit(X, y)

    assert not np.array_equal(mlp1.coefs_[0], mlp2.coefs_[0])


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_sparse_matrices(csr_container):
    # Test that sparse and dense input matrices output the same results.
    X = X_digits_binary[:50]
    y = y_digits_binary[:50]
    X_sparse = csr_container(X)
    mlp = MLPClassifier(solver="lbfgs", hidden_layer_sizes=15, random_state=1)
    mlp.fit(X, y)
    pred1 = mlp.predict(X)
    mlp.fit(X_sparse, y)
    pred2 = mlp.predict(X_sparse)
    assert_almost_equal(pred1, pred2)
    pred1 = mlp.predict(X)
    pred2 = mlp.predict(X_sparse)
    assert_array_equal(pred1, pred2)


def test_tolerance():
    # Test tolerance.
    # It should force the solver to exit the loop when it converges.
    X = [[3, 2], [1, 6]]
    y = [1, 0]
    clf = MLPClassifier(tol=0.5, max_iter=3000, solver="sgd")
    clf.fit(X, y)
    assert clf.max_iter > clf.n_iter_


def test_verbose_sgd():
    # Test verbose.
    X = [[3, 2], [1, 6]]
    y = [1, 0]
    clf = MLPClassifier(solver="sgd", max_iter=2, verbose=10, hidden_layer_sizes=2)
    old_stdout = sys.stdout
    sys.stdout = output = StringIO()

    with ignore_warnings(category=ConvergenceWarning):
        clf.fit(X, y)
    clf.partial_fit(X, y)

    sys.stdout = old_stdout
    assert "Iteration" in output.getvalue()


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
def test_early_stopping(MLPEstimator):
    X = X_digits_binary[:100]
    y = y_digits_binary[:100]
    tol = 0.2
    mlp_estimator = MLPEstimator(
        tol=tol, max_iter=3000, solver="sgd", early_stopping=True
    )
    mlp_estimator.fit(X, y)
    assert mlp_estimator.max_iter > mlp_estimator.n_iter_

    assert mlp_estimator.best_loss_ is None
    assert isinstance(mlp_estimator.validation_scores_, list)

    valid_scores = mlp_estimator.validation_scores_
    best_valid_score = mlp_estimator.best_validation_score_
    assert max(valid_scores) == best_valid_score
    assert best_valid_score + tol > valid_scores[-2]
    assert best_valid_score + tol > valid_scores[-1]

    # check that the attributes `validation_scores_` and `best_validation_score_`
    # are set to None when `early_stopping=False`
    mlp_estimator = MLPEstimator(
        tol=tol, max_iter=3000, solver="sgd", early_stopping=False
    )
    mlp_estimator.fit(X, y)
    assert mlp_estimator.validation_scores_ is None
    assert mlp_estimator.best_validation_score_ is None
    assert mlp_estimator.best_loss_ is not None


def test_adaptive_learning_rate():
    X = [[3, 2], [1, 6]]
    y = [1, 0]
    clf = MLPClassifier(tol=0.5, max_iter=3000, solver="sgd", learning_rate="adaptive")
    clf.fit(X, y)
    assert clf.max_iter > clf.n_iter_
    assert 1e-6 > clf._optimizer.learning_rate


def test_warm_start():
    X = X_iris
    y = y_iris

    y_2classes = np.array([0] * 75 + [1] * 75)
    y_3classes = np.array([0] * 40 + [1] * 40 + [2] * 70)
    y_3classes_alt = np.array([0] * 50 + [1] * 50 + [3] * 50)
    y_4classes = np.array([0] * 37 + [1] * 37 + [2] * 38 + [3] * 38)
    y_5classes = np.array([0] * 30 + [1] * 30 + [2] * 30 + [3] * 30 + [4] * 30)

    # No error raised
    clf = MLPClassifier(
        hidden_layer_sizes=2, solver="lbfgs", warm_start=True, random_state=42, tol=1e-2
    ).fit(X, y)
    clf.fit(X, y)
    clf.fit(X, y_3classes)

    for y_i in (y_2classes, y_3classes_alt, y_4classes, y_5classes):
        clf = MLPClassifier(
            hidden_layer_sizes=2,
            solver="lbfgs",
            warm_start=True,
            random_state=42,
            tol=1e-2,
        ).fit(X, y)
        message = (
            "warm_start can only be used where `y` has the same "
            "classes as in the previous call to fit."
            " Previously got [0 1 2], `y` has %s" % np.unique(y_i)
        )
        with pytest.raises(ValueError, match=re.escape(message)):
            clf.fit(X, y_i)


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
def test_warm_start_full_iteration(MLPEstimator):
    # Non-regression test for:
    # https://github.com/scikit-learn/scikit-learn/issues/16812
    # Check that the MLP estimator accomplish `max_iter` with a
    # warm started estimator.
    X, y = X_iris, y_iris
    max_iter = 3
    clf = MLPEstimator(
        hidden_layer_sizes=2, solver="sgd", warm_start=True, max_iter=max_iter
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ConvergenceWarning)
        clf.fit(X, y)
        assert max_iter == clf.n_iter_
        clf.fit(X, y)
        assert max_iter == clf.n_iter_


def test_n_iter_no_change():
    # test n_iter_no_change using binary data set
    # the classifying fitting process is not prone to loss curve fluctuations
    X = X_digits_binary[:100]
    y = y_digits_binary[:100]
    tol = 0.01
    max_iter = 3000

    # test multiple n_iter_no_change
    for n_iter_no_change in [2, 5, 10, 50, 100]:
        clf = MLPClassifier(
            tol=tol, max_iter=max_iter, solver="sgd", n_iter_no_change=n_iter_no_change
        )
        clf.fit(X, y)

        # validate n_iter_no_change
        assert clf._no_improvement_count == n_iter_no_change + 1
        assert max_iter > clf.n_iter_


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
def test_n_iter_no_change_inf():
    # test n_iter_no_change using binary data set
    # the fitting process should go to max_iter iterations
    X = X_digits_binary[:100]
    y = y_digits_binary[:100]

    # set a ridiculous tolerance
    # this should always trigger _update_no_improvement_count()
    tol = 1e9

    # fit
    n_iter_no_change = np.inf
    max_iter = 3000
    clf = MLPClassifier(
        tol=tol, max_iter=max_iter, solver="sgd", n_iter_no_change=n_iter_no_change
    )
    clf.fit(X, y)

    # validate n_iter_no_change doesn't cause early stopping
    assert clf.n_iter_ == max_iter

    # validate _update_no_improvement_count() was always triggered
    assert clf._no_improvement_count == clf.n_iter_ - 1


def test_early_stopping_stratified():
    # Make sure data splitting for early stopping is stratified
    X = [[1, 2], [2, 3], [3, 4], [4, 5]]
    y = [0, 0, 0, 1]

    mlp = MLPClassifier(early_stopping=True)
    with pytest.raises(
        ValueError, match="The least populated class in y has only 1 member"
    ):
        mlp.fit(X, y)


def test_mlp_classifier_dtypes_casting():
    # Compare predictions for different dtypes
    mlp_64 = MLPClassifier(
        alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=100, tol=1e-1
    )
    mlp_64.fit(X_digits[:300], y_digits[:300])
    pred_64 = mlp_64.predict(X_digits[300:])
    proba_64 = mlp_64.predict_proba(X_digits[300:])

    mlp_32 = MLPClassifier(
        alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=100, tol=1e-1
    )
    mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
    pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))
    proba_32 = mlp_32.predict_proba(X_digits[300:].astype(np.float32))

    assert_array_equal(pred_64, pred_32)
    assert_allclose(proba_64, proba_32, rtol=1e-02)


def test_mlp_regressor_dtypes_casting():
    mlp_64 = MLPRegressor(
        alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=150, tol=1e-3
    )
    mlp_64.fit(X_digits[:300], y_digits[:300])
    pred_64 = mlp_64.predict(X_digits[300:])

    mlp_32 = MLPRegressor(
        alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=150, tol=1e-3
    )
    mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
    pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))

    assert_allclose(pred_64, pred_32, rtol=5e-04)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("Estimator", [MLPClassifier, MLPRegressor])
def test_mlp_param_dtypes(dtype, Estimator):
    # Checks if input dtype is used for network parameters
    # and predictions
    X, y = X_digits.astype(dtype), y_digits
    mlp = Estimator(
        alpha=1e-5, hidden_layer_sizes=(5, 3), random_state=1, max_iter=50, tol=1e-1
    )
    mlp.fit(X[:300], y[:300])
    pred = mlp.predict(X[300:])

    assert all([intercept.dtype == dtype for intercept in mlp.intercepts_])

    assert all([coef.dtype == dtype for coef in mlp.coefs_])

    if Estimator == MLPRegressor:
        assert pred.dtype == dtype


def test_mlp_loading_from_joblib_partial_fit(tmp_path):
    """Loading from MLP and partial fitting updates weights. Non-regression
    test for #19626."""
    pre_trained_estimator = MLPRegressor(
        hidden_layer_sizes=(42,), random_state=42, learning_rate_init=0.01, max_iter=200
    )
    features, target = [[2]], [4]

    # Fit on x=2, y=4
    pre_trained_estimator.fit(features, target)

    # dump and load model
    pickled_file = tmp_path / "mlp.pkl"
    joblib.dump(pre_trained_estimator, pickled_file)
    load_estimator = joblib.load(pickled_file)

    # Train for a more epochs on point x=2, y=1
    fine_tune_features, fine_tune_target = [[2]], [1]

    for _ in range(200):
        load_estimator.partial_fit(fine_tune_features, fine_tune_target)

    # finetuned model learned the new target
    predicted_value = load_estimator.predict(fine_tune_features)
    assert_allclose(predicted_value, fine_tune_target, rtol=1e-4)


@pytest.mark.parametrize("Estimator", [MLPClassifier, MLPRegressor])
def test_preserve_feature_names(Estimator):
    """Check that feature names are preserved when early stopping is enabled.

    Feature names are required for consistency checks during scoring.

    Non-regression test for gh-24846
    """
    pd = pytest.importorskip("pandas")
    rng = np.random.RandomState(0)

    X = pd.DataFrame(data=rng.randn(10, 2), columns=["colname_a", "colname_b"])
    y = pd.Series(data=np.full(10, 1), name="colname_y")

    model = Estimator(early_stopping=True, validation_fraction=0.2)

    with warnings.catch_warnings():
        warnings.simplefilter("error", UserWarning)
        model.fit(X, y)


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
def test_mlp_warm_start_with_early_stopping(MLPEstimator):
    """Check that early stopping works with warm start."""
    mlp = MLPEstimator(
        max_iter=10, random_state=0, warm_start=True, early_stopping=True
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ConvergenceWarning)
        mlp.fit(X_iris, y_iris)
        n_validation_scores = len(mlp.validation_scores_)
        mlp.set_params(max_iter=20)
        mlp.fit(X_iris, y_iris)
    assert len(mlp.validation_scores_) > n_validation_scores


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
@pytest.mark.parametrize("solver", ["sgd", "adam", "lbfgs"])
def test_mlp_warm_start_no_convergence(MLPEstimator, solver):
    """Check that we stop the number of iteration at `max_iter` when warm starting.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/24764
    """
    model = MLPEstimator(
        solver=solver,
        warm_start=True,
        early_stopping=False,
        max_iter=10,
        n_iter_no_change=np.inf,
        random_state=0,
    )

    with pytest.warns(ConvergenceWarning):
        model.fit(X_iris, y_iris)
    assert model.n_iter_ == 10

    model.set_params(max_iter=20)
    with pytest.warns(ConvergenceWarning):
        model.fit(X_iris, y_iris)
    assert model.n_iter_ == 20


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
def test_mlp_partial_fit_after_fit(MLPEstimator):
    """Check partial fit does not fail after fit when early_stopping=True.

    Non-regression test for gh-25693.
    """
    mlp = MLPEstimator(early_stopping=True, random_state=0).fit(X_iris, y_iris)

    msg = "partial_fit does not support early_stopping=True"
    with pytest.raises(ValueError, match=msg):
        mlp.partial_fit(X_iris, y_iris)


def test_mlp_diverging_loss():
    """Test that a diverging model does not raise errors when early stopping is enabled.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/29504
    """
    mlp = MLPRegressor(
        hidden_layer_sizes=100,
        activation="identity",
        solver="sgd",
        alpha=0.0001,
        learning_rate="constant",
        learning_rate_init=1,
        shuffle=True,
        max_iter=20,
        early_stopping=True,
        n_iter_no_change=10,
        random_state=0,
    )

    with warnings.catch_warnings():
        # RuntimeWarning: overflow encountered in matmul
        # ConvergenceWarning: Stochastic Optimizer: Maximum iteration
        warnings.simplefilter("ignore", RuntimeWarning)
        warnings.simplefilter("ignore", ConvergenceWarning)
        mlp.fit(X_iris, y_iris)

    # In python, float("nan") != float("nan")
    assert str(mlp.validation_scores_[-1]) == str(np.nan)
    assert isinstance(mlp.validation_scores_[-1], float)
