#include <mail/ratesrv/src/router/counters_worker.h>
#include <mail/ratesrv/src/router/network_agent.h>
#include <mail/ratesrv/src/router/node_manager.h>
#include <mail/ratesrv/src/context.h>

#include <mail/ratesrv/ut/mock/application.h>
#include <mail/ratesrv/ut/mock/storage.h>

#include <yplatform/task_context.h>
#include <yplatform/find.h>
#include <yplatform/yield.h>
#include <yplatform/time_traits.h>

#include <gtest/gtest.h>

#include <condition_variable>
#include <functional>
#include <mutex>
#include <memory>

namespace {

using namespace testing;
using namespace NRateSrv;
using namespace NRateSrv::NRouter;
using namespace NRateSrv::NMock;

class TNetworkAgentImpl : public INetworkAgent {
public:
    using TRequestCallback = std::function<ui64(size_t, NStorage::TRequest, EMessageTypes, TCallback)>;
    using TCompleteRequestCallback = std::function<TCallback(ui64, bool)>;

    TNetworkAgentImpl(
        TRequestCallback requestCallback,
        TCompleteRequestCallback completeCallback,
        TNodeManagerPtr nodeManager
    )
        : RequestCallback(std::move(requestCallback))
        , CompleteCallback(std::move(completeCallback))
        , NodeManager(std::move(nodeManager))
    {}

    ui64 SendRequest(
        TNodeManagerPtr,
        TContextPtr,
        size_t nodeNum,
        NStorage::TRequest request,
        EMessageTypes type,
        TCallback callback) override
    {
        return RequestCallback(nodeNum, std::move(request), type, std::move(callback));
    }

    TCallback CompleteRequest(ui64 requestId, bool success) override {
        return CompleteCallback(requestId, success);
    }

    TNodeManagerPtr GetNodeManager() {
        return NodeManager;
    }

private:
    TRequestCallback RequestCallback;
    TCompleteRequestCallback CompleteCallback;
    TNodeManagerPtr NodeManager;
};

} // namespace NRateSrv::NMock

class TTestCountersWorker : public Test {
protected:
    void SetUp() override {
        Request.emplace("id1", TCounterData{"group:limit:key1", 100});
        Request.emplace("id2", TCounterData{"group:limit:key2", 200});
        Request.emplace("id3", TCounterData{"invalid", 300});
    }

    bool CheckWithoutRemoteHosts(bool isIncrease) {
        bool success = true;
        TNodeManager::TNodes nodes{TNode{"localhost", true}};
        auto nodeManager = std::make_shared<TNodeManager>(std::move(nodes));
        auto agent = std::make_shared<TNetworkAgentImpl>(
            [&success](size_t, NStorage::TRequest, EMessageTypes, INetworkAgent::TCallback) -> ui64 {
                success = false;
                return 0;
            },
            [&success](ui64, bool) -> INetworkAgent::TCallback {
                success = false;
                return {};
            },
            nodeManager
        );
        RunWorker(agent, isIncrease);

        return success;
    }

    bool CheckWithRemoteHosts(bool isIncrease) {
        bool success = true;
        TNodeManager::TNodes nodes{TNode{"addr1", false}, TNode{"addr2", false}, TNode{"addr3", false}};
        auto nodeManager = std::make_shared<TNodeManager>(std::move(nodes));
        const size_t nodeCount = nodeManager->Count();
        std::vector<int> tryNums(nodeCount, 0);

        auto requestCallback = [isIncrease, &tryNums, nodeCount, this]
            (size_t nodeNum, NStorage::TRequest request, EMessageTypes, INetworkAgent::TCallback callback) -> ui64 {
                ui64 requestId = nodeCount * tryNums[nodeNum] + nodeNum;
                if (tryNums[nodeNum]++ == 0) {
                    return requestId;
                }
                auto storageReactor = yplatform::find_reactor("storage");
                boost::asio::post(
                    *storageReactor->io(),
                    [request = std::move(request), requestId, callback = std::move(callback), isIncrease, this] {
                        callback(requestId, MakeStorageResponse(request, isIncrease));
                    });
                return requestId;
            };

        auto agent = std::make_shared<TNetworkAgentImpl>(
            requestCallback,
            [&origSuccess = success, nodeCount](ui64 requestId, bool success) -> INetworkAgent::TCallback {
                if ((requestId < nodeCount && success) || (requestId >= nodeCount && !success)) {
                    origSuccess = false;
                }
                return {};
            },
            nodeManager);

        RunWorker(agent, isIncrease);

        return success;
    }

    void RunWorker(std::shared_ptr<TNetworkAgentImpl> agent, bool isIncrease) {
        TApplication::Instance();
        auto storage = yplatform::find<TStorage>("storage");
        storage->SetGetCallback([this](NStorage::TRequest request) -> NStorage::TResponse {
            return MakeStorageResponse(request, false);
        });
        storage->SetIncreaseCallback([this](NStorage::TRequest request) -> NStorage::TResponse {
            return MakeStorageResponse(request, true);
        });

        auto configuration = std::make_shared<TConfiguration>();
        configuration->SetTimeout(yplatform::time_traits::seconds(1));
        configuration->SetHashSeed(100);
        configuration->SetBucketCount(1000);
        configuration->SetNodeHashSalt("ratesrv_salt");
        configuration->SetBaseBanDuration(yplatform::time_traits::seconds(1));

        auto ctx = boost::make_shared<TContext>("testid");

        auto reactor = yplatform::find_reactor("global");
        auto worker = std::make_shared<TCountersWorker>(
            *reactor,
            configuration,
            agent,
            agent->GetNodeManager(),
            ctx,
            Request,
            isIncrease ? ERequestMode::Increase : ERequestMode::Get,
            [&origResponse = Response, this](TCounterResponse response) {
                origResponse = std::move(response);
                SetDone();
            }
        );

        yplatform::spawn(worker);
        WaitDone();
    }

    bool Compare(const TCounterData& data, const TCounterValue& value, bool isIncrease) {
        return
            value.State == ECounterState::Ok &&
            value.Current == (isIncrease ? data.Value : 1) &&
            value.Available == -100;
    }

    bool Compare(bool isIncrease) {
        if (Request.size() != Response.size()) {
            return false;
        }

        for (const auto& [id, data] : Request) {
            if (Response.count(id) == 0) {
                return false;
            }
        }

        return
            Compare(Request["id1"], Response["id1"], isIncrease) &&
            Compare(Request["id2"], Response["id2"], isIncrease) &&
            Response["id3"].State == ECounterState::InvalidName;
    }

    NStorage::TResponse MakeStorageResponse(const NStorage::TRequest& request, bool isIncrease) {
        NStorage::TResponse response;
        for (const auto& [groupName, limits] : request.GetRefGroups()) {
            for (const auto& [limitName, keys] : limits) {
                for (const auto& key : keys) {
                    response.emplace(key.Index, TCounterValue{ECounterState::Ok, isIncrease ? key.Value : 1, -100});
                }
            }
        }
        return response;
    }

private:
    void SetDone() {
        {
            std::lock_guard lock(Mutex);
            Done = true;
        }
        Cv.notify_one();
    }

    void WaitDone() {
        std::unique_lock lock(Mutex);
        Cv.wait(lock, [&done = Done]{ return done; });
    }

protected:
    TCounterRequest Request;
    TCounterResponse Response;

private:
    std::mutex Mutex;
    std::condition_variable Cv;
    bool Done = false;
};

TEST_F(TTestCountersWorker, GetWithoutRemoteHosts) {
    ASSERT_TRUE(CheckWithoutRemoteHosts(false));
    ASSERT_TRUE(Compare(false));
}

TEST_F(TTestCountersWorker, IncreaseWithoutRemoteHosts) {
    ASSERT_TRUE(CheckWithoutRemoteHosts(true));
    ASSERT_TRUE(Compare(true));
}

TEST_F(TTestCountersWorker, GetWithRemoteHosts) {
    ASSERT_TRUE(CheckWithRemoteHosts(false));
    ASSERT_TRUE(Compare(false));
}

TEST_F(TTestCountersWorker, IncreaseWithRemoteHosts) {
    ASSERT_TRUE(CheckWithRemoteHosts(true));
    ASSERT_TRUE(Compare(true));
}
