#include "sslja3.h"

#include <util/string/builder.h>
#include <util/string/cast.h>

namespace {
    ui8 TryGetValue(const TVector<ui8>& v, size_t i) {
        return i < v.size() ? v[i] : 0;
    }

    const THashSet<ui8> EMPTY_GREASE8{};
}

TString NSrvKernel::TSslJa3::Ja3HeaderValue() const noexcept {
    TStringBuilder ja3;

    if (Y_LIKELY(Data_ != nullptr)) {
        // Legacy version field from client's hello message
        ja3 << ToString(Data_->LegacyVersion) << ",";

        // List of client's ciphers
        SerializeVector16bit(Data_->Ciphers, 0, ja3);
        ja3 << ",";

        // List of client's extensions
        SerializeVectorInt(Data_->ClientExtensions, ja3);
        ja3 << ",";

        // List of client's elliptic curves
        SerializeVector16bit(Data_->EllipticCurves, 2, ja3);
        ja3 << ",";

        // Elliptic Curves Point Formats
        SerializeVector8bit(Data_->EllipticCurvesPointFormats, 1, ja3);
    } else {
        ja3 << "0,,,,"sv;
    }

    return ja3;
}

TString NSrvKernel::TSslJa3::Ja4HeaderValue() const noexcept {
    TStringBuilder ja4;

    if (Y_LIKELY(Data_ != nullptr)) {
        // Signature algorithms
        SerializeVector16bit(Data_->SignatureAlgorithms, 2, ja4);
        ja4 << ",";

        // Signature algorithms in certificates
        SerializeVector16bit(Data_->SignatureAlgorithmsCert, 2, ja4);
        ja4 << ",";

        // Supported TLS versions
        SerializeVector16bit(Data_->SupportedVersions, 1, ja4);
        ja4 << ",";

        // List of protocols advertised by client
        SerializeProtocols(ja4);
        ja4 << ",";

        // Key share
        SerializeKeyShare(ja4);
        ja4 << ",";

        // Pre-shared key exchange modes
        SerializeVector8bit(Data_->PskKeyExchangeModes, 1, PskKeyExchangeModesGrease_, ja4);
    } else {
        ja4 << ",,,,,"sv;
    }

    return ja4;
}

void NSrvKernel::TSslJa3::SerializeVector16bit(const TVector<ui8>& v, size_t startPos, TStringBuilder& dest) const noexcept {
    bool first = true;
    for (size_t i = startPos; i < v.size(); i += 2) {
        const ui16 value = (v[i] << 8) | TryGetValue(v, i + 1);
        if (Grease_.find(value) == Grease_.end()) {
            if (!first) {
                dest << "-"sv;
            } else {
                first = false;
            }
            dest << ToString(value);
        }
    }
}

void NSrvKernel::TSslJa3::SerializeVectorInt(const TVector<int>& v, TStringBuilder& dest) const noexcept {
    bool first = true;
    for (const auto value : v) {
        if (Grease_.find(value) == Grease_.end()) {
            if (!first) {
                dest << "-"sv;
            } else {
                first = false;
            }
            dest << ToString(value);
        }
    }
}

void NSrvKernel::TSslJa3::SerializeVector8bit(const TVector<ui8>& v, size_t startPos, TStringBuilder& dest) noexcept {
    SerializeVector8bit(v, startPos, EMPTY_GREASE8, dest);
}

void NSrvKernel::TSslJa3::SerializeVector8bit(const TVector<ui8>& v, size_t startPos, const THashSet<ui8>& grease, TStringBuilder& dest) noexcept {
    bool first = true;
    for (size_t i = startPos; i < v.size(); ++i) {
        if (grease.find(v[i]) == grease.end()) {
            if (!first) {
                dest << "-"sv;
            } else {
                first = false;
            }
            dest << ToString(v[i]);
        }
    }
}

void NSrvKernel::TSslJa3::SerializeProtocols(TStringBuilder& dest) const noexcept {
    const auto& protocols = Data_->ApplicationLayerProtocolNegotiation;
    bool first = true;
    for (size_t i = 2; i < protocols.size();) {
        const ui8 nameLen = protocols[i];
        if (nameLen == 0) {
            ++i;
            continue;
        }
        ++i;
        const size_t nameEnd = Min(i + nameLen, protocols.size());
        if (IsALPNGrease(protocols, i, nameEnd)) {
            i = nameEnd;
            continue;
        }
        if (!first) {
            dest << "-"sv;
        } else {
            first = false;
        }
        while (i < nameEnd) {
            const unsigned char value = protocols[i];
            if (value != ',') {
                dest << value;
            }
            ++i;
        }
    }
}

void NSrvKernel::TSslJa3::SerializeKeyShare(TStringBuilder& dest) const noexcept {
    const auto& keyShare = Data_->KeyShare;
    bool first = true;
    for (size_t i = 2; i < keyShare.size();) {
        const ui16 group = (keyShare[i] << 8) | TryGetValue(keyShare, i + 1);
        if (Grease_.find(group) != Grease_.end()) {
            const ui16 keyLen = (TryGetValue(keyShare, i + 2) << 8) | TryGetValue(keyShare, i + 3);
            i += 4 + keyLen;
            continue;
        }
        if (!first) {
            dest << "-"sv;
        } else {
            first = false;
        }
        dest << ToString(group);
        const ui16 keyLen = (TryGetValue(keyShare, i + 2) << 8) | TryGetValue(keyShare, i + 3);
        i += 4 + keyLen;
    }
}

bool NSrvKernel::TSslJa3::IsALPNGrease(const TVector<ui8>& v, size_t begin, size_t end) const noexcept {
    if (end - begin < ALPNGreasePrefix_.size()) {
        return false;
    }
    for (size_t i = 0; i < ALPNGreasePrefix_.size(); ++i) {
        if (v[i + begin] != ALPNGreasePrefix_[i]) {
            return false;
        }
    }
    return true;
}
