"""
Test the memory module.
"""

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009 Gael Varoquaux
# License: BSD Style, 3 clauses.

import datetime
import functools
import gc
import logging
import os
import os.path
import pathlib
import pickle
import shutil
import sys
import textwrap
import time

import pytest

from joblib._store_backends import FileSystemStoreBackend, StoreBackendBase
from joblib.hashing import hash
from joblib.memory import (
    _FUNCTION_HASHES,
    _STORE_BACKENDS,
    JobLibCollisionWarning,
    MemorizedFunc,
    MemorizedResult,
    Memory,
    NotMemorizedFunc,
    NotMemorizedResult,
    _build_func_identifier,
    _store_backend_factory,
    expires_after,
    register_store_backend,
)
from joblib.parallel import Parallel, delayed
from joblib.test.common import np, with_multiprocessing, with_numpy
from joblib.testing import parametrize, raises, warns


###############################################################################
# Module-level variables for the tests
def f(x, y=1):
    """A module-level function for testing purposes."""
    return x**2 + y


###############################################################################
# Helper function for the tests
def check_identity_lazy(func, accumulator, location):
    """Given a function and an accumulator (a list that grows every
    time the function is called), check that the function can be
    decorated by memory to be a lazy identity.
    """
    # Call each function with several arguments, and check that it is
    # evaluated only once per argument.
    memory = Memory(location=location, verbose=0)
    func = memory.cache(func)
    for i in range(3):
        for _ in range(2):
            assert func(i) == i
            assert len(accumulator) == i + 1


def corrupt_single_cache_item(memory):
    (single_cache_item,) = memory.store_backend.get_items()
    output_filename = os.path.join(single_cache_item.path, "output.pkl")
    with open(output_filename, "w") as f:
        f.write("garbage")


def monkeypatch_cached_func_warn(func, monkeypatch_fixture):
    # Need monkeypatch because pytest does not
    # capture stdlib logging output (see
    # https://github.com/pytest-dev/pytest/issues/2079)

    recorded = []

    def append_to_record(item):
        recorded.append(item)

    monkeypatch_fixture.setattr(func, "warn", append_to_record)
    return recorded


###############################################################################
# Tests
def test_memory_integration(tmpdir):
    """Simple test of memory lazy evaluation."""
    accumulator = list()

    # Rmk: this function has the same name than a module-level function,
    # thus it serves as a test to see that both are identified
    # as different.
    def f(arg):
        accumulator.append(1)
        return arg

    check_identity_lazy(f, accumulator, tmpdir.strpath)

    # Now test clearing
    for compress in (False, True):
        for mmap_mode in ("r", None):
            memory = Memory(
                location=tmpdir.strpath,
                verbose=10,
                mmap_mode=mmap_mode,
                compress=compress,
            )
            # First clear the cache directory, to check that our code can
            # handle that
            # NOTE: this line would raise an exception, as the database file is
            # still open; we ignore the error since we want to test what
            # happens if the directory disappears
            shutil.rmtree(tmpdir.strpath, ignore_errors=True)
            g = memory.cache(f)
            g(1)
            g.clear(warn=False)
            current_accumulator = len(accumulator)
            out = g(1)

        assert len(accumulator) == current_accumulator + 1
        # Also, check that Memory.eval works similarly
        assert memory.eval(f, 1) == out
        assert len(accumulator) == current_accumulator + 1

    # Now do a smoke test with a function defined in __main__, as the name
    # mangling rules are more complex
    f.__module__ = "__main__"
    memory = Memory(location=tmpdir.strpath, verbose=0)
    memory.cache(f)(1)


@parametrize("call_before_reducing", [True, False])
def test_parallel_call_cached_function_defined_in_jupyter(tmpdir, call_before_reducing):
    # Calling an interactively defined memory.cache()'d function inside a
    # Parallel call used to clear the existing cache related to the said
    # function (https://github.com/joblib/joblib/issues/1035)

    # This tests checks that this is no longer the case.

    # TODO: test that the cache related to the function cache persists across
    # ipython sessions (provided that no code change were made to the
    # function's source)?

    # The first part of the test makes the necessary low-level calls to emulate
    # the definition of a function in an jupyter notebook cell. Joblib has
    # some custom code to treat functions defined specifically in jupyter
    # notebooks/ipython session -- we want to test this code, which requires
    # the emulation to be rigorous.
    for session_no in [0, 1]:
        ipython_cell_source = """
        def f(x):
            return x
        """

        ipython_cell_id = "<ipython-input-{}-000000000000>".format(session_no)

        my_locals = {}
        exec(
            compile(
                textwrap.dedent(ipython_cell_source),
                filename=ipython_cell_id,
                mode="exec",
            ),
            # TODO when Python 3.11 is the minimum supported version, use
            # locals=my_locals instead of passing globals and locals in the
            # next two lines as positional arguments
            None,
            my_locals,
        )
        f = my_locals["f"]
        f.__module__ = "__main__"

        # Preliminary sanity checks, and tests checking that joblib properly
        # identified f as an interactive function defined in a jupyter notebook
        assert f(1) == 1
        assert f.__code__.co_filename == ipython_cell_id

        memory = Memory(location=tmpdir.strpath, verbose=0)
        cached_f = memory.cache(f)

        assert len(os.listdir(tmpdir / "joblib")) == 1
        f_cache_relative_directory = os.listdir(tmpdir / "joblib")[0]
        assert "ipython-input" in f_cache_relative_directory

        f_cache_directory = tmpdir / "joblib" / f_cache_relative_directory

        if session_no == 0:
            # The cache should be empty as cached_f has not been called yet.
            assert os.listdir(f_cache_directory) == ["f"]
            assert os.listdir(f_cache_directory / "f") == []

            if call_before_reducing:
                cached_f(3)
                # Two files were just created, func_code.py, and a folder
                # containing the information (inputs hash/ouptput) of
                # cached_f(3)
                assert len(os.listdir(f_cache_directory / "f")) == 2

                # Now, testing  #1035: when calling a cached function, joblib
                # used to dynamically inspect the underlying function to
                # extract its source code (to verify it matches the source code
                # of the function as last inspected by joblib) -- however,
                # source code introspection fails for dynamic functions sent to
                # child processes - which would eventually make joblib clear
                # the cache associated to f
                Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
            else:
                # Submit the function to the joblib child processes, although
                # the function has never been called in the parent yet. This
                # triggers a specific code branch inside
                # MemorizedFunc.__reduce__.
                Parallel(n_jobs=2)(delayed(cached_f)(i) for i in [1, 2])
                # Ensure the child process has time to close the file.
                # Wait up to 5 seconds for slow CI runs
                for _ in range(25):
                    if len(os.listdir(f_cache_directory / "f")) == 3:
                        break
                    time.sleep(0.2)  # pragma: no cover
                assert len(os.listdir(f_cache_directory / "f")) == 3

                cached_f(3)

            # Making sure f's cache does not get cleared after the parallel
            # calls, and contains ALL cached functions calls (f(1), f(2), f(3))
            # and 'func_code.py'
            assert len(os.listdir(f_cache_directory / "f")) == 4
        else:
            # For the second session, there should be an already existing cache
            assert len(os.listdir(f_cache_directory / "f")) == 4

            cached_f(3)

            # The previous cache should not be invalidated after calling the
            # function in a new session
            assert len(os.listdir(f_cache_directory / "f")) == 4


def test_no_memory():
    """Test memory with location=None: no memoize"""
    accumulator = list()

    def ff(arg):
        accumulator.append(1)
        return arg

    memory = Memory(location=None, verbose=0)
    gg = memory.cache(ff)
    for _ in range(4):
        current_accumulator = len(accumulator)
        gg(1)
        assert len(accumulator) == current_accumulator + 1


def test_memory_kwarg(tmpdir):
    "Test memory with a function with keyword arguments."
    accumulator = list()

    def g(arg1=None, arg2=1):
        accumulator.append(1)
        return arg1

    check_identity_lazy(g, accumulator, tmpdir.strpath)

    memory = Memory(location=tmpdir.strpath, verbose=0)
    g = memory.cache(g)
    # Smoke test with an explicit keyword argument:
    assert g(arg1=30, arg2=2) == 30


def test_memory_lambda(tmpdir):
    "Test memory with a function with a lambda."
    accumulator = list()

    def helper(x):
        """A helper function to define l as a lambda."""
        accumulator.append(1)
        return x

    check_identity_lazy(lambda x: helper(x), accumulator, tmpdir.strpath)


def test_memory_name_collision(tmpdir):
    "Check that name collisions with functions will raise warnings"
    memory = Memory(location=tmpdir.strpath, verbose=0)

    @memory.cache
    def name_collision(x):
        """A first function called name_collision"""
        return x

    a = name_collision

    @memory.cache
    def name_collision(x):
        """A second function called name_collision"""
        return x

    b = name_collision

    with warns(JobLibCollisionWarning) as warninfo:
        a(1)
        b(1)

    assert len(warninfo) == 1
    assert "collision" in str(warninfo[0].message)


def test_memory_warning_lambda_collisions(tmpdir):
    # Check that multiple use of lambda will raise collisions
    memory = Memory(location=tmpdir.strpath, verbose=0)
    a = memory.cache(lambda x: x)
    b = memory.cache(lambda x: x + 1)

    with warns(JobLibCollisionWarning) as warninfo:
        assert a(0) == 0
        assert b(1) == 2
        assert a(1) == 1

    # In recent Python versions, we can retrieve the code of lambdas,
    # thus nothing is raised
    assert len(warninfo) == 4


def test_memory_warning_collision_detection(tmpdir):
    # Check that collisions impossible to detect will raise appropriate
    # warnings.
    memory = Memory(location=tmpdir.strpath, verbose=0)
    a1 = eval("lambda x: x")
    a1 = memory.cache(a1)
    b1 = eval("lambda x: x+1")
    b1 = memory.cache(b1)

    with warns(JobLibCollisionWarning) as warninfo:
        a1(1)
        b1(1)
        a1(0)

    assert len(warninfo) == 2
    assert "cannot detect" in str(warninfo[0].message).lower()


def test_memory_partial(tmpdir):
    "Test memory with functools.partial."
    accumulator = list()

    def func(x, y):
        """A helper function to define l as a lambda."""
        accumulator.append(1)
        return y

    import functools

    function = functools.partial(func, 1)

    check_identity_lazy(function, accumulator, tmpdir.strpath)


def test_memory_eval(tmpdir):
    "Smoke test memory with a function with a function defined in an eval."
    memory = Memory(location=tmpdir.strpath, verbose=0)

    m = eval("lambda x: x")
    mm = memory.cache(m)

    assert mm(1) == 1


def count_and_append(x=[]):
    """A function with a side effect in its arguments.

    Return the length of its argument and append one element.
    """
    len_x = len(x)
    x.append(None)
    return len_x


def test_argument_change(tmpdir):
    """Check that if a function has a side effect in its arguments, it
    should use the hash of changing arguments.
    """
    memory = Memory(location=tmpdir.strpath, verbose=0)
    func = memory.cache(count_and_append)
    # call the function for the first time, is should cache it with
    # argument x=[]
    assert func() == 0
    # the second time the argument is x=[None], which is not cached
    # yet, so the functions should be called a second time
    assert func() == 1


@with_numpy
@parametrize("mmap_mode", [None, "r"])
def test_memory_numpy(tmpdir, mmap_mode):
    "Test memory with a function with numpy arrays."
    accumulator = list()

    def n(arg=None):
        accumulator.append(1)
        return arg

    memory = Memory(location=tmpdir.strpath, mmap_mode=mmap_mode, verbose=0)
    cached_n = memory.cache(n)

    rnd = np.random.RandomState(0)
    for i in range(3):
        a = rnd.random_sample((10, 10))
        for _ in range(3):
            assert np.all(cached_n(a) == a)
            assert len(accumulator) == i + 1


@with_numpy
def test_memory_numpy_check_mmap_mode(tmpdir, monkeypatch):
    """Check that mmap_mode is respected even at the first call"""

    memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)

    @memory.cache()
    def twice(a):
        return a * 2

    a = np.ones(3)

    b = twice(a)
    c = twice(a)

    assert isinstance(c, np.memmap)
    assert c.mode == "r"

    assert isinstance(b, np.memmap)
    assert b.mode == "r"

    # Corrupts the file,  Deleting b and c mmaps
    # is necessary to be able edit the file
    del b
    del c
    gc.collect()
    corrupt_single_cache_item(memory)

    # Make sure that corrupting the file causes recomputation and that
    # a warning is issued.
    recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
    d = twice(a)
    assert len(recorded_warnings) == 1
    exception_msg = "Exception while loading results"
    assert exception_msg in recorded_warnings[0]
    # Asserts that the recomputation returns a mmap
    assert isinstance(d, np.memmap)
    assert d.mode == "r"


def test_memory_exception(tmpdir):
    """Smoketest the exception handling of Memory."""
    memory = Memory(location=tmpdir.strpath, verbose=0)

    class MyException(Exception):
        pass

    @memory.cache
    def h(exc=0):
        if exc:
            raise MyException

    # Call once, to initialise the cache
    h()

    for _ in range(3):
        # Call 3 times, to be sure that the Exception is always raised
        with raises(MyException):
            h(1)


def test_memory_ignore(tmpdir):
    "Test the ignore feature of memory"
    memory = Memory(location=tmpdir.strpath, verbose=0)
    accumulator = list()

    @memory.cache(ignore=["y"])
    def z(x, y=1):
        accumulator.append(1)

    assert z.ignore == ["y"]

    z(0, y=1)
    assert len(accumulator) == 1
    z(0, y=1)
    assert len(accumulator) == 1
    z(0, y=2)
    assert len(accumulator) == 1


def test_memory_ignore_decorated(tmpdir):
    "Test the ignore feature of memory on a decorated function"
    memory = Memory(location=tmpdir.strpath, verbose=0)
    accumulator = list()

    def decorate(f):
        @functools.wraps(f)
        def wrapped(*args, **kwargs):
            return f(*args, **kwargs)

        return wrapped

    @memory.cache(ignore=["y"])
    @decorate
    def z(x, y=1):
        accumulator.append(1)

    assert z.ignore == ["y"]

    z(0, y=1)
    assert len(accumulator) == 1
    z(0, y=1)
    assert len(accumulator) == 1
    z(0, y=2)
    assert len(accumulator) == 1


def test_memory_args_as_kwargs(tmpdir):
    """Non-regression test against 0.12.0 changes.

    https://github.com/joblib/joblib/pull/751
    """
    memory = Memory(location=tmpdir.strpath, verbose=0)

    @memory.cache
    def plus_one(a):
        return a + 1

    # It's possible to call a positional arg as a kwarg.
    assert plus_one(1) == 2
    assert plus_one(a=1) == 2

    # However, a positional argument that joblib hadn't seen
    # before would cause a failure if it was passed as a kwarg.
    assert plus_one(a=2) == 3


@parametrize("ignore, verbose, mmap_mode", [(["x"], 100, "r"), ([], 10, None)])
def test_partial_decoration(tmpdir, ignore, verbose, mmap_mode):
    "Check cache may be called with kwargs before decorating"
    memory = Memory(location=tmpdir.strpath, verbose=0)

    @memory.cache(ignore=ignore, verbose=verbose, mmap_mode=mmap_mode)
    def z(x):
        pass

    assert z.ignore == ignore
    assert z._verbose == verbose
    assert z.mmap_mode == mmap_mode


def test_func_dir(tmpdir):
    # Test the creation of the memory cache directory for the function.
    memory = Memory(location=tmpdir.strpath, verbose=0)
    path = __name__.split(".")
    path.append("f")
    path = tmpdir.join("joblib", *path).strpath

    g = memory.cache(f)
    # Test that the function directory is created on demand
    func_id = _build_func_identifier(f)
    location = os.path.join(g.store_backend.location, func_id)
    assert location == path
    assert os.path.exists(path)
    assert memory.location == os.path.dirname(g.store_backend.location)

    # Test that the code is stored.
    # For the following test to be robust to previous execution, we clear
    # the in-memory store
    _FUNCTION_HASHES.clear()
    assert not g._check_previous_func_code()
    assert os.path.exists(os.path.join(path, "func_code.py"))
    assert g._check_previous_func_code()

    # Test the robustness to failure of loading previous results.
    args_id = g._get_args_id(1)
    output_dir = os.path.join(g.store_backend.location, g.func_id, args_id)
    a = g(1)
    assert os.path.exists(output_dir)
    os.remove(os.path.join(output_dir, "output.pkl"))
    assert a == g(1)


def test_persistence(tmpdir):
    # Test the memorized functions can be pickled and restored.
    memory = Memory(location=tmpdir.strpath, verbose=0)
    g = memory.cache(f)
    output = g(1)

    h = pickle.loads(pickle.dumps(g))

    args_id = h._get_args_id(1)
    output_dir = os.path.join(h.store_backend.location, h.func_id, args_id)
    assert os.path.exists(output_dir)
    assert output == h.store_backend.load_item([h.func_id, args_id])
    memory2 = pickle.loads(pickle.dumps(memory))
    assert memory.store_backend.location == memory2.store_backend.location

    # Smoke test that pickling a memory with location=None works
    memory = Memory(location=None, verbose=0)
    pickle.loads(pickle.dumps(memory))
    g = memory.cache(f)
    gp = pickle.loads(pickle.dumps(g))
    gp(1)


@pytest.mark.parametrize("consider_cache_valid", [True, False])
def test_check_call_in_cache(tmpdir, consider_cache_valid):
    for func in (
        MemorizedFunc(
            f, tmpdir.strpath, cache_validation_callback=lambda _: consider_cache_valid
        ),
        Memory(location=tmpdir.strpath, verbose=0).cache(
            f, cache_validation_callback=lambda _: consider_cache_valid
        ),
    ):
        result = func.check_call_in_cache(2)
        assert isinstance(result, bool)
        assert not result
        assert func(2) == 5
        result = func.check_call_in_cache(2)
        assert isinstance(result, bool)
        assert result == consider_cache_valid
        func.clear()

    func = NotMemorizedFunc(f)
    assert not func.check_call_in_cache(2)


def test_call_and_shelve(tmpdir):
    # Test MemorizedFunc outputting a reference to cache.

    for func, Result in zip(
        (
            MemorizedFunc(f, tmpdir.strpath),
            NotMemorizedFunc(f),
            Memory(location=tmpdir.strpath, verbose=0).cache(f),
            Memory(location=None).cache(f),
        ),
        (MemorizedResult, NotMemorizedResult, MemorizedResult, NotMemorizedResult),
    ):
        assert func(2) == 5
        result = func.call_and_shelve(2)
        assert isinstance(result, Result)
        assert result.get() == 5

        result.clear()
        with raises(KeyError):
            result.get()
        result.clear()  # Do nothing if there is no cache.


def test_call_and_shelve_lazily_load_stored_result(tmpdir):
    """Check call_and_shelve only load stored data if needed."""
    test_access_time_file = tmpdir.join("test_access")
    test_access_time_file.write("test_access")
    test_access_time = os.stat(test_access_time_file.strpath).st_atime
    # check file system access time stats resolution is lower than test wait
    # timings.
    time.sleep(0.5)
    assert test_access_time_file.read() == "test_access"

    if test_access_time == os.stat(test_access_time_file.strpath).st_atime:
        # Skip this test when access time cannot be retrieved with enough
        # precision from the file system (e.g. NTFS on windows).
        pytest.skip("filesystem does not support fine-grained access time attribute")

    memory = Memory(location=tmpdir.strpath, verbose=0)
    func = memory.cache(f)
    args_id = func._get_args_id(2)
    result_path = os.path.join(
        memory.store_backend.location, func.func_id, args_id, "output.pkl"
    )
    assert func(2) == 5
    first_access_time = os.stat(result_path).st_atime
    time.sleep(1)

    # Should not access the stored data
    result = func.call_and_shelve(2)
    assert isinstance(result, MemorizedResult)
    assert os.stat(result_path).st_atime == first_access_time
    time.sleep(1)

    # Read the stored data => last access time is greater than first_access
    assert result.get() == 5
    assert os.stat(result_path).st_atime > first_access_time


def test_memorized_pickling(tmpdir):
    for func in (MemorizedFunc(f, tmpdir.strpath), NotMemorizedFunc(f)):
        filename = tmpdir.join("pickling_test.dat").strpath
        result = func.call_and_shelve(2)
        with open(filename, "wb") as fp:
            pickle.dump(result, fp)
        with open(filename, "rb") as fp:
            result2 = pickle.load(fp)
        assert result2.get() == result.get()
        os.remove(filename)


def test_memorized_repr(tmpdir):
    func = MemorizedFunc(f, tmpdir.strpath)
    result = func.call_and_shelve(2)

    func2 = MemorizedFunc(f, tmpdir.strpath)
    result2 = func2.call_and_shelve(2)
    assert result.get() == result2.get()
    assert repr(func) == repr(func2)

    # Smoke test with NotMemorizedFunc
    func = NotMemorizedFunc(f)
    repr(func)
    repr(func.call_and_shelve(2))

    # Smoke test for message output (increase code coverage)
    func = MemorizedFunc(f, tmpdir.strpath, verbose=11, timestamp=time.time())
    result = func.call_and_shelve(11)
    result.get()

    func = MemorizedFunc(f, tmpdir.strpath, verbose=11)
    result = func.call_and_shelve(11)
    result.get()

    func = MemorizedFunc(f, tmpdir.strpath, verbose=5, timestamp=time.time())
    result = func.call_and_shelve(11)
    result.get()

    func = MemorizedFunc(f, tmpdir.strpath, verbose=5)
    result = func.call_and_shelve(11)
    result.get()


def test_memory_file_modification(capsys, tmpdir, monkeypatch):
    # Test that modifying a Python file after loading it does not lead to
    # Recomputation
    dir_name = tmpdir.mkdir("tmp_import").strpath
    filename = os.path.join(dir_name, "tmp_joblib_.py")
    content = "def f(x):\n    print(x)\n    return x\n"
    with open(filename, "w") as module_file:
        module_file.write(content)

    # Load the module:
    monkeypatch.syspath_prepend(dir_name)
    import tmp_joblib_ as tmp

    memory = Memory(location=tmpdir.strpath, verbose=0)
    f = memory.cache(tmp.f)
    # First call f a few times
    f(1)
    f(2)
    f(1)

    # Now modify the module where f is stored without modifying f
    with open(filename, "w") as module_file:
        module_file.write("\n\n" + content)

    # And call f a couple more times
    f(1)
    f(1)

    # Flush the .pyc files
    shutil.rmtree(dir_name)
    os.mkdir(dir_name)
    # Now modify the module where f is stored, modifying f
    content = 'def f(x):\n    print("x=%s" % x)\n    return x\n'
    with open(filename, "w") as module_file:
        module_file.write(content)

    # And call f more times prior to reloading: the cache should not be
    # invalidated at this point as the active function definition has not
    # changed in memory yet.
    f(1)
    f(1)

    # Now reload
    sys.stdout.write("Reloading\n")
    sys.modules.pop("tmp_joblib_")
    import tmp_joblib_ as tmp

    f = memory.cache(tmp.f)

    # And call f more times
    f(1)
    f(1)

    out, err = capsys.readouterr()
    assert out == "1\n2\nReloading\nx=1\n"


def _function_to_cache(a, b):
    # Just a place holder function to be mutated by tests
    pass


def _sum(a, b):
    return a + b


def _product(a, b):
    return a * b


def test_memory_in_memory_function_code_change(tmpdir):
    _function_to_cache.__code__ = _sum.__code__

    memory = Memory(location=tmpdir.strpath, verbose=0)
    f = memory.cache(_function_to_cache)

    assert f(1, 2) == 3
    assert f(1, 2) == 3

    with warns(JobLibCollisionWarning):
        # Check that inline function modification triggers a cache invalidation
        _function_to_cache.__code__ = _product.__code__
        assert f(1, 2) == 2
        assert f(1, 2) == 2


def test_clear_memory_with_none_location():
    memory = Memory(location=None)
    memory.clear()


def func_with_kwonly_args(a, b, *, kw1="kw1", kw2="kw2"):
    return a, b, kw1, kw2


def func_with_signature(a: int, b: float) -> float:
    return a + b


def test_memory_func_with_kwonly_args(tmpdir):
    memory = Memory(location=tmpdir.strpath, verbose=0)
    func_cached = memory.cache(func_with_kwonly_args)

    assert func_cached(1, 2, kw1=3) == (1, 2, 3, "kw2")

    # Making sure that providing a keyword-only argument by
    # position raises an exception
    with raises(ValueError) as excinfo:
        func_cached(1, 2, 3, kw2=4)
    excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")

    # Keyword-only parameter passed by position with cached call
    # should still raise ValueError
    func_cached(1, 2, kw1=3, kw2=4)

    with raises(ValueError) as excinfo:
        func_cached(1, 2, 3, kw2=4)
    excinfo.match("Keyword-only parameter 'kw1' was passed as positional parameter")

    # Test 'ignore' parameter
    func_cached = memory.cache(func_with_kwonly_args, ignore=["kw2"])
    assert func_cached(1, 2, kw1=3, kw2=4) == (1, 2, 3, 4)
    assert func_cached(1, 2, kw1=3, kw2="ignored") == (1, 2, 3, 4)


def test_memory_func_with_signature(tmpdir):
    memory = Memory(location=tmpdir.strpath, verbose=0)
    func_cached = memory.cache(func_with_signature)

    assert func_cached(1, 2.0) == 3.0


def _setup_toy_cache(tmpdir, num_inputs=10):
    memory = Memory(location=tmpdir.strpath, verbose=0)

    @memory.cache()
    def get_1000_bytes(arg):
        return "a" * 1000

    inputs = list(range(num_inputs))
    for arg in inputs:
        get_1000_bytes(arg)

    func_id = _build_func_identifier(get_1000_bytes)
    hash_dirnames = [get_1000_bytes._get_args_id(arg) for arg in inputs]

    full_hashdirs = [
        os.path.join(get_1000_bytes.store_backend.location, func_id, dirname)
        for dirname in hash_dirnames
    ]
    return memory, full_hashdirs, get_1000_bytes


def test__get_items(tmpdir):
    memory, expected_hash_dirs, _ = _setup_toy_cache(tmpdir)
    items = memory.store_backend.get_items()
    hash_dirs = [ci.path for ci in items]
    assert set(hash_dirs) == set(expected_hash_dirs)

    def get_files_size(directory):
        full_paths = [os.path.join(directory, fn) for fn in os.listdir(directory)]
        return sum(os.path.getsize(fp) for fp in full_paths)

    expected_hash_cache_sizes = [get_files_size(hash_dir) for hash_dir in hash_dirs]
    hash_cache_sizes = [ci.size for ci in items]
    assert hash_cache_sizes == expected_hash_cache_sizes

    output_filenames = [os.path.join(hash_dir, "output.pkl") for hash_dir in hash_dirs]

    expected_last_accesses = [
        datetime.datetime.fromtimestamp(os.path.getatime(fn)) for fn in output_filenames
    ]
    last_accesses = [ci.last_access for ci in items]
    assert last_accesses == expected_last_accesses


def test__get_items_to_delete(tmpdir):
    # test empty cache
    memory, _, _ = _setup_toy_cache(tmpdir, num_inputs=0)
    items_to_delete = memory.store_backend._get_items_to_delete("1K")
    assert items_to_delete == []

    memory, expected_hash_cachedirs, _ = _setup_toy_cache(tmpdir)
    items = memory.store_backend.get_items()
    # bytes_limit set to keep only one cache item (each hash cache
    # folder is about 1000 bytes + metadata)
    items_to_delete = memory.store_backend._get_items_to_delete("2K")
    nb_hashes = len(expected_hash_cachedirs)
    assert set.issubset(set(items_to_delete), set(items))
    assert len(items_to_delete) == nb_hashes - 1

    # Sanity check bytes_limit=2048 is the same as bytes_limit='2K'
    items_to_delete_2048b = memory.store_backend._get_items_to_delete(2048)
    assert sorted(items_to_delete) == sorted(items_to_delete_2048b)

    # bytes_limit greater than the size of the cache
    items_to_delete_empty = memory.store_backend._get_items_to_delete("1M")
    assert items_to_delete_empty == []

    # All the cache items need to be deleted
    bytes_limit_too_small = 500
    items_to_delete_500b = memory.store_backend._get_items_to_delete(
        bytes_limit_too_small
    )
    assert set(items_to_delete_500b), set(items)

    # Test LRU property: surviving cache items should all have a more
    # recent last_access that the ones that have been deleted
    items_to_delete_6000b = memory.store_backend._get_items_to_delete(6000)
    surviving_items = set(items).difference(items_to_delete_6000b)

    assert max(ci.last_access for ci in items_to_delete_6000b) <= min(
        ci.last_access for ci in surviving_items
    )


def test_memory_reduce_size_bytes_limit(tmpdir):
    memory, _, _ = _setup_toy_cache(tmpdir)
    ref_cache_items = memory.store_backend.get_items()

    # By default memory.bytes_limit is None and reduce_size is a noop
    memory.reduce_size()
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # No cache items deleted if bytes_limit greater than the size of
    # the cache
    memory.reduce_size(bytes_limit="1M")
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # bytes_limit is set so that only two cache items are kept
    memory.reduce_size(bytes_limit="3K")
    cache_items = memory.store_backend.get_items()
    assert set.issubset(set(cache_items), set(ref_cache_items))
    assert len(cache_items) == 2

    # bytes_limit set so that no cache item is kept
    bytes_limit_too_small = 500
    memory.reduce_size(bytes_limit=bytes_limit_too_small)
    cache_items = memory.store_backend.get_items()
    assert cache_items == []


def test_memory_reduce_size_items_limit(tmpdir):
    memory, _, _ = _setup_toy_cache(tmpdir)
    ref_cache_items = memory.store_backend.get_items()

    # By default reduce_size is a noop
    memory.reduce_size()
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # No cache items deleted if items_limit greater than the size of
    # the cache
    memory.reduce_size(items_limit=10)
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # items_limit is set so that only two cache items are kept
    memory.reduce_size(items_limit=2)
    cache_items = memory.store_backend.get_items()
    assert set.issubset(set(cache_items), set(ref_cache_items))
    assert len(cache_items) == 2

    # item_limit set so that no cache item is kept
    memory.reduce_size(items_limit=0)
    cache_items = memory.store_backend.get_items()
    assert cache_items == []


def test_memory_reduce_size_age_limit(tmpdir):
    import datetime
    import time

    memory, _, put_cache = _setup_toy_cache(tmpdir)
    ref_cache_items = memory.store_backend.get_items()

    # By default reduce_size is a noop
    memory.reduce_size()
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # No cache items deleted if age_limit big.
    memory.reduce_size(age_limit=datetime.timedelta(days=1))
    cache_items = memory.store_backend.get_items()
    assert sorted(ref_cache_items) == sorted(cache_items)

    # age_limit is set so that only two cache items are kept
    time.sleep(1)
    put_cache(-1)
    put_cache(-2)
    memory.reduce_size(age_limit=datetime.timedelta(seconds=1))
    cache_items = memory.store_backend.get_items()
    assert not set.issubset(set(cache_items), set(ref_cache_items))
    assert len(cache_items) == 2

    # ensure age_limit is forced to be positive
    with pytest.raises(ValueError, match="has to be a positive"):
        memory.reduce_size(age_limit=datetime.timedelta(seconds=-1))

    # age_limit set so that no cache item is kept
    time.sleep(0.001)  # make sure the age is different
    memory.reduce_size(age_limit=datetime.timedelta(seconds=0))
    cache_items = memory.store_backend.get_items()
    assert cache_items == []


def test_memory_clear(tmpdir):
    memory, _, g = _setup_toy_cache(tmpdir)
    memory.clear()

    assert os.listdir(memory.store_backend.location) == []

    # Check that the cache for functions hash is also reset.
    assert not g._check_previous_func_code(stacklevel=4)


def fast_func_with_complex_output():
    complex_obj = ["a" * 1000] * 1000
    return complex_obj


def fast_func_with_conditional_complex_output(complex_output=True):
    complex_obj = {str(i): i for i in range(int(1e5))}
    return complex_obj if complex_output else "simple output"


@with_multiprocessing
def test_cached_function_race_condition_when_persisting_output(tmpdir, capfd):
    # Test race condition where multiple processes are writing into
    # the same output.pkl. See
    # https://github.com/joblib/joblib/issues/490 for more details.
    memory = Memory(location=tmpdir.strpath)
    func_cached = memory.cache(fast_func_with_complex_output)

    Parallel(n_jobs=2)(delayed(func_cached)() for i in range(3))

    stdout, stderr = capfd.readouterr()

    # Checking both stdout and stderr (ongoing PR #434 may change
    # logging destination) to make sure there is no exception while
    # loading the results
    exception_msg = "Exception while loading results"
    assert exception_msg not in stdout
    assert exception_msg not in stderr


@with_multiprocessing
def test_cached_function_race_condition_when_persisting_output_2(tmpdir, capfd):
    # Test race condition in first attempt at solving
    # https://github.com/joblib/joblib/issues/490. The race condition
    # was due to the delay between seeing the cache directory created
    # (interpreted as the result being cached) and the output.pkl being
    # pickled.
    memory = Memory(location=tmpdir.strpath)
    func_cached = memory.cache(fast_func_with_conditional_complex_output)

    Parallel(n_jobs=2)(
        delayed(func_cached)(True if i % 2 == 0 else False) for i in range(3)
    )

    stdout, stderr = capfd.readouterr()

    # Checking both stdout and stderr (ongoing PR #434 may change
    # logging destination) to make sure there is no exception while
    # loading the results
    exception_msg = "Exception while loading results"
    assert exception_msg not in stdout
    assert exception_msg not in stderr


def test_memory_recomputes_after_an_error_while_loading_results(tmpdir, monkeypatch):
    memory = Memory(location=tmpdir.strpath)

    def func(arg):
        # This makes sure that the timestamp returned by two calls of
        # func are different. This is needed on Windows where
        # time.time resolution may not be accurate enough
        time.sleep(0.01)
        return arg, time.time()

    cached_func = memory.cache(func)
    input_arg = "arg"
    arg, timestamp = cached_func(input_arg)

    # Make sure the function is correctly cached
    assert arg == input_arg

    # Corrupting output.pkl to make sure that an error happens when
    # loading the cached result
    corrupt_single_cache_item(memory)

    # Make sure that corrupting the file causes recomputation and that
    # a warning is issued.
    recorded_warnings = monkeypatch_cached_func_warn(cached_func, monkeypatch)
    recomputed_arg, recomputed_timestamp = cached_func(arg)
    assert len(recorded_warnings) == 1
    exception_msg = "Exception while loading results"
    assert exception_msg in recorded_warnings[0]
    assert recomputed_arg == arg
    assert recomputed_timestamp > timestamp

    # Corrupting output.pkl to make sure that an error happens when
    # loading the cached result
    corrupt_single_cache_item(memory)
    reference = cached_func.call_and_shelve(arg)
    try:
        reference.get()
        raise AssertionError(
            "It normally not possible to load a corrupted MemorizedResult"
        )
    except KeyError as e:
        message = "is corrupted"
        assert message in str(e.args)


class IncompleteStoreBackend(StoreBackendBase):
    """This backend cannot be instantiated and should raise a TypeError."""

    pass


class DummyStoreBackend(StoreBackendBase):
    """A dummy store backend that does nothing."""

    def _open_item(self, *args, **kwargs):
        """Open an item on store."""
        "Does nothing"

    def _item_exists(self, location):
        """Check if an item location exists."""
        "Does nothing"

    def _move_item(self, src, dst):
        """Move an item from src to dst in store."""
        "Does nothing"

    def create_location(self, location):
        """Create location on store."""
        "Does nothing"

    def exists(self, obj):
        """Check if an object exists in the store"""
        return False

    def clear_location(self, obj):
        """Clear object on store"""
        "Does nothing"

    def get_items(self):
        """Returns the whole list of items available in cache."""
        return []

    def configure(self, location, *args, **kwargs):
        """Configure the store"""
        "Does nothing"


@parametrize("invalid_prefix", [None, dict(), list()])
def test_register_invalid_store_backends_key(invalid_prefix):
    # verify the right exceptions are raised when passing a wrong backend key.
    with raises(ValueError) as excinfo:
        register_store_backend(invalid_prefix, None)
    excinfo.match(r"Store backend name should be a string*")


def test_register_invalid_store_backends_object():
    # verify the right exceptions are raised when passing a wrong backend
    # object.
    with raises(ValueError) as excinfo:
        register_store_backend("fs", None)
    excinfo.match(r"Store backend should inherit StoreBackendBase*")


def test_memory_default_store_backend():
    # test an unknown backend falls back into a FileSystemStoreBackend
    with raises(TypeError) as excinfo:
        Memory(location="/tmp/joblib", backend="unknown")
    excinfo.match(r"Unknown location*")


def test_warning_on_unknown_location_type():
    class NonSupportedLocationClass:
        pass

    unsupported_location = NonSupportedLocationClass()

    with warns(UserWarning) as warninfo:
        _store_backend_factory("local", location=unsupported_location)

    expected_mesage = (
        "Instantiating a backend using a "
        "NonSupportedLocationClass as a location is not "
        "supported by joblib"
    )
    assert expected_mesage in str(warninfo[0].message)


def test_instanciate_incomplete_store_backend():
    # Verify that registering an external incomplete store backend raises an
    # exception when one tries to instantiate it.
    backend_name = "isb"
    register_store_backend(backend_name, IncompleteStoreBackend)
    assert (backend_name, IncompleteStoreBackend) in _STORE_BACKENDS.items()
    with raises(TypeError) as excinfo:
        _store_backend_factory(backend_name, "fake_location")
    excinfo.match(
        r"Can't instantiate abstract class IncompleteStoreBackend "
        "(without an implementation for|with) abstract methods*"
    )


def test_dummy_store_backend():
    # Verify that registering an external store backend works.

    backend_name = "dsb"
    register_store_backend(backend_name, DummyStoreBackend)
    assert (backend_name, DummyStoreBackend) in _STORE_BACKENDS.items()

    backend_obj = _store_backend_factory(backend_name, "dummy_location")
    assert isinstance(backend_obj, DummyStoreBackend)


def test_instanciate_store_backend_with_pathlib_path():
    # Instantiate a FileSystemStoreBackend using a pathlib.Path object
    path = pathlib.Path("some_folder")
    backend_obj = _store_backend_factory("local", path)
    try:
        assert backend_obj.location == "some_folder"
    finally:  # remove cache folder after test
        shutil.rmtree("some_folder", ignore_errors=True)


def test_filesystem_store_backend_repr(tmpdir):
    # Verify string representation of a filesystem store backend.

    repr_pattern = 'FileSystemStoreBackend(location="{location}")'
    backend = FileSystemStoreBackend()
    assert backend.location is None

    repr(backend)  # Should not raise an exception

    assert str(backend) == repr_pattern.format(location=None)

    # backend location is passed explicitly via the configure method (called
    # by the internal _store_backend_factory function)
    backend.configure(tmpdir.strpath)

    assert str(backend) == repr_pattern.format(location=tmpdir.strpath)

    repr(backend)  # Should not raise an exception


def test_memory_objects_repr(tmpdir):
    # Verify printable reprs of MemorizedResult, MemorizedFunc and Memory.

    def my_func(a, b):
        return a + b

    memory = Memory(location=tmpdir.strpath, verbose=0)
    memorized_func = memory.cache(my_func)

    memorized_func_repr = "MemorizedFunc(func={func}, location={location})"

    assert str(memorized_func) == memorized_func_repr.format(
        func=my_func, location=memory.store_backend.location
    )

    memorized_result = memorized_func.call_and_shelve(42, 42)

    memorized_result_repr = (
        'MemorizedResult(location="{location}", func="{func}", args_id="{args_id}")'
    )

    assert str(memorized_result) == memorized_result_repr.format(
        location=memory.store_backend.location,
        func=memorized_result.func_id,
        args_id=memorized_result.args_id,
    )

    assert str(memory) == "Memory(location={location})".format(
        location=memory.store_backend.location
    )


def test_memorized_result_pickle(tmpdir):
    # Verify a MemoryResult object can be pickled/depickled. Non regression
    # test introduced following issue
    # https://github.com/joblib/joblib/issues/747

    memory = Memory(location=tmpdir.strpath)

    @memory.cache
    def g(x):
        return x**2

    memorized_result = g.call_and_shelve(4)
    memorized_result_pickle = pickle.dumps(memorized_result)
    memorized_result_loads = pickle.loads(memorized_result_pickle)

    assert (
        memorized_result.store_backend.location
        == memorized_result_loads.store_backend.location
    )
    assert memorized_result.func == memorized_result_loads.func
    assert memorized_result.args_id == memorized_result_loads.args_id
    assert str(memorized_result) == str(memorized_result_loads)


def compare(left, right, ignored_attrs=None):
    if ignored_attrs is None:
        ignored_attrs = []

    left_vars = vars(left)
    right_vars = vars(right)
    assert set(left_vars.keys()) == set(right_vars.keys())
    for attr in left_vars.keys():
        if attr in ignored_attrs:
            continue
        assert left_vars[attr] == right_vars[attr]


@pytest.mark.parametrize(
    "memory_kwargs",
    [
        {"compress": 3, "verbose": 2},
        {"mmap_mode": "r", "verbose": 5, "backend_options": {"parameter": "unused"}},
    ],
)
def test_memory_pickle_dump_load(tmpdir, memory_kwargs):
    memory = Memory(location=tmpdir.strpath, **memory_kwargs)

    memory_reloaded = pickle.loads(pickle.dumps(memory))

    # Compare Memory instance before and after pickle roundtrip
    compare(memory.store_backend, memory_reloaded.store_backend)
    compare(
        memory,
        memory_reloaded,
        ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
    )
    assert hash(memory) == hash(memory_reloaded)

    func_cached = memory.cache(f)

    func_cached_reloaded = pickle.loads(pickle.dumps(func_cached))

    # Compare MemorizedFunc instance before/after pickle roundtrip
    compare(func_cached.store_backend, func_cached_reloaded.store_backend)
    compare(
        func_cached,
        func_cached_reloaded,
        ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
    )
    assert hash(func_cached) == hash(func_cached_reloaded)

    # Compare MemorizedResult instance before/after pickle roundtrip
    memorized_result = func_cached.call_and_shelve(1)
    memorized_result_reloaded = pickle.loads(pickle.dumps(memorized_result))

    compare(memorized_result.store_backend, memorized_result_reloaded.store_backend)
    compare(
        memorized_result,
        memorized_result_reloaded,
        ignored_attrs=set(["store_backend", "timestamp", "_func_code_id"]),
    )
    assert hash(memorized_result) == hash(memorized_result_reloaded)


def test_info_log(tmpdir, caplog):
    caplog.set_level(logging.INFO)
    x = 3

    memory = Memory(location=tmpdir.strpath, verbose=20)

    @memory.cache
    def f(x):
        return x**2

    _ = f(x)
    assert "Querying" in caplog.text
    caplog.clear()

    memory = Memory(location=tmpdir.strpath, verbose=0)

    @memory.cache
    def f(x):
        return x**2

    _ = f(x)
    assert "Querying" not in caplog.text
    caplog.clear()


class TestCacheValidationCallback:
    "Tests on parameter `cache_validation_callback`"

    def foo(self, x, d, delay=None):
        d["run"] = True
        if delay is not None:
            time.sleep(delay)
        return x * 2

    def test_invalid_cache_validation_callback(self, memory):
        "Test invalid values for `cache_validation_callback"
        match = "cache_validation_callback needs to be callable. Got True."
        with pytest.raises(ValueError, match=match):
            memory.cache(cache_validation_callback=True)

    @pytest.mark.parametrize("consider_cache_valid", [True, False])
    def test_constant_cache_validation_callback(self, memory, consider_cache_valid):
        "Test expiry of old results"
        f = memory.cache(
            self.foo,
            cache_validation_callback=lambda _: consider_cache_valid,
            ignore=["d"],
        )

        d1, d2 = {"run": False}, {"run": False}
        assert f(2, d1) == 4
        assert f(2, d2) == 4

        assert d1["run"]
        assert d2["run"] != consider_cache_valid

    def test_memory_only_cache_long_run(self, memory):
        "Test cache validity based on run duration."

        def cache_validation_callback(metadata):
            duration = metadata["duration"]
            if duration > 0.1:
                return True

        f = memory.cache(
            self.foo, cache_validation_callback=cache_validation_callback, ignore=["d"]
        )

        # Short run are not cached
        d1, d2 = {"run": False}, {"run": False}
        assert f(2, d1, delay=0) == 4
        assert f(2, d2, delay=0) == 4
        assert d1["run"]
        assert d2["run"]

        # Longer run are cached
        d1, d2 = {"run": False}, {"run": False}
        assert f(2, d1, delay=0.2) == 4
        assert f(2, d2, delay=0.2) == 4
        assert d1["run"]
        assert not d2["run"]

    def test_memory_expires_after(self, memory):
        "Test expiry of old cached results"

        f = memory.cache(
            self.foo, cache_validation_callback=expires_after(seconds=0.3), ignore=["d"]
        )

        d1, d2, d3 = {"run": False}, {"run": False}, {"run": False}
        assert f(2, d1) == 4
        assert f(2, d2) == 4
        time.sleep(0.5)
        assert f(2, d3) == 4

        assert d1["run"]
        assert not d2["run"]
        assert d3["run"]


class TestMemorizedFunc:
    "Tests for the MemorizedFunc and NotMemorizedFunc classes"

    @staticmethod
    def f(x, counter):
        counter[x] = counter.get(x, 0) + 1
        return counter[x]

    def test_call_method_memorized(self, memory):
        "Test calling the function"

        f = memory.cache(self.f, ignore=["counter"])

        counter = {}
        assert f(2, counter) == 1
        assert f(2, counter) == 1

        x, meta = f.call(2, counter)
        assert x == 2, "f has not been called properly"
        assert isinstance(meta, dict), (
            "Metadata are not returned by MemorizedFunc.call."
        )

    def test_call_method_not_memorized(self, memory):
        "Test calling the function"

        f = NotMemorizedFunc(self.f)

        counter = {}
        assert f(2, counter) == 1
        assert f(2, counter) == 2

        x, meta = f.call(2, counter)
        assert x == 3, "f has not been called properly"
        assert isinstance(meta, dict), (
            "Metadata are not returned by MemorizedFunc.call."
        )


@with_numpy
@parametrize(
    "location",
    [
        "test_cache_dir",
        pathlib.Path("test_cache_dir"),
        pathlib.Path("test_cache_dir").resolve(),
    ],
)
def test_memory_creates_gitignore(location):
    """Test that using the memory object automatically creates a `.gitignore` file
    within the new cache directory."""

    mem = Memory(location)
    arr = np.asarray([[1, 2, 3], [4, 5, 6]])
    costly_operation = mem.cache(np.square)
    costly_operation(arr)

    location = pathlib.Path(location)

    try:
        path_to_gitignore_file = os.path.join(location, ".gitignore")
        gitignore_file_content = "# Created by joblib automatically.\n*\n"
        with open(path_to_gitignore_file) as f:
            assert gitignore_file_content == f.read()

    finally:  # remove cache folder after test
        shutil.rmtree(location, ignore_errors=True)
