#include "processor.h"

#include <passport/infra/libs/cpp/unistat/builder.h>
#include <passport/infra/libs/cpp/utils/thread_local_id.h>
#include <passport/infra/libs/cpp/utils/log/global.h>

#include <util/generic/yexception.h>
#include <util/string/cast.h>
#include <util/system/compiler.h>
#include <util/system/thread.h>

namespace NPassport::NYsa {
    TRequest::TRequest()
        : CreationTime(TInstant::Now())
        , RequestId_(NUtils::GetThreadLocalRequestId())
    {
    }

    TRequest::TRequest(const TString& requestId)
        : CreationTime(TInstant::Now())
        , RequestId_(requestId)
    {
    }

    static const TString USER_IP = "X-Real-Ip";
    static const TString USER_PORT = "X-Real-Port";
    static const TString SERVER_IP = "X-Server-Ip";
    static const TString SERVER_PORT = "X-Server-Port";

    void TRequest::InitLazy() const {
        if (ServerPort_) {
            return;
        }

        auto getHeader = [this](const TString& name) -> const TString& {
            auto it = Headers.find(name);
            Y_ENSURE(it != Headers.end(), "missing header: " << name);
            return it->second;
        };

        auto getPort = [&getHeader](const TString& header) -> ui16 {
            ui16 port;
            const TString& value = getHeader(header);
            Y_ENSURE(TryIntFromString<10>(value, port),
                     "invalid port (" << header << "): '" << value << "'");
            return port;
        };

        UserIp_ = getHeader(USER_IP);
        UserPort_ = getPort(USER_PORT);
        ServerIp_ = getHeader(SERVER_IP);
        ServerPort_ = getPort(SERVER_PORT);
    }

    bool TFingerprint::CheckResponse() const {
        Y_ENSURE(PofResponse_.magic == P0F_FLOW_RESP_MAGIC, "Impossible 'magic': " << PofResponse_.magic);

        if (PofResponse_.status == P0F_STATUS_NOMATCH) {
            return false;
        }

        Y_ENSURE(PofResponse_.status == P0F_STATUS_OK, "Impossible status: " << PofResponse_.status);
        return true;
    }

    bool TFingerprint::IsConnectionSecure() const {
        return PofResponse_.f.in_tls == 1 && PofResponse_.f.tls_client_hello_recvd == 1;
    }

    void TFingerprint::ResetPofResponose(char c) {
        memset(&PofResponse_, c, sizeof(PofResponse_));
    }

    TProcessor::TProcessor(size_t debtLimit, TCallbackPofQuery callback)
        : DebtLimit_(debtLimit)
        , CallbackForQuery_(callback)
    {
        Y_VERIFY(CallbackForQuery_);
        Requests_.reserve(10000);
        TLog::Info() << "Processor succesfully constructed";
    }

    TResponses TProcessor::Run(TInstant now) {
        // Called from single thread - main event loop
        // So it is safe call p0f functions
        struct TTheadNamer {
            TTheadNamer() {
                TThread::SetCurrentThreadName("ysa_proc");
            }
        } thread_local namer;

        TResponses resps;
        resps.reserve(Requests_.size() + Debt_.size());

        size_t debtSize = Debt_.size();
        for (auto it = Debt_.begin(); it != Debt_.end();) {
            TRequest& request = *it;
            NUtils::TRequestIdGuard g(&request.RequestId());
            TResponse resp = RunOne(request);

            if (resp.IsReturnable() || request.CreationTime + TDuration::Seconds(15) < now) {
                TLog::Debug() << "TProcessor: " << (resp.Finger ? "successful" : "failed")
                              << " request from debt"
                              << ". src_ip=" << request.UserIp()
                              << ". src_port=" << request.UserPort()
                              << ". dst_ip=" << request.ServerIp()
                              << ". dst_port=" << request.ServerPort();

                ReturnResponse(std::move(request), std::move(resp), resps);
                Debt_.erase(it++);
                --debtSize;
            } else {
                ++it;
            }
        }

        for (TRequest& request : Requests_) {
            NUtils::TRequestIdGuard g(&request.RequestId());
            TResponse resp = RunOne(request);

            if (!resp.IsReturnable() && debtSize < DebtLimit_) {
                Debt_.push_back(std::move(request));
                ++debtSize;
                continue;
            }

            ReturnResponse(std::move(request), std::move(resp), resps);
        }

        DebtSize_ = debtSize;

        Requests_.clear();
        return resps;
    }

    size_t TProcessor::GetDebtSize() const {
        return DebtSize_.GetValue();
    }

    void TProcessor::AddUnistat(NUnistat::TBuilder& builder) const {
        builder.Add(DebtSize_);
    }

    TResponse TProcessor::RunOne(const TRequest& request) const {
        TResponse result;

        try {
            Proc(request, result);
        } catch (const std::exception& e) {
            result.Error = TStringBuilder() << "Process: " << e.what();
        }

        return result;
    }

    void TProcessor::ReturnResponse(TRequest&& request, TResponse&& resp, TResponses& resps) {
        resp.FetchTime = TInstant::Now();
        resp.Request = std::move(request);
        resps.push_back(std::move(resp));
    }

    void TProcessor::Proc(const TRequest& request, TResponse& res) const {
        res.Finger = MakeQuery(Ip4Or6FromString(request.UserIp().c_str()),
                               request.UserPort(),
                               Ip4Or6FromString(request.ServerIp().c_str()),
                               request.ServerPort());
    }

    TFingerprintPtr TProcessor::MakeQuery(const TIp4Or6& userIp,
                                          ui16 userPort,
                                          const TIp4Or6& serverIp,
                                          ui16 serverPort) const {
        const p0f_api_query query = PrepareQuery(
            userIp,
            userPort,
            serverIp,
            serverPort);

        TFingerprintPtr res = std::make_shared<TFingerprint>();

        CallbackForQuery_(query, res->MutatePof());

        if (!res->CheckResponse()) {
            res.reset();
        }
        return res;
    }

    p0f_api_query TProcessor::PrepareQuery(const TIp4Or6& userIp,
                                           ui16 userPort,
                                           const TIp4Or6& serverIp,
                                           ui16 serverPort) {
        Y_ENSURE(std::holds_alternative<TIp6>(userIp) == std::holds_alternative<TIp6>(serverIp),
                 "Ip addresses must both ipv4 or ipv6");

        auto fillIp = [](const TIp4Or6& ip, ui8* bin) {
            if (std::holds_alternative<TIp6>(ip)) {
                const TIp6 addr = std::get<TIp6>(ip);
                memcpy(bin, addr.Data, sizeof(TIp6::Data));
            } else {
                const TIp4 addr = std::get<TIp4>(ip);
                memcpy(bin, &addr, 4);
                memset(bin + 4, 0, 12);
            }
        };

        p0f_api_query query{};

        query.magic = P0F_FLOW_QUERY_MAGIC;
        query.f.ip_ver = std::holds_alternative<TIp6>(userIp) ? P0F_ADDR_IPV6 : P0F_ADDR_IPV4;

        fillIp(userIp, query.f.src);
        query.f.sport = userPort;

        fillIp(serverIp, query.f.dst);
        query.f.dport = serverPort;

        return query;
    }
}
