#include <library/cpp/testing/gtest/gtest.h>
#include <maps/libs/concurrent/include/latch.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/threadpool_wrapper.h>

#include <chrono>
#include <stdexcept>
#include <vector>

namespace maps::mrc::common::tests {
using namespace ::testing;

namespace {
constexpr int THREADS_NUM = 4;
constexpr int TASKS_NUM = 10;
} // namespace

TEST(Threadpool_should, test_threadpool) {
    concurrent::Latch latch(TASKS_NUM + 1);
    std::vector<bool> tasks(TASKS_NUM, false);
    std::vector<bool> tasksDone(TASKS_NUM, true);

    ThreadpoolWrapper threads(THREADS_NUM);
    for (int i = 0; i < TASKS_NUM; ++i) {
        threads->add([&, i]() {
            tasks[i] = true;
            latch.arrive();
        });
    }

    threads->join();
    latch.arriveAndWait();
    threads.checkExceptions();

    EXPECT_EQ(tasks, tasksDone);
}

TEST(Threadpool_should, test_threadpool_exception) {
    ThreadpoolWrapper threads(THREADS_NUM);
    for (int i = 0; i < TASKS_NUM; ++i) {
        threads->add([&, i]() {
            if (i == 1) {
                throw std::runtime_error("test");
            }
        });
    }

    threads->join();
    EXPECT_THROW(threads.checkExceptions(), std::runtime_error);
}

} // namespace maps::mrc::common::tests
