from __future__ import absolute_import, division, print_function

import os
import warnings
from random import random
from time import sleep
from uuid import uuid4

import pytest

from .. import Parallel, delayed, parallel_backend, parallel_config
from .._dask import DaskDistributedBackend
from ..parallel import AutoBatchingMixin, ThreadingBackend
from .common import np, with_numpy
from .test_parallel import (
    _recursive_backend_info,
    _test_deadlock_with_generator,
    _test_parallel_unordered_generator_returns_fastest_first,  # noqa: E501
)

distributed = pytest.importorskip("distributed")
dask = pytest.importorskip("dask")

# These imports need to be after the pytest.importorskip hence the noqa: E402
from distributed import Client, LocalCluster, get_client  # noqa: E402
from distributed.metrics import time  # noqa: E402

# Note: pytest requires to manually import all fixtures used in the test
# and their dependencies.
from distributed.utils_test import cleanup, cluster, inc  # noqa: E402, F401


@pytest.fixture(scope="function", autouse=True)
def avoid_dask_env_leaks(tmp_path):
    # when starting a dask nanny, the environment variable might change.
    # this fixture makes sure the environment is reset after the test.

    from joblib._parallel_backends import ParallelBackendBase

    old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
    yield

    # Reset the environment variables to their original values
    for k, v in old_value.items():
        if v is None:
            os.environ.pop(k, None)
        else:
            os.environ[k] = v


def noop(*args, **kwargs):
    pass


def slow_raise_value_error(condition, duration=0.05):
    sleep(duration)
    if condition:
        raise ValueError("condition evaluated to True")


def count_events(event_name, client):
    worker_events = client.run(lambda dask_worker: dask_worker.log)
    event_counts = {}
    for w, events in worker_events.items():
        event_counts[w] = len(
            [event for event in list(events) if event[1] == event_name]
        )
    return event_counts


def test_simple(loop):
    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask"):
                seq = Parallel()(delayed(inc)(i) for i in range(10))
                assert seq == [inc(i) for i in range(10)]

                with pytest.raises(ValueError):
                    Parallel()(
                        delayed(slow_raise_value_error)(i == 3) for i in range(10)
                    )

                seq = Parallel()(delayed(inc)(i) for i in range(10))
                assert seq == [inc(i) for i in range(10)]


def test_dask_backend_uses_autobatching(loop):
    assert (
        DaskDistributedBackend.compute_batch_size
        is AutoBatchingMixin.compute_batch_size
    )

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask"):
                with Parallel() as parallel:
                    # The backend should be initialized with a default
                    # batch size of 1:
                    backend = parallel._backend
                    assert isinstance(backend, DaskDistributedBackend)
                    assert backend.parallel is parallel
                    assert backend._effective_batch_size == 1

                    # Launch many short tasks that should trigger
                    # auto-batching:
                    parallel(delayed(lambda: None)() for _ in range(int(1e4)))
                    assert backend._effective_batch_size > 10


@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
    with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
        _test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)


@with_numpy
@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
    with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
        _test_deadlock_with_generator(None, return_as, n_jobs)


@with_numpy
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_nested_parallelism_with_dask(context):
    with distributed.Client(n_workers=2, threads_per_worker=2):
        # 10 MB of data as argument to trigger implicit scattering
        data = np.ones(int(1e7), dtype=np.uint8)
        for i in range(2):
            with context("dask"):
                backend_types_and_levels = _recursive_backend_info(data=data)
            assert len(backend_types_and_levels) == 4
            assert all(
                name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
            )

        # No argument
        with context("dask"):
            backend_types_and_levels = _recursive_backend_info()
        assert len(backend_types_and_levels) == 4
        assert all(
            name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
        )


def random2():
    return random()


def test_dont_assume_function_purity(loop):
    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask"):
                x, y = Parallel()(delayed(random2)() for i in range(2))
                assert x != y


@pytest.mark.parametrize("mixed", [True, False])
def test_dask_funcname(loop, mixed):
    from joblib._dask import Batch

    if not mixed:
        tasks = [delayed(inc)(i) for i in range(4)]
        batch_repr = "batch_of_inc_4_calls"
    else:
        tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
        batch_repr = "mixed_batch_of_inc_4_calls"

    assert repr(Batch(tasks)) == batch_repr

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:
            with parallel_config(backend="dask"):
                _ = Parallel(batch_size=2, pre_dispatch="all")(tasks)

            def f(dask_scheduler):
                return list(dask_scheduler.transition_log)

            batch_repr = batch_repr.replace("4", "2")
            log = client.run_on_scheduler(f)
            assert all("batch_of_inc" in tup[0] for tup in log)


def test_no_undesired_distributed_cache_hit():
    # Dask has a pickle cache for callables that are called many times. Because
    # the dask backends used to wrap both the functions and the arguments
    # under instances of the Batch callable class this caching mechanism could
    # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
    # The joblib-dask backend has been refactored to avoid bundling the
    # arguments as an attribute of the Batch instance to avoid this problem.
    # This test serves as non-regression problem.

    # Use a large number of input arguments to give the AutoBatchingMixin
    # enough tasks to kick-in.
    lists = [[] for _ in range(100)]
    np = pytest.importorskip("numpy")
    X = np.arange(int(1e6))

    def isolated_operation(list_, data=None):
        if data is not None:
            np.testing.assert_array_equal(data, X)
        list_.append(uuid4().hex)
        return list_

    cluster = LocalCluster(n_workers=1, threads_per_worker=2)
    client = Client(cluster)
    try:
        with parallel_config(backend="dask"):
            # dispatches joblib.parallel.BatchedCalls
            res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)

        # The original arguments should not have been mutated as the mutation
        # happens in the dask worker process.
        assert lists == [[] for _ in range(100)]

        # Here we did not pass any large numpy array as argument to
        # isolated_operation so no scattering event should happen under the
        # hood.
        counts = count_events("receive-from-scatter", client)
        assert sum(counts.values()) == 0
        assert all([len(r) == 1 for r in res])

        with parallel_config(backend="dask"):
            # Append a large array which will be scattered by dask, and
            # dispatch joblib._dask.Batch
            res = Parallel()(
                delayed(isolated_operation)(list_, data=X) for list_ in lists
            )

        # This time, auto-scattering should have kicked it.
        counts = count_events("receive-from-scatter", client)
        assert sum(counts.values()) > 0
        assert all([len(r) == 1 for r in res])
    finally:
        client.close(timeout=30)
        cluster.close(timeout=30)


class CountSerialized(object):
    def __init__(self, x):
        self.x = x
        self.count = 0

    def __add__(self, other):
        return self.x + getattr(other, "x", other)

    __radd__ = __add__

    def __reduce__(self):
        self.count += 1
        return (CountSerialized, (self.x,))


def add5(a, b, c, d=0, e=0):
    return a + b + c + d + e


def test_manual_scatter(loop):
    # Let's check that the number of times scattered and non-scattered
    # variables are serialized is consistent between `joblib.Parallel` calls
    # and equivalent native `client.submit` call.

    # Number of serializations can vary from dask to another, so this test only
    # checks that `joblib.Parallel` does not add more serialization steps than
    # a native `client.submit` call, but does not check for an exact number of
    # serialization steps.

    w, x, y, z = (CountSerialized(i) for i in range(4))

    f = delayed(add5)
    tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
    tasks += [
        f(x, z, y, d=5, e=4),
        f(y, x, z, d=x, e=5),
        f(z, z, x, d=z, e=y),
    ]
    expected = [func(*args, **kwargs) for func, args, kwargs in tasks]

    with cluster() as (s, _):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask", scatter=[w, x, y]):
                results_parallel = Parallel(batch_size=1)(tasks)
                assert results_parallel == expected

            # Check that an error is raised for bad arguments, as scatter must
            # take a list/tuple
            with pytest.raises(TypeError):
                with parallel_config(backend="dask", loop=loop, scatter=1):
                    pass

    # Scattered variables only serialized during scatter. Checking with an
    # extra variable as this count can vary from one dask version
    # to another.
    n_serialization_scatter_with_parallel = w.count
    assert x.count == n_serialization_scatter_with_parallel
    assert y.count == n_serialization_scatter_with_parallel
    n_serialization_with_parallel = z.count

    # Reset the cluster and the serialization count
    for var in (w, x, y, z):
        var.count = 0

    with cluster() as (s, _):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            scattered = dict()
            for obj in w, x, y:
                scattered[id(obj)] = client.scatter(obj, broadcast=True)
            results_native = [
                client.submit(
                    func,
                    *(scattered.get(id(arg), arg) for arg in args),
                    **dict(
                        (key, scattered.get(id(value), value))
                        for (key, value) in kwargs.items()
                    ),
                    key=str(uuid4()),
                ).result()
                for (func, args, kwargs) in tasks
            ]
            assert results_native == expected

    # Now check that the number of serialization steps is the same for joblib
    # and native dask calls.
    n_serialization_scatter_native = w.count
    assert x.count == n_serialization_scatter_native
    assert y.count == n_serialization_scatter_native

    assert n_serialization_scatter_with_parallel == n_serialization_scatter_native

    distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
    if distributed_version < (2023, 4):
        # Previous to 2023.4, the serialization was adding an extra call to
        # __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
        # appears both in the args and kwargs, which is not the case when
        # running with joblib. Cope with this discrepancy.
        assert z.count == n_serialization_with_parallel + 1
    else:
        assert z.count == n_serialization_with_parallel


# When the same IOLoop is used for multiple clients in a row, use
# loop_in_thread instead of loop to prevent the Client from closing it.  See
# dask/distributed #4112
def test_auto_scatter(loop_in_thread):
    np = pytest.importorskip("numpy")
    data1 = np.ones(int(1e4), dtype=np.uint8)
    data2 = np.ones(int(1e4), dtype=np.uint8)
    data_to_process = ([data1] * 3) + ([data2] * 3)

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop_in_thread) as client:
            with parallel_config(backend="dask"):
                # Passing the same data as arg and kwarg triggers a single
                # scatter operation whose result is reused.
                Parallel()(
                    delayed(noop)(data, data, i, opt=data)
                    for i, data in enumerate(data_to_process)
                )
            # By default large array are automatically scattered with
            # broadcast=1 which means that one worker must directly receive
            # the data from the scatter operation once.
            counts = count_events("receive-from-scatter", client)
            assert counts[a["address"]] + counts[b["address"]] == 2

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop_in_thread) as client:
            with parallel_config(backend="dask"):
                Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
            # Small arrays are passed within the task definition without going
            # through a scatter operation.
            counts = count_events("receive-from-scatter", client)
            assert counts[a["address"]] == 0
            assert counts[b["address"]] == 0


@pytest.mark.parametrize("retry_no", list(range(2)))
def test_nested_scatter(loop, retry_no):
    np = pytest.importorskip("numpy")

    NUM_INNER_TASKS = 10
    NUM_OUTER_TASKS = 10

    def my_sum(x, i, j):
        return np.sum(x)

    def outer_function_joblib(array, i):
        client = get_client()  # noqa
        with parallel_config(backend="dask"):
            results = Parallel()(
                delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
            )
        return sum(results)

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as _:
            with parallel_config(backend="dask"):
                my_array = np.ones(10000)
                _ = Parallel()(
                    delayed(outer_function_joblib)(my_array[i:], i)
                    for i in range(NUM_OUTER_TASKS)
                )


def test_nested_backend_context_manager(loop_in_thread):
    def get_nested_pids():
        pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
        pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
        return pids

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop_in_thread) as client:
            with parallel_config(backend="dask"):
                pid_groups = Parallel(n_jobs=2)(
                    delayed(get_nested_pids)() for _ in range(10)
                )
                for pid_group in pid_groups:
                    assert len(set(pid_group)) <= 2

        # No deadlocks
        with Client(s["address"], loop=loop_in_thread) as client:  # noqa: F841
            with parallel_config(backend="dask"):
                pid_groups = Parallel(n_jobs=2)(
                    delayed(get_nested_pids)() for _ in range(10)
                )
                for pid_group in pid_groups:
                    assert len(set(pid_group)) <= 2


def test_nested_backend_context_manager_implicit_n_jobs(loop):
    # Check that Parallel with no explicit n_jobs value automatically selects
    # all the dask workers, including in nested calls.

    def _backend_type(p):
        return p._backend.__class__.__name__

    def get_nested_implicit_n_jobs():
        with Parallel() as p:
            return _backend_type(p), p.n_jobs

    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask"):
                with Parallel() as p:
                    assert _backend_type(p) == "DaskDistributedBackend"
                    assert p.n_jobs == -1
                    all_nested_n_jobs = p(
                        delayed(get_nested_implicit_n_jobs)() for _ in range(2)
                    )
                for backend_type, nested_n_jobs in all_nested_n_jobs:
                    assert backend_type == "DaskDistributedBackend"
                    assert nested_n_jobs == -1


def test_errors(loop):
    with pytest.raises(ValueError) as info:
        with parallel_config(backend="dask"):
            pass

    assert "create a dask client" in str(info.value).lower()


def test_correct_nested_backend(loop):
    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            # No requirement, should be us
            with parallel_config(backend="dask"):
                result = Parallel(n_jobs=2)(
                    delayed(outer)(nested_require=None) for _ in range(1)
                )
                assert isinstance(result[0][0][0], DaskDistributedBackend)

            # Require threads, should be threading
            with parallel_config(backend="dask"):
                result = Parallel(n_jobs=2)(
                    delayed(outer)(nested_require="sharedmem") for _ in range(1)
                )
                assert isinstance(result[0][0][0], ThreadingBackend)


def outer(nested_require):
    return Parallel(n_jobs=2, prefer="threads")(
        delayed(middle)(nested_require) for _ in range(1)
    )


def middle(require):
    return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))


def inner():
    return Parallel()._backend


def test_secede_with_no_processes(loop):
    # https://github.com/dask/distributed/issues/1775
    with Client(loop=loop, processes=False, set_as_default=True):
        with parallel_config(backend="dask"):
            Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))


def _worker_address(_):
    from distributed import get_worker

    return get_worker().address


def test_dask_backend_keywords(loop):
    with cluster() as (s, [a, b]):
        with Client(s["address"], loop=loop) as client:  # noqa: F841
            with parallel_config(backend="dask", workers=a["address"]):
                seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
                assert seq == [a["address"]] * 10

            with parallel_config(backend="dask", workers=b["address"]):
                seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
                assert seq == [b["address"]] * 10


def test_scheduler_tasks_cleanup(loop):
    with Client(processes=False, loop=loop) as client:
        with parallel_config(backend="dask"):
            Parallel()(delayed(inc)(i) for i in range(10))

        start = time()
        while client.cluster.scheduler.tasks:
            sleep(0.01)
            assert time() < start + 5

        assert not client.futures


@pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
@pytest.mark.skipif(
    distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
    reason="distributed bug - https://github.com/dask/distributed/pull/2841",
)
def test_wait_for_workers(cluster_strategy):
    cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
    client = Client(cluster)
    if cluster_strategy == "adaptive":
        cluster.adapt(minimum=0, maximum=2)
    elif cluster_strategy == "late_scaling":
        # Tell the cluster to start workers but this is a non-blocking call
        # and new workers might take time to connect. In this case the Parallel
        # call should wait for at least one worker to come up before starting
        # to schedule work.
        cluster.scale(2)
    try:
        with parallel_config(backend="dask"):
            # The following should wait a bit for at least one worker to
            # become available.
            Parallel()(delayed(inc)(i) for i in range(10))
    finally:
        client.close()
        cluster.close()


def test_wait_for_workers_timeout():
    # Start a cluster with 0 worker:
    cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
    client = Client(cluster)
    try:
        with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
            # Short timeout: DaskDistributedBackend
            msg = "DaskDistributedBackend has no worker after 0.1 seconds."
            with pytest.raises(TimeoutError, match=msg):
                Parallel()(delayed(inc)(i) for i in range(10))

        with parallel_config(backend="dask", wait_for_workers_timeout=0):
            # No timeout: fallback to generic joblib failure:
            msg = "DaskDistributedBackend has no active worker"
            with pytest.raises(RuntimeError, match=msg):
                Parallel()(delayed(inc)(i) for i in range(10))
    finally:
        client.close()
        cluster.close()


@pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
def test_joblib_warning_inside_dask_daemonic_worker(backend):
    cluster = LocalCluster(n_workers=2)
    client = Client(cluster)
    try:

        def func_using_joblib_parallel():
            # Somehow trying to check the warning type here (e.g. with
            # pytest.warns(UserWarning)) make the test hang. Work-around:
            # return the warning record to the client and the warning check is
            # done client-side.
            with warnings.catch_warnings(record=True) as record:
                Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))

            return record

        fut = client.submit(func_using_joblib_parallel)
        record = fut.result()

        assert len(record) == 1
        warning = record[0].message
        assert isinstance(warning, UserWarning)
        assert "distributed.worker.daemon" in str(warning)
    finally:
        client.close(timeout=30)
        cluster.close(timeout=30)
