#include "cycles_checker.h"

#include <balancer/kernel/http/parser/http.h>

#include <util/string/split.h>


using namespace NSrvKernel;

namespace NModHttp {

const TStringBuf CyclesHeader{"X-Yandex-Internal-Cycles"};
constexpr char CyclesIdDelimiter = ';';
constexpr char CyclesCounterDelimiter = '-';


TCyclesChecker::TCyclesChecker(const TString& configUid, uint64_t maxCycles, uint64_t maxLength)
    : ConfigUid_(configUid)
    , MaxCycles_(maxCycles > 0 ? maxCycles : 1)
    , MaxLength_(maxLength > 0 ? maxLength : 4096)
{
    Y_VERIFY(!ConfigUid_.empty());
}

TError TCyclesChecker::Check(TRequest& request, TString& lengthAlert, TString& formatAlert) const noexcept {
    auto headerValue = request.Headers().GetFirstValue(CyclesHeader);
    if (!headerValue) { // in first balancer set on route
        return SetNewHeader(request);
    }

    try {
        TString newHeaderValue(headerValue);
        TString uid = MakeUid();
        const uint64_t uidLen = uid.length();
        Y_ASSERT(uidLen);
        if (const auto uidPos = headerValue.find(uid); uidPos == std::string::npos) {
            AddCounter(uid, 1);
            newHeaderValue += CyclesIdDelimiter + uid;
        } else {
            const bool firstOk = !uidPos || headerValue[uidPos - 1] == CyclesIdDelimiter;
            const bool lastOk = (uidPos + uidLen) < headerValue.length() && headerValue[uidPos + uidLen] == CyclesCounterDelimiter;
            if (firstOk && lastOk) {
                const auto counterEndPos = headerValue.find(CyclesIdDelimiter, uidPos + uidLen);
                const auto counterLen = counterEndPos != std::string::npos ? counterEndPos - (uidPos + uidLen + 1)
                        : headerValue.length() - (uidPos + uidLen + 1);
                uint64_t counter = FromString(headerValue.substr(uidPos + uidLen + 1, counterLen));
                if (!counter) {
                    formatAlert = uid + ": invalid counter value for " + headerValue;
                    return SetNewHeader(request);
                }
                if (++counter > MaxCycles_) {
                    return Y_MAKE_ERROR(yexception{} << "Request cycle detected in " << headerValue << " for " << uid);
                }
                newHeaderValue.replace(uidPos + uidLen + 1, counterLen, ToString(counter));
            } else {
                formatAlert = uid + ": invalid data: " + headerValue;
                return SetNewHeader(request);
            }
        }

        if (newHeaderValue.length() > MaxLength_) {
            lengthAlert = "Header length: " + ToString(newHeaderValue.length()) + ", max " + ToString(MaxLength_);
        }

        request.Headers().Replace(CyclesHeader, newHeaderValue);
    } catch (std::exception& e) {
        formatAlert = e.what();
        return SetNewHeader(request);
    }

    return {};
}

TError TCyclesChecker::SetNewHeader(TRequest& request) const noexcept {
    TString uid = MakeUid();
    AddCounter(uid, 1);
    request.Headers().Replace(CyclesHeader, uid);
    return {};
}

TString TCyclesChecker::MakeUid() const {
    return ConfigUid_;
}

void TCyclesChecker::AddCounter(TString& id, uint64_t counter) const {
    id += CyclesCounterDelimiter;
    id += ToString(counter);
}

}

