#include <library/cpp/testing/gtest/gtest.h>
#include <maps/libs/common/include/exception.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/retry.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/stopwatch.h>

#include <chrono>
#include <iostream>

namespace maps::mrc::common::tests {
using namespace ::testing;

TEST(Retry_should, test_retry_failure) {
    size_t number = 0;
    auto func = [&](int a, int b) {
        number++;
        if (number == 5) {
            return (a + b);
        }
        throw maps::Exception("Fail");
    };

    Stopwatch watch(Stopwatch::Running);
    try {
        retryOnException<maps::Exception>(
            RetryPolicy()
                .setInitialTimeout(std::chrono::milliseconds(100))
                .setMaxAttempts(4)
                .setTimeoutBackoff(2),
            //callable
            func,
            //args
            1, 2
        );
    } catch (const MaxRetryNumberReached& ex) {
        //last wait should not take place
        const size_t EXPECTED_TIMEOUT = 100 + 200 + 400;
        auto msPassed = watch.elapsed<std::chrono::milliseconds>();
        EXPECT_GE(msPassed, EXPECTED_TIMEOUT);
        EXPECT_LT(msPassed, EXPECTED_TIMEOUT * 2);
        return;
    }

    GTEST_FAIL() << "Unreachable code";
}

TEST(Retry_should, test_retry_ok) {
    size_t number = 0;
    auto func = [&](int a, int b, int c) {
        number++;
        if (number == 5) {
            return (a + b + c);
        }
        throw maps::Exception("Fail");
    };

    Stopwatch watch(Stopwatch::Running);
    EXPECT_NO_THROW({
        retryOnException<maps::Exception>(
            RetryPolicy()
                .setInitialTimeout(std::chrono::milliseconds(100))
                .setMaxAttempts(5)
                .setTimeoutBackoff(2),
            func,
            100, 500, 3
        );
    });

    //last wait should not take place
    const size_t EXPECTED_TIMEOUT = 100 + 200 + 400 + 800;
    auto msPassed = watch.elapsed<std::chrono::milliseconds>();
    EXPECT_GE(msPassed, EXPECTED_TIMEOUT);
    EXPECT_LT(msPassed, EXPECTED_TIMEOUT * 2);
}

} // namespace maps::mrc::common::tests
