#include "handler.h"

#include <balancer/kernel/helpers/default_instance.h>
#include <balancer/kernel/helpers/reqid.h>
#include <balancer/kernel/helpers/yuid.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/iface.h>
#include <balancer/kernel/net/socket.h>
#include <balancer/kernel/net/sockops.h>
#include <balancer/kernel/regexp/regexp_pire.h>
#include <balancer/kernel/ssl/sslja3.h>

#include <library/cpp/containers/stack_array/stack_array.h>
#include <library/cpp/regex/pire/regexp.h>

#include <util/generic/intrlist.h>
#include <util/generic/string.h>
#include <util/stream/format.h>
#include <util/system/hostname.h>

#include <alloca.h>

namespace NModHeaders {
    using namespace NSrvKernel;

    namespace {
        TString GetHost(const TConnDescr& descr) noexcept {
            return TString{descr.Request->Headers().GetFirstValue("host")};
        }

       TString GetLocation(const TConnDescr& descr) noexcept {
            if (auto host = GetHost(descr)) {
                auto url = ToString(descr.Request->RequestLine().GetURL());
                if (url && url[0] != '/') {
                    host.append('/');
                }
                return host + url;
            }

            return {};
        }

        TString GetSslClientHandshakeInfo(const TConnDescr& descr) noexcept {
            TStringBuilder out;
            out.reserve(60);
            out << "handshake-time=" << descr.Ssl().HandshakeDuration
                << ", " << (descr.Ssl().HandshakeUsedTlsTickets ? "tls-tickets" : "no-tls-tickets")
                << ", handshake-ts=" << descr.Ssl().HandshakeCompleted.TimeT()
                << ", cipher-id=" << descr.Ssl().CurrentCipherId
                << ", protocol-id=" << descr.Ssl().CurrentProtocolId;
            return out;
        }

        TString GetSslEarlyData(const TConnDescr& descr) noexcept {
            return descr.Ssl().EarlyData ? "1" : "0";
        }

        TString GetExpStatic(const TConnDescr& descr) noexcept {
            if (descr.Properties->Parent.Experiments.empty()) {
                return "no";
            }

            TStringBuilder out;
            out.reserve(15);

            for (const auto&[exp, bucket]: descr.Properties->Parent.Experiments
                ) {
                if (out) {
                    out << ';';
                }
                out << exp << ",0," << bucket;
            }

            return out;
        }

        TString GetTcpInfoImpl(const TConnDescr& descr) noexcept {
            if (descr.Properties->Parent.Socket) {
                TTcpInfo info;
                Y_TRY(TError, err) {
                    return GetTcpInfo(*descr.Properties->Parent.Socket).AssignTo(info);
                } Y_CATCH {
                    return "unknown";
                };
                return TStringBuilder() << "v=2" // See SERP-51544
                    << "; rtt=" << info.Rtt
                    << "; rttvar=" << info.RttVar
                    << "; snd_cwnd=" << info.SndCwnd
                    << "; total_retrans=" << info.TotalRetrans;
            } else {
                return "unknown";
            }
        }

        TString GenTime(const TCreateHeaderParameters& parameters) noexcept {
            auto time = TInstant::Now();

            const auto diff = std::get_if<TSignedDuration>(&parameters);
            if (diff && diff->Negative) {
                time -= diff->Value;
            } else if (diff) {
                time += diff->Value;
            }

            return time.ToRfc822String();
        }

        ui64 RandomNonZeroUI64() {
            return RandomNumber<ui64>(std::numeric_limits<ui64>::max()) + 1;
        }

        TString GenTraceparent() noexcept {
            // Traceparent header format:
            // https://www.w3.org/TR/trace-context/#traceparent-header

            constexpr TStringBuf TRACEPARENT_VERSION = "00";
            constexpr TStringBuf TRACEPARENT_FLAGS_UNSAMPLED = "00";
            Y_UNUSED(TRACEPARENT_FLAGS_UNSAMPLED);
            constexpr TStringBuf TRACEPARENT_FLAGS_SAMPLED = "01";

            const ui64 traceIdFirst8Bytes  = RandomNumber<ui64>();
            const ui64 traceIdSecond8Bytes = traceIdFirst8Bytes ? RandomNumber<ui64>() : RandomNonZeroUI64();

            TStringBuilder out;
            out.reserve(55);
            out << TRACEPARENT_VERSION
                << "-"
                << Hex(traceIdFirst8Bytes,  ENumberFormatFlag::HF_FULL)
                << Hex(traceIdSecond8Bytes, ENumberFormatFlag::HF_FULL)
                << "-"
                << Hex(RandomNonZeroUI64(), ENumberFormatFlag::HF_FULL)
                << "-"
                << TRACEPARENT_FLAGS_SAMPLED;
            out.to_lower();
            return out;
        }

#define GetSslProp(descr, prop) \
    (descr.Ssl().ClientCertData ? descr.Ssl().ClientCertData->prop : Nothing())

        TString ApplyFunc(EHeaderFunc func, const TConnDescr& descr, const TCreateHeaderParameters& parameters) noexcept {
            switch (func) {
            case EHeaderFunc::ReqId:
                return GenReqId(descr.Properties->Start, descr.Properties->Random);

            case EHeaderFunc::MarketReqId:
                return GenMarketReqId(descr.Properties->Start, descr.Properties->Random);

            case EHeaderFunc::AdfoxReqId:
                return GenAdfoxReqId(descr.Properties->Random);

            case EHeaderFunc::SearchReqId:
                return GenSearchReqId(descr.Properties->Start, descr.Properties->Random, HostName());

            case EHeaderFunc::RealIp:
                return descr.RemoteAddrStr();

            case EHeaderFunc::Hostname:
                return HostName();

            case EHeaderFunc::LocalIp:
                return descr.LocalAddrStr();

            case EHeaderFunc::RealPort:
                return ::ToString(descr.Properties->Parent.RemoteAddress->Port());

            case EHeaderFunc::LocalPort:
                return ::ToString(descr.Properties->Parent.LocalAddress->Port());

            case EHeaderFunc::Url:
                return descr.Request ? ToString(descr.Request->RequestLine().GetURL()) : TString();

            case EHeaderFunc::Location:
                return GetLocation(descr) ?: "undefined";

            case EHeaderFunc::Host:
                return GetHost(descr) ?: "undefined";

            case EHeaderFunc::Yuid:
                return GetYuid(descr.Properties->Random, descr.Properties->Start);

            case EHeaderFunc::Scheme:
                return descr.Properties->UserConnIsSsl ? "https" : "http";

            case EHeaderFunc::Proto:
                return ::ToString(descr.Ssl().NextProto);

            case EHeaderFunc::SslClientCertCn:
                return GetSslProp(descr, CN).GetOrElse("undefined");

            case EHeaderFunc::SslClientCertSubject:
                return GetSslProp(descr, Subject).GetOrElse("undefined");

            case EHeaderFunc::SslClientCertVerifyResult:
                if (const auto& vr = GetSslProp(descr, VerifyResult)) {
                    return ::ToString(*vr);
                } else {
                    return "undefined";
                }

            case EHeaderFunc::SslClientCertSerialNumber:
                return GetSslProp(descr, SerialNumber).GetOrElse("undefined");

            case EHeaderFunc::SslClientHandshakeInfo:
                return GetSslClientHandshakeInfo(descr);

            case EHeaderFunc::SslEarlyData:
                return GetSslEarlyData(descr);

            case EHeaderFunc::SslTicketName:
                return descr.Ssl().TicketName;

            case EHeaderFunc::SslTicketIV:
                return descr.Ssl().TicketIV;

            case EHeaderFunc::ExpStatic:
                return GetExpStatic(descr);

            case EHeaderFunc::TcpInfo:
                return GetTcpInfoImpl(descr);

            case EHeaderFunc::Ja3:
                return TSslJa3(descr.Ssl().Ja3Data).Ja3HeaderValue();

            case EHeaderFunc::Ja4:
                return TSslJa3(descr.Ssl().Ja3Data).Ja4HeaderValue();

            case EHeaderFunc::Time:
                return GenTime(parameters);

            case EHeaderFunc::P0f:
                return descr.Properties->Parent.P0f.GetOrElse("");

            case EHeaderFunc::Traceparent:
                return GenTraceparent();
            }
        }
    }


    TString TCreateHeader::GenValue(const NSrvKernel::TConnDescr& descrRequest, const THashMap<TString, TSharedFileReReader>& files) const noexcept {
        return std::visit([&](auto&& val) -> TString {
            using TVal = std::decay_t<decltype(val)>;
            if constexpr (std::is_same_v<TString, TVal>) {
                return val;
            } else if constexpr (std::is_same_v<EHeaderFunc, TVal>) {
                return ApplyFunc(val, descrRequest, Parameters_);
            } else if constexpr (std::is_same_v<ECommonFunc, TVal>) {
                return ApplyFunc(val, descrRequest);
            } else if constexpr (std::is_same_v<TValueFromFile, TVal>) {
                auto it = files.find(val.FileName_);
                if (it != files.end()) {
                    TString value = it->second.Data().Data();
                    if (!CheckHeaderValue(value)) {
                        Y_HTTPD_LOG_IMPL(descrRequest.ErrorLog, TLOG_ERR, "headers", descrRequest,
                                         "invalid header value " << value.Quote() << " for header \"" <<
                                         Name_ << "\" in file \"" << val.FileName_ << "\" will be ignored");
                        return {};
                    }
                    return value;
                }
                return {};
            }
        }, Value_);
    }

    bool TCreateHeader::AllowEmpty() const noexcept {
        return std::visit([&](auto&& val) -> bool {
            using TVal = std::decay_t<decltype(val)>;
            if constexpr (std::is_same_v<TString, TVal>) {
                return false;
            } else if constexpr (std::is_same_v<EHeaderFunc, TVal>) {
                return true;
            } else if constexpr (std::is_same_v<ECommonFunc, TVal>) {
                return false;
            } else if constexpr (std::is_same_v<TValueFromFile, TVal>) {
                return false;
            }
        }, Value_);
    }

    bool THeadersHandler::Delete(TString regexp) noexcept {
        Delete_ = TFsm::Glue(
            Delete_,
            TFsm(regexp, TFsm::TOptions().SetCaseInsensitive(true))
        );

        return CheckRestrictedHeaderRegexp(Delete_);
    }

    void THeadersHandler::CopyValue(TString src, TString dst, bool weak) noexcept {
        TString srcLowerCase = to_lower(src);
        TString dstLowerCase = to_lower(dst);
        const ui32 idx = Copy_.size();
        Copy_.emplace(srcLowerCase, TFuncIdx<TCopyHeader>({src, dst, weak}, idx));
        CopyReverse_.emplace(dstLowerCase, TFuncIdx<TCopyReverseHeader>({weak}, idx));
    }

    void THeadersHandler::Create(TCreateHeader func, bool weak) noexcept {
        TString nameLowerCase = to_lower(func.GetName());
        if (weak) {
            CreateFuncsWeak_.emplace(nameLowerCase, TFuncIdx<TCreateHeader>(func, CreateFuncsWeak_.size()));
        } else {
            CreateFuncs_.emplace(nameLowerCase, TFuncIdx<TCreateHeader>(func, CreateFuncs_.size()));
        }
    }

    void THeadersHandler::Append(TCreateHeader func, bool weak) noexcept {
        TString nameLowerCase = to_lower(func.GetName());
        if (weak) {
            AppendFuncsWeak_.emplace(nameLowerCase, TFuncIdx<TCreateHeader>(
                func,
                AppendFuncsWeak_.size()
            ));
        } else {
            AppendFuncs_.emplace(nameLowerCase, TFuncIdx<TCreateHeader>(
                func,
                AppendFuncs_.size()
            ));
        }
    }

    void THeadersHandler::AddFile(TValueFromFile fromFile) {
        FromFiles_.push_back(std::move(fromFile));
    }

    void THeadersHandler::SetupFileReaders(THashMap<TString, TSharedFileReReader>& readers, IWorkerCtl* process) {
        for (const auto& fromFile: FromFiles_) {
            readers[fromFile.FileName_] = process->SharedFiles()->FileReReader(fromFile.FileName_, TDuration::Seconds(1));
        }
    }

    void THeadersHandler::Apply(THeaders& headers, const TConnDescr& prevDescr, const THashMap<TString, TSharedFileReReader>& files) const noexcept {
        THeaders newHeaders = headers.UseBuffer();

        // Skip all copy operations by default
        NStackArray::TStackArray<bool> excludedCopyOps(ALLOC_ON_STACK(bool, Copy_.size()));
        Fill(excludedCopyOps.begin(), excludedCopyOps.end(), false);
        // Used to find source value in copy op loop
        NStackArray::TStackArray<const TStringStorage*> sourceCopyPtr(ALLOC_ON_STACK(const TStringStorage*, Copy_.size()));
        Fill(sourceCopyPtr.begin(), sourceCopyPtr.end(), nullptr);
        // Used to prevent creation of headers with names from copy function
        NStackArray::TStackArray<bool> excludedCreateFuncOps(ALLOC_ON_STACK(bool, CreateFuncs_.size()));
        Fill(excludedCreateFuncOps.begin(), excludedCreateFuncOps.end(), false);

        NStackArray::TStackArray<bool> excludedCreateWeakFuncOps(ALLOC_ON_STACK(bool, CreateFuncsWeak_.size()));
        Fill(excludedCreateWeakFuncOps.begin(), excludedCreateWeakFuncOps.end(), false);

        NStackArray::TStackArray<bool> excludedAppendFuncOps(ALLOC_ON_STACK(bool, AppendFuncs_.size()));
        Fill(excludedAppendFuncOps.begin(), excludedAppendFuncOps.end(), false);

        for (const auto& header: headers) {
            TString headerLowerCase = TString(header.first.AsStringBuf());
            headerLowerCase.to_lower();
            NRegExp::TMatcher matcher(Delete_);
            bool skipCopyToNewHeaders = false;
            bool skipCreateHeader = false;

            // If matched then do not copy or append this header.
            if (Match(matcher, headerLowerCase).Final()) {
                skipCopyToNewHeaders = true;
            }

            // If source header is found then allow copying it.
            // Create and append operations are allowed for the source header.
            const auto copyIt = Copy_.find(headerLowerCase);
            if (copyIt != Copy_.end()) {
                sourceCopyPtr[copyIt->second.Index] = &header.second[0];
            }

            // If destination header for copying is strong then skip all other operations
            // otherwise skip copy operation.
            const auto copyReverseIt = CopyReverse_.find(headerLowerCase);
            if (copyReverseIt != CopyReverse_.end()) {
                if (copyReverseIt->second.Func.Weak) {
                    excludedCopyOps[copyReverseIt->second.Index] = true;
                } else {
                    skipCopyToNewHeaders = true;
                    skipCreateHeader = true;
                }
            }

            // If no copy operation is defined then skip copying the old value to new headers
            // struct and create new header with current name outside this loop.
            const auto createFuncIt = CreateFuncs_.find(headerLowerCase);
            if (createFuncIt != CreateFuncs_.end()) {
                if (skipCreateHeader) {
                    excludedCreateFuncOps[createFuncIt->second.Index] = true;
                } else {
                    TString value = createFuncIt->second.Func.GenValue(prevDescr, files);
                    if (value || createFuncIt->second.Func.AllowEmpty()) {
                        newHeaders.Add(
                            createFuncIt->second.Func.GetName(),
                            std::move(value)
                        );
                        skipCopyToNewHeaders = true;
                        excludedCreateFuncOps[createFuncIt->second.Index] = true; // already done
                    }
                }
            }

            // Drop all weak headers which present in request/response.
            // Create Weak is not working with Delete expression.
            const auto createFuncWeakIt = CreateFuncsWeak_.find(headerLowerCase);
            if (createFuncWeakIt != CreateFuncsWeak_.end()) {
                excludedCreateWeakFuncOps[createFuncWeakIt->second.Index] = true;
            }

            if (!skipCopyToNewHeaders && header.second) {
                const auto appendFuncIt = AppendFuncs_.find(headerLowerCase);
                const auto appendFuncWeakIt = AppendFuncsWeak_.find(headerLowerCase);

                TVector<TStringStorage> values = std::move(header.second);

                // In a case when both strong and weak headers with the same name are
                // present apply only strong append function. Remove all functions
                // which were applied.
                if (appendFuncIt != AppendFuncs_.end()) {
                    values[0] = TStringStorage(TString(TStringBuilder{}
                        << values[0].AsStringBuf() << appendFuncIt->second.Func.GenDelimiter()
                        << appendFuncIt->second.Func.GenValue(prevDescr, files)));

                    excludedAppendFuncOps[appendFuncIt->second.Index] = true;
                } else if (appendFuncWeakIt != AppendFuncsWeak_.end()) {
                    values[0] = TStringStorage(TString(TStringBuilder{}
                        << values[0].AsStringBuf() << appendFuncWeakIt->second.Func.GenDelimiter()
                        << appendFuncWeakIt->second.Func.GenValue(prevDescr, files)));
                }

                newHeaders.Add(header.first, std::move(values));
            }
        }

        // If the header was not present in request/response than apply
        // append like strong Create function. Create has higher priority
        // to prevent additional checking in this loop.
        // For copy operation we can't skip additional checking.
        for (const auto&[header, op] : AppendFuncs_) {
            if (excludedAppendFuncOps[op.Index]) {
                continue;
            }
            if (CopyReverse_.find(header) != CopyReverse_.end()) {
                continue;
            }

            newHeaders.Add(
                op.Func.GetName(),
                op.Func.GenValue(prevDescr, files)
            );
        }

        // Add strong Create headers. If header already present in headers
        // it will be added as a separate header.
        for (const auto&[header, op] : CreateFuncs_) {
            if (excludedCreateFuncOps[op.Index]) {
                continue;
            }
            if (CopyReverse_.find(header) != CopyReverse_.end()) {
                continue;
            }
            TString value = op.Func.GenValue(prevDescr, files);
            if (value || op.Func.AllowEmpty()) {
                newHeaders.Add(
                    op.Func.GetName(),
                    std::move(value)
                );
            }
        }

        // The same as strong Create, but only for headers which were not
        // present in source request/response.
        for (const auto&[header, op] : CreateFuncsWeak_) {
            if (excludedCreateWeakFuncOps[op.Index]) {
                continue;
            }
            if (CopyReverse_.find(header) != CopyReverse_.end()) {
                continue;
            }
            TString value = op.Func.GenValue(prevDescr, files);
            if (value || op.Func.AllowEmpty()) {
                newHeaders.Add(
                    op.Func.GetName(),
                    std::move(value)
                );
            }
        }

        // Apply all copy operations.
        for (const auto&[header, op]: Copy_) {
            if (excludedCopyOps[op.Index] || sourceCopyPtr[op.Index] == nullptr) {
                continue;
            }
            newHeaders.Add(op.Func.Dst, *sourceCopyPtr[op.Index]);
        }

        headers.Swap(newHeaders);
    }

    bool MatchHeaderPrefix(TStringBuf header, TStringBuf prefix) {
        return header.size() >= prefix.size() && header.substr(0, prefix.size()) == prefix;

    }

    TCreateHeader ParseHeaderFunc(const TString& key, TStringBuf value, TStringBuf delimiter) {
        try {
            return {key, ::FromString<EHeaderFunc>(value), {}, delimiter};
        } catch (const yexception&) {
            try {
                return {key, ::FromString<ECommonFunc>(value), {}, delimiter};
            } catch (const yexception&) {
                if (MatchHeaderPrefix(value, "time")) {
                    return SubParserHeaderFuncTime(key, value, delimiter);
                }

                ythrow TConfigParseError() << "unknown func value: " << TString{value}.Quote();
            }
        }
    }

    TCreateHeader SubParserHeaderFuncTime(const TString& key, TStringBuf value, TStringBuf delimiter) {
        if (value.size() < 6 || value[4] != ':') {
            ythrow TConfigParseError() << "incorrect parameter for time: " << TString{value}.Quote();
        }

        bool hasSign = (value[5] == '-' || value[5] == '+');
        if (hasSign && value.size() < 7) {
            ythrow TConfigParseError() << "incorrect parameter for time: " << TString{value}.Quote();
        }

        bool isNegative = false;
        if (hasSign && value[5] == '-') {
            isNegative = true;
        }

        auto time = FromString<TDuration>(value.substr(hasSign ? 6 : 5, value.size()));
        return {key, EHeaderFunc::Time, TSignedDuration(time, isNegative), delimiter };
    }
}
