#include "logs_transmitter_mock.h"
#include "logs_transmitter_multithread.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/hash_set.h>

namespace NInfra::NPodAgent::NLogsTransmitterMultithreadTest {

static TLogger logger({});

const THashSet<TPushContainer> pushContainers = {
    TPushContainer({TPortoContainerName({"box_1"}, "workload_1"), {"1"}, {"1"}}),
    TPushContainer({TPortoContainerName({"box_2"}, "workload_2"), {"2"}, {"2"}}),
    TPushContainer({TPortoContainerName({"box_3"}, "workload_3"), {"3"}, {"3"}}),
    TPushContainer({TPortoContainerName({"box_4"}, "workload_4"), {"4"}, {"4"}}),
    TPushContainer({TPortoContainerName({"box_5"}, "workload_5"), {"5"}, {"5"}}),
    TPushContainer({TPortoContainerName({"box_6"}, "workload_6"), {"6"}, {"6"}}),
    TPushContainer({TPortoContainerName({"box_7"}, "workload_7"), {"7"}, {"7"}}),
    TPushContainer({TPortoContainerName({"box_8"}, "workload_8"), {"8"}, {"8"}})
};

struct TTestLogsTransmitter : public TMockLogstransmitter {
    void TransmitLogs(const THashSet<TPushContainer>& pushContainers) override {
        TGuard<TMutex> g(mutex);
        for (const auto& container: pushContainers) {
            if (PushContainers.contains(container)) {
                throw yexception() << "Container " << container.Container << "already processed";
            }
        }
        ++TransmitCalls;
        PushContainers.insert(pushContainers.begin(), pushContainers.end());
    }

    size_t TransmitCalls = 0;
    THashSet<TPushContainer> PushContainers;
    TMutex mutex;
};

class ITestLogsTransmitterMultithreadCanon {
public:
    ITestLogsTransmitterMultithreadCanon(const ui32 numOfThreads, TLogsTransmitterPtr logsTransmitter)
        : NumOfThreads_(numOfThreads)
        , LogsTransmitter_(logsTransmitter)
    {}

    virtual ~ITestLogsTransmitterMultithreadCanon() = default;

    TLogsTransmitterMultithreadPtr GetLogsTransmitterMultithread() {
        TAtomicSharedPtr<IThreadPool> mtpQueue_ = new TThreadPool();
        mtpQueue_->Start(NumOfThreads_);

        return new TLogsTransmitterMultithread(
            NumOfThreads_
            , LogsTransmitter_
            , mtpQueue_
            , logger.SpawnFrame()
        );
    }

    void RunLogsTransmitterMultithread() {
        GetLogsTransmitterMultithread()->TransmitLogs(pushContainers);
    }

    ui32 GetNumOfThreads() const {
        return NumOfThreads_;
    }

    void DoTest() {
        Test();
    }

protected:
    virtual void Test() = 0;

protected:
    const ui32 NumOfThreads_;
    TLogsTransmitterPtr LogsTransmitter_;
};

Y_UNIT_TEST_SUITE(LogsTransmitterMultithreadSuite) {
    Y_UNIT_TEST(TransmitLogsOkTest) {
        class TTest: public ITestLogsTransmitterMultithreadCanon {
        public:
            TTest(const ui32 numOfThreads, TLogsTransmitterPtr logsTransmitter)
                : ITestLogsTransmitterMultithreadCanon(numOfThreads, logsTransmitter)
            {}

        protected:
            void Test() override {
                auto logsTransmitter = GetLogsTransmitterMultithread();
                auto splittedPushContainers = logsTransmitter->Split(pushContainers, GetNumOfThreads());
                UNIT_ASSERT_EQUAL(splittedPushContainers.size(), 3);

                UNIT_ASSERT_EQUAL(splittedPushContainers[0].size(), 3);
                UNIT_ASSERT_EQUAL(splittedPushContainers[1].size(), 3);
                UNIT_ASSERT_EQUAL(splittedPushContainers[2].size(), 2);

                THashSet<TPushContainer> unionSplittedContainers;
                unionSplittedContainers.insert(splittedPushContainers[0].begin(), splittedPushContainers[0].end());
                unionSplittedContainers.insert(splittedPushContainers[1].begin(), splittedPushContainers[1].end());
                unionSplittedContainers.insert(splittedPushContainers[2].begin(), splittedPushContainers[2].end());

                UNIT_ASSERT_EQUAL(pushContainers, unionSplittedContainers);

                logsTransmitter->TransmitLogs(pushContainers);

                UNIT_ASSERT_EQUAL(((TTestLogsTransmitter*)LogsTransmitter_.Get())->TransmitCalls, 3);
                UNIT_ASSERT_EQUAL(((TTestLogsTransmitter*)LogsTransmitter_.Get())->PushContainers, pushContainers);
            }
        };

        ui32 numOfThreads = 3;
        TLogsTransmitterPtr logsTransmitter = new TTestLogsTransmitter();
        TTest test(numOfThreads, logsTransmitter);
        test.DoTest();
    }

    Y_UNIT_TEST(FewLogsTransmittersNotRunTest) {
        class TTest: public ITestLogsTransmitterMultithreadCanon {
        public:
            TTest(const ui32 numOfThreads, TLogsTransmitterPtr logsTransmitter)
                : ITestLogsTransmitterMultithreadCanon(numOfThreads, logsTransmitter)
            {}

        protected:
            void Test() override {
                auto logsTransmitter = GetLogsTransmitterMultithread();
                auto splittedPushContainers = logsTransmitter->Split(pushContainers, GetNumOfThreads());
                UNIT_ASSERT_EQUAL(splittedPushContainers.size(), 8);

                for (ui64 i = 0; i < splittedPushContainers.size(); ++i) {
                    UNIT_ASSERT_EQUAL(splittedPushContainers[i].size(), 1);
                }

                THashSet<TPushContainer> unionSplittedContainers;
                for (ui64 i = 0; i < splittedPushContainers.size(); ++i) {
                    unionSplittedContainers.insert(splittedPushContainers[i].begin(), splittedPushContainers[i].end());
                }

                UNIT_ASSERT_EQUAL(pushContainers, unionSplittedContainers);

                logsTransmitter->TransmitLogs(pushContainers);

                UNIT_ASSERT_EQUAL(((TTestLogsTransmitter*)LogsTransmitter_.Get())->TransmitCalls, 8);
                UNIT_ASSERT_EQUAL(((TTestLogsTransmitter*)LogsTransmitter_.Get())->PushContainers, pushContainers);
            }
        };

        ui32 numOfThreads = 10;
        TLogsTransmitterPtr logsTransmitter = new TTestLogsTransmitter();
        TTest test(numOfThreads, logsTransmitter);
        test.DoTest();
    }

    Y_UNIT_TEST(LogsTransmitterThrowExceptionTest) {
        struct TTestLogsTransmitterFails: public TTestLogsTransmitter {
            void TransmitLogs(const THashSet<TPushContainer>&) override {
                TGuard<TMutex> g(mutex);
                ++TransmitCalls;
                throw yexception() << "Transmit log fails";
            }
        };

        class TTest: public ITestLogsTransmitterMultithreadCanon {
        public:
            TTest(const ui32 numOfThreads, TLogsTransmitterPtr logsTransmitter)
                : ITestLogsTransmitterMultithreadCanon(numOfThreads, logsTransmitter)
            {}

        protected:
            void Test() override {
                RunLogsTransmitterMultithread();

                UNIT_ASSERT_EQUAL(((TTestLogsTransmitter*)LogsTransmitter_.Get())->TransmitCalls, 3);
            }
        };

        ui32 numOfThreads = 3;
        TLogsTransmitterPtr logsTransmitter = new TTestLogsTransmitterFails();
        TTest test(numOfThreads, logsTransmitter);
        test.DoTest();
    }

}

} //NInfra::NPodAgent::NLogsTransmitterMultithreadTest
