#pragma once

#include "error_code.h"

#include <mail/nwsmtp/src/context.h>
#include <mail/nwsmtp/src/log.h>
#include <mail/ymod_ratesrv/helper/src/client_helper.h>
#include <mail/ymod_ratesrv/src/client.h>

#include <yplatform/coroutine.h>
#include <yplatform/find.h>
#include <yplatform/yield.h>

#include <functional>
#include <memory>
#include <sstream>
#include <type_traits>
#include <unordered_map>

namespace NNwSmtp::NRateSrv {

static const std::string RATESRV{"RATESRV"};

struct TGetRequestTag : public std::false_type {};
struct TUpdateRequestTag : public std::true_type {};

template<typename TRequestTag>
class TClient {
public:
    using TYieldCtx = yplatform::yield_context<TClient>;
    using TCallback = std::function<void(TErrorCode)>;

    TClient(
        TContextPtr context,
        bool writeLogForEveryCounter,
        bool writeFinalLogOnlyOnError,
        TCallback callback
    )
        : Context(std::move(context))
        , WriteLogForEveryCounter(writeLogForEveryCounter)
        , WriteFinalLogOnlyOnError(writeFinalLogOnlyOnError)
        , Callback(std::move(callback))
        , ClientHelper(yplatform::find<NYmodRateSrv::TClientHelper, std::shared_ptr>("ratesrv_client_helper"))
        , ErrorCode(EError::Accept)
    {}

    template <typename... TArgs>
    void AddPartToRequest(std::string target, const std::string& limitName, TArgs&&... args) {
        auto limitPtr = ClientHelper->GetLimit(limitName);
        if (!limitPtr || !limitPtr->IsEnabled()) {
            return;
        }
        const auto& limit = *limitPtr;
        auto [it, success] = CounterContexts.emplace(
            std::to_string(CounterContexts.size()),
            TCounterContext{std::move(target), std::move(limitPtr)});

        if constexpr (TRequestTag::value) {
            ClientHelper->AddKeyToIncreaseRequest(Request, limit, it->first, std::forward<TArgs>(args)...);
        } else {
            ClientHelper->AddKeyToGetRequest(Request, limit, it->first, std::forward<TArgs>(args)...);
        }
    }

    void operator()(TYieldCtx yieldCtx, boost::system::error_code ec = {}, NYmodRateSrv::TResponse response = {}) {
        reenter (yieldCtx) {
            if (Request.Size() == 0) {
                yield break;
            }
            if (Request.HasPartWithDomain("zero")) {
                ErrorCode = EError::Discard;
                yield break;
            }
            yield SendRequest(yieldCtx);
            if (!ec) {
                HandleResponse(std::move(response));
            }
            WriteFinalLog(ec);
        }

        if (yieldCtx.is_complete()) {
            Callback(std::move(ErrorCode));
        }
    }

private:
    void SendRequest(TYieldCtx yieldCtx) {
        auto client = yplatform::find<NYmodRateSrv::IClient>("ratesrv_client");
        std::string where{RATESRV + ":"};
        if constexpr (TRequestTag::value) {
            client->SendIncreaseRequest(Context->CreateTaskContext(std::move(where)), Request,
                std::move(yieldCtx));
        } else {
            client->SendGetRequest(Context->CreateTaskContext(std::move(where)), Request,
                std::move(yieldCtx));
        }
    }

    void HandleResponse(NYmodRateSrv::TResponse response) {
        std::ostringstream logStream;
        for (auto it = response.begin(); it != response.end();) {
            const auto& [id, counter] = *it;
            auto contextIt = CounterContexts.find(id);
            if (contextIt == CounterContexts.end()) {
                ++it;
                continue;
            }
            const auto& context = contextIt->second;
            if (WriteLogForEveryCounter) {
                WriteCounterLog(logStream, counter, context.Target, context.Limit->GetLimitName());
            }
            if (!context.Limit->IsDryRun()) {
                if constexpr (TRequestTag::value) {
                    if (counter.State == NYmodRateSrv::ECounterState::Exceeded) {
                        ErrorCode = EError::Reject;
                    }
                } else {
                    if (counter.State == NYmodRateSrv::ECounterState::Ok && counter.Available <= 0) {
                        ErrorCode = EError::Reject;
                    }
                }
            }

            it = response.erase(it);
            CounterContexts.erase(contextIt);
        }
        if (WriteLogForEveryCounter) {
            if (auto logStr = logStream.str(); !logStr.empty()) {
                NWLOG_L(notice, RATESRV, "Limits info: " + logStr);
            }
        }

        for (const auto& [id, counter] : response) {
            const auto idCopy{id};
            NWLOG_L(error, RATESRV, "Unexpected counter in response, id " + idCopy);
        }
        for (const auto& [id, context] : CounterContexts) {
            const auto limitName{context.Limit->GetLimitName()};
            const auto target{context.Target};
            NWLOG_L(error, RATESRV, "No data for limit [" + limitName + "] of [" + target + "]");
        }
    }

    void WriteCounterLog(
        std::ostringstream& logStream,
        const NYmodRateSrv::TCounter& counter,
        const std::string& target,
        const std::string& limitName)
    {
        std::string description;
        bool isError = false;

        switch (counter.State) {
            case NYmodRateSrv::ECounterState::Ok:
                description = "ok";
                break;
            case NYmodRateSrv::ECounterState::Exceeded:
                description = "exceeded";
                break;
            case NYmodRateSrv::ECounterState::Error:
                description = "error (" + counter.Description + ")";
                isError = true;
                break;
            case NYmodRateSrv::ECounterState::Unknown:
                description = "unknown";
                isError = true;
                break;
        }

        logStream << "[" << limitName << " of " << target << ": " << description;
        if (!isError) {
            logStream << "; " << counter.Current << ";" << counter.Available;
        }
        logStream << "] ";
    }

    void WriteFinalLog(boost::system::error_code ec) {
        if (!ec && WriteFinalLogOnlyOnError) {
            return;
        }

        using namespace std::string_literals;
        if (ec) {
            NWLOG_L(error, RATESRV, "failed to "s + (TRequestTag::value ? "update" : "check") + " limits: " +
                ec.message());
        } else {
            NWLOG_L(notice, RATESRV, "limits "s + (TRequestTag::value ? "updated" : "checked") +
                " successfully");
        }
    }

private:
    TContextPtr Context;
    bool WriteLogForEveryCounter;
    bool WriteFinalLogOnlyOnError;
    TCallback Callback;

    std::shared_ptr<NYmodRateSrv::TClientHelper> ClientHelper;
    TErrorCode ErrorCode;

    struct TCounterContext {
        std::string Target;
        NYmodRateSrv::TLimitPtr Limit;
    };

    std::unordered_map<std::string, TCounterContext> CounterContexts;
    NYmodRateSrv::TRequest Request;
};

} // namespace NNwSmtp::NRateSrv
