"""
This is testing the equivalence between some estimators with internal nearest
neighbors computations, and the corresponding pipeline versions with
KNeighborsTransformer or RadiusNeighborsTransformer to precompute the
neighbors.
"""

import numpy as np

from sklearn.base import clone
from sklearn.cluster import DBSCAN, SpectralClustering
from sklearn.cluster.tests.common import generate_clustered_data
from sklearn.datasets import make_blobs
from sklearn.manifold import TSNE, Isomap, SpectralEmbedding
from sklearn.neighbors import (
    KNeighborsRegressor,
    KNeighborsTransformer,
    LocalOutlierFactor,
    RadiusNeighborsRegressor,
    RadiusNeighborsTransformer,
)
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_almost_equal


def test_spectral_clustering():
    # Test chaining KNeighborsTransformer and SpectralClustering
    n_neighbors = 5
    X, _ = make_blobs(random_state=0)

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        KNeighborsTransformer(n_neighbors=n_neighbors, mode="connectivity"),
        SpectralClustering(
            n_neighbors=n_neighbors, affinity="precomputed", random_state=42
        ),
    )
    est_compact = SpectralClustering(
        n_neighbors=n_neighbors, affinity="nearest_neighbors", random_state=42
    )
    labels_compact = est_compact.fit_predict(X)
    labels_chain = est_chain.fit_predict(X)
    assert_array_almost_equal(labels_chain, labels_compact)


def test_spectral_embedding():
    # Test chaining KNeighborsTransformer and SpectralEmbedding
    n_neighbors = 5

    n_samples = 1000
    centers = np.array(
        [
            [0.0, 5.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 4.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 5.0, 1.0],
        ]
    )
    S, true_labels = make_blobs(
        n_samples=n_samples, centers=centers, cluster_std=1.0, random_state=42
    )

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        KNeighborsTransformer(n_neighbors=n_neighbors, mode="connectivity"),
        SpectralEmbedding(
            n_neighbors=n_neighbors, affinity="precomputed", random_state=42
        ),
    )
    est_compact = SpectralEmbedding(
        n_neighbors=n_neighbors, affinity="nearest_neighbors", random_state=42
    )
    St_compact = est_compact.fit_transform(S)
    St_chain = est_chain.fit_transform(S)
    assert_array_almost_equal(St_chain, St_compact)


def test_dbscan():
    # Test chaining RadiusNeighborsTransformer and DBSCAN
    radius = 0.3
    n_clusters = 3
    X = generate_clustered_data(n_clusters=n_clusters)

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        RadiusNeighborsTransformer(radius=radius, mode="distance"),
        DBSCAN(metric="precomputed", eps=radius),
    )
    est_compact = DBSCAN(eps=radius)

    labels_chain = est_chain.fit_predict(X)
    labels_compact = est_compact.fit_predict(X)
    assert_array_almost_equal(labels_chain, labels_compact)


def test_isomap():
    # Test chaining KNeighborsTransformer and Isomap with
    # neighbors_algorithm='precomputed'
    algorithm = "auto"
    n_neighbors = 10

    X, _ = make_blobs(random_state=0)
    X2, _ = make_blobs(random_state=1)

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        KNeighborsTransformer(
            n_neighbors=n_neighbors, algorithm=algorithm, mode="distance"
        ),
        Isomap(n_neighbors=n_neighbors, metric="precomputed"),
    )
    est_compact = Isomap(n_neighbors=n_neighbors, neighbors_algorithm=algorithm)

    Xt_chain = est_chain.fit_transform(X)
    Xt_compact = est_compact.fit_transform(X)
    assert_array_almost_equal(Xt_chain, Xt_compact)

    Xt_chain = est_chain.transform(X2)
    Xt_compact = est_compact.transform(X2)
    assert_array_almost_equal(Xt_chain, Xt_compact)


def test_tsne():
    # Test chaining KNeighborsTransformer and TSNE
    max_iter = 250
    perplexity = 5
    n_neighbors = int(3.0 * perplexity + 1)

    rng = np.random.RandomState(0)
    X = rng.randn(20, 2)

    for metric in ["minkowski", "sqeuclidean"]:
        # compare the chained version and the compact version
        est_chain = make_pipeline(
            KNeighborsTransformer(
                n_neighbors=n_neighbors, mode="distance", metric=metric
            ),
            TSNE(
                init="random",
                metric="precomputed",
                perplexity=perplexity,
                method="barnes_hut",
                random_state=42,
                max_iter=max_iter,
            ),
        )
        est_compact = TSNE(
            init="random",
            metric=metric,
            perplexity=perplexity,
            max_iter=max_iter,
            method="barnes_hut",
            random_state=42,
        )

        Xt_chain = est_chain.fit_transform(X)
        Xt_compact = est_compact.fit_transform(X)
        assert_array_almost_equal(Xt_chain, Xt_compact)


def test_lof_novelty_false():
    # Test chaining KNeighborsTransformer and LocalOutlierFactor
    n_neighbors = 4

    rng = np.random.RandomState(0)
    X = rng.randn(40, 2)

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        KNeighborsTransformer(n_neighbors=n_neighbors, mode="distance"),
        LocalOutlierFactor(
            metric="precomputed",
            n_neighbors=n_neighbors,
            novelty=False,
            contamination="auto",
        ),
    )
    est_compact = LocalOutlierFactor(
        n_neighbors=n_neighbors, novelty=False, contamination="auto"
    )

    pred_chain = est_chain.fit_predict(X)
    pred_compact = est_compact.fit_predict(X)
    assert_array_almost_equal(pred_chain, pred_compact)


def test_lof_novelty_true():
    # Test chaining KNeighborsTransformer and LocalOutlierFactor
    n_neighbors = 4

    rng = np.random.RandomState(0)
    X1 = rng.randn(40, 2)
    X2 = rng.randn(40, 2)

    # compare the chained version and the compact version
    est_chain = make_pipeline(
        KNeighborsTransformer(n_neighbors=n_neighbors, mode="distance"),
        LocalOutlierFactor(
            metric="precomputed",
            n_neighbors=n_neighbors,
            novelty=True,
            contamination="auto",
        ),
    )
    est_compact = LocalOutlierFactor(
        n_neighbors=n_neighbors, novelty=True, contamination="auto"
    )

    pred_chain = est_chain.fit(X1).predict(X2)
    pred_compact = est_compact.fit(X1).predict(X2)
    assert_array_almost_equal(pred_chain, pred_compact)


def test_kneighbors_regressor():
    # Test chaining KNeighborsTransformer and classifiers/regressors
    rng = np.random.RandomState(0)
    X = 2 * rng.rand(40, 5) - 1
    X2 = 2 * rng.rand(40, 5) - 1
    y = rng.rand(40, 1)

    n_neighbors = 12
    radius = 1.5
    # We precompute more neighbors than necessary, to have equivalence between
    # k-neighbors estimator after radius-neighbors transformer, and vice-versa.
    factor = 2

    k_trans = KNeighborsTransformer(n_neighbors=n_neighbors, mode="distance")
    k_trans_factor = KNeighborsTransformer(
        n_neighbors=int(n_neighbors * factor), mode="distance"
    )

    r_trans = RadiusNeighborsTransformer(radius=radius, mode="distance")
    r_trans_factor = RadiusNeighborsTransformer(
        radius=int(radius * factor), mode="distance"
    )

    k_reg = KNeighborsRegressor(n_neighbors=n_neighbors)
    r_reg = RadiusNeighborsRegressor(radius=radius)

    test_list = [
        (k_trans, k_reg),
        (k_trans_factor, r_reg),
        (r_trans, r_reg),
        (r_trans_factor, k_reg),
    ]

    for trans, reg in test_list:
        # compare the chained version and the compact version
        reg_compact = clone(reg)
        reg_precomp = clone(reg)
        reg_precomp.set_params(metric="precomputed")

        reg_chain = make_pipeline(clone(trans), reg_precomp)

        y_pred_chain = reg_chain.fit(X, y).predict(X2)
        y_pred_compact = reg_compact.fit(X, y).predict(X2)
        assert_array_almost_equal(y_pred_chain, y_pred_compact)
