import infra.callisto.libraries.memoize as memoize


def test_cache():
    n_calls = 100
    size = 10
    args = []
    cached = memoize.get_cache(records_limit=size)

    @cached
    def f(x):
        args.append(x)
        return x

    assert len(args) == 0
    assert len(f._cache) == 0

    f(0)
    assert len(args) == 1
    assert len(f._cache) == 1

    f(0)
    assert len(args) == 1
    assert len(f._cache) == 1

    f(1)
    assert len(args) == 2
    assert len(f._cache) == 2

    f(1)
    assert len(args) == 2
    assert len(f._cache) == 2

    for i in range(n_calls):
        f(i)

    assert len(args) >= n_calls
    assert len(f._cache) <= size


def test_ttl_cache():
    args = []
    cached = memoize.get_cache(records_limit=100)

    @cached
    def f(x):
        args.append(x)
        return memoize.SaveWithTtl(x, ttl=0.1)

    import time

    assert len(args) == 0
    f(0)
    assert len(args) == 1
    f(0)
    assert len(args) == 1

    time.sleep(0.15)
    assert len(args) == 1
    f(0)
    assert len(args) == 2


def test_dont_save():
    args = []
    cached = memoize.get_cache(records_limit=100)

    @cached
    def f(x):
        args.append(x)
        return memoize.DontSave(x)

    assert len(args) == 0
    f(0)
    assert len(args) == 1
    f(0)
    assert len(args) == 2


def test_kw_args():
    args = []
    cached = memoize.get_cache(records_limit=100)

    @cached
    def f(x):
        args.append(x)
        return x

    assert len(args) == 0
    f(x=1)
    assert len(args) == 1
    f(x=1)
    assert len(args) == 1


def test_multiple_function_same_args():
    cached = memoize.get_cache(records_limit=100)

    @cached
    def f(x):
        return x

    @cached
    def g(x):
        return x + 1

    for i in list(range(10)) * 2:
        assert f(i) != g(i)
        assert f(i) + 1 == g(i)


def test_memoized():
    args = []

    @memoize.memoized
    def p(x):
        args.append(x)
        return x

    @memoize.memoized
    def q(x):
        args.append(x)
        return x + 1

    for i in list(range(10)) * 2:
        assert p(i) != q(i)
        assert p(i) + 1 == q(i)

    assert len(args) == 20

    for i in list(range(10)) * 2:
        assert p(i) != q(i)
        assert p(i) + 1 == q(i)

    assert len(args) == 20
