#include "consumer.h"

#include "utils.h"

#include <passport/infra/libs/cpp/utils/string/coder.h>
#include <passport/infra/libs/cpp/utils/string/string_utils.h>

#include <library/cpp/string_utils/tskv_format/builder.h>

#include <util/stream/format.h>

namespace NPassport::NBbAccess {
    static const TString DBFIELDS = "dbfields";
    static const TString ATTRIBUTES = "attributes";
    static const TString PHONE_ATTRIBUTES = "phone_attributes";
    static const TString ARGS = "args";
    static const TString GET_ARGS = "get_args";
    static const TString HOST = "host";
    static const TString SESSGUARD = "sessguard";
    static const TString UID = "uid";
    static const TString USERIP = "userip";
    static const TString REQUEST_NUMBER_THROUGH_ONE_CONNECTION = "request_number_through_one_connection";

    void TCumulativeValue::Map(TCumulativeValue::TType v) {
        Sumvalue_ += v;
        ++Count_;
        Min_ = std::min(Min_, v);
        Max_ = std::max(Max_, v);
    }

    void TCumulativeValue::Reduce(TCumulativeValue& to) const {
        to.Sumvalue_ += Sumvalue_;
        to.Count_ += Count_;
        to.Min_ = std::min(Min_, to.Min_);
        to.Max_ = std::max(Max_, to.Max_);
    }

    void TConsumer::Reserve() {
        Dbfields_.reserve(20);
        Attributes_.reserve(20);
        PhoneAttributes_.reserve(20);
        Args_.reserve(20);
    }

    namespace {
        // host patterns that need sessguard
        const std::vector<TString> SESSGUARD_HOSTS = {"passport.yandex"};

        bool NeedsSessguard(const TCgiParams& params) {
            auto it = params.find(HOST);
            if (it == params.end()) {
                return false;
            }

            for (auto& name : SESSGUARD_HOSTS) {
                if (it->second.Contains(name)) {
                    return true;
                }
            }
            return false;
        }
    }

    static bool HasUserPort(const TCgiParams& params) {
        auto it = params.find("user_port");
        if (it == params.end()) {
            return false;
        }

        ui16 port = 0;
        return TryIntFromString<10, ui16>(it->second, port);
    }

    static const TString SKIP_ARG_CHARS = "@; \t\v\r<>\\";
    static const size_t EXAMPLE_COUNT = 3;

    void TConsumer::Map(TData& data) {
        data.EnsureIsOk();

        ++Total_;
        if (data.IsHttps) {
            ++Https_;
        }
        if (data.IsCacheHit) {
            ++CacheHit_;
        }
        if (HasUserPort(*data.Params)) {
            ++ReqsWithUserPort_;
        }

        RespTime_ += data.RespTime;
        RespSize_ += data.RespSize;

        if (!Ip_.contains(data.Ip)) {
            Ip_.emplace(data.Ip);
        }
        if (Examples_.size() < EXAMPLE_COUNT) {
            Examples_.emplace(data.Line);
        }

        ProcessListedArg(*data.Params, Dbfields_, DBFIELDS);
        ProcessListedArg(*data.Params, Attributes_, ATTRIBUTES);
        ProcessListedArg(*data.Params, PhoneAttributes_, PHONE_ATTRIBUTES);

        auto skipArg = [](const TStringBuf buf) {
            for (char c : SKIP_ARG_CHARS) {
                if (buf.Contains(c)) {
                    return true;
                }
            }

            return false;
        };

        for (const auto& [name, value] : *data.Params) {
            if (skipArg(name)) {
                continue;
            }

            if (!Args_.contains(name)) {
                Args_.emplace(name);
            }
        }

        for (const auto& [name, value] : *data.GetParams) {
            if (!GetArgs_.contains(name)) {
                GetArgs_.emplace(name);
            }
        }

        ProcArgCount(Uids_, UID, *data.Params);
        if (data.RequestNumberThroughOneConnection) {
            RequestNumberThroughOneConnection_.Map(*data.RequestNumberThroughOneConnection);
        }

        // remember userip for suspicious hosts with no sessguard
        Y_UNUSED(NeedsSessguard);
        // if (needsSessguard(*data.params_) && !data.params_->contains(SESSGUARD)) {
        //     auto it = data.params_->find(USERIP);
        //     if (it != data.params_->end()) {
        //         userIpNoSessguard_.emplace(it->second);
        //     }
        // }

        std::optional<TStringBuf> host = GetHostFromRequest(*data.Params);
        if (host) {
            auto it = Hosts_.find(*host);
            if (it == Hosts_.end()) {
                Hosts_.emplace(*host, 1);
            } else {
                ++it->second;
            }
        }
    }

    void TConsumer::Reduce(TConsumer& to) const {
        to.Total_ += Total_;
        to.Https_ += Https_;
        to.CacheHit_ += CacheHit_;
        to.ReqsWithUserPort_ += ReqsWithUserPort_;
        to.RespSize_ += RespSize_;
        to.RespTime_ += RespTime_;
        ReduceSet(Dbfields_, to.Dbfields_);
        ReduceSet(Attributes_, to.Attributes_);
        ReduceSet(PhoneAttributes_, to.PhoneAttributes_);
        ReduceSet(Args_, to.Args_);
        ReduceSet(GetArgs_, to.GetArgs_);
        ReduceSet(Ip_, to.Ip_);
        ReduceSet(UserIpNoSessguard_, to.UserIpNoSessguard_);
        Uids_.Reduce(to.Uids_);
        RequestNumberThroughOneConnection_.Reduce(to.RequestNumberThroughOneConnection_);

        for (const TString& e : Examples_) {
            if (to.Examples_.size() >= EXAMPLE_COUNT) {
                break;
            }
            to.Examples_.emplace(e);
        }
    }

    static const size_t HOSTS_TO_SERIALIZE = 5;

    void TConsumer::PrintRaw(TStringBuf grantType, TStringBuf method, TStringBuf consumer, IOutputStream& stream) const {
        NTskvFormat::TLogBuilder tskv;
        tskv.Add("method", method);
        tskv.Add("grant_type", grantType);
        tskv.Add("consumer", consumer);
        tskv.Add("https", ToString(Https_));
        tskv.Add("cache_hit", ToString(CacheHit_));
        tskv.Add("reqs_with_user_port", ToString(ReqsWithUserPort_));
        tskv.Add("avg_resp_time", ToString(RespTime_.MicroSeconds() / Total_ / 1000.));
        tskv.Add("avg_resp_size", ToString(RespSize_ / Total_));

        auto print = [&tskv](const TString& key, const TCumulativeValue& value) {
            if (value.GetMax() > 0) {
                tskv.Add(key,
                         TStringBuilder() << value.GetMin()
                                          << ":" << Prec(value.GetAvg(), PREC_POINT_DIGITS, 1)
                                          << ":" << value.GetMax());
            }
        };

        print(UID, Uids_);
        print(REQUEST_NUMBER_THROUGH_ONE_CONNECTION, RequestNumberThroughOneConnection_);

        PrintRawListedArg(Dbfields_, DBFIELDS, tskv);
        PrintRawListedArg(Attributes_, ATTRIBUTES, tskv);
        PrintRawListedArg(PhoneAttributes_, PHONE_ATTRIBUTES, tskv);
        PrintRawListedArg(Args_, ARGS, tskv);
        PrintRawListedArg(GetArgs_, GET_ARGS, tskv);

        if (TString value = SerializeCountedValues(Hosts_, HOSTS_TO_SERIALIZE); value) {
            tskv.Add("top_hosts", value);
        }
        tskv.Add("all_ip", SerializeSet(Ip_));
        tskv.Add("some_example", SerializeSet(Examples_));

        if (!UserIpNoSessguard_.empty()) {
            // tskv.Add("userip_with_no_sessguard", serializeSet(userIpNoSessguard_)); // TODO: PASSP-27110
        }
        tskv.Add("total_reqs", ToString(Total_));

        tskv.End();
        stream << "tskv\t" << tskv.Str();
    }

    void TConsumer::PrintPretty(TStringBuf consumer, ui64 totalMethod, IOutputStream& stream) const {
        stream << LeftPad(consumer, 32)
               << " " << LeftPad(MakeRatio(Total_, totalMethod), 64)
               << " cache_hit=" << Prec(100. * CacheHit_ / Total_, PREC_POINT_DIGITS, 1) << "%" << Endl;
    }

    TString TConsumer::SerializeSet(const TConsumer::TSet& set) {
        TStringStream str;
        for (const TString& s : set) {
            str << s << Endl;
        }
        return NUtils::BinToBase64(str.Str(), true);
    }

    TString TConsumer::SerializeCountedValues(const TCountedValues& values, size_t topCount) {
        std::vector<std::pair<TStringBuf, ui64>> pairs;
        pairs.reserve(values.size());
        for (const auto& v : values) {
            pairs.emplace_back(v.first, v.second);
        }

        std::sort(pairs.begin(), pairs.end(), [](const auto& l, const auto& r) {
            return l.second > r.second;
        });

        const size_t toPrint = std::min(pairs.size(), topCount);

        TStringStream res;
        for (size_t idx = 0; idx < toPrint; ++idx) {
            res << pairs[idx].first << Endl;
        }

        return NUtils::BinToBase64(res.Str());
    }

    static const std::vector<TString> KNOWN_SIDE_DOMAINS = {
        "beru.ru",
        "edadeal.ru",
        "edastage.ru",
        "kinopoisk.ru",
        "yandexsport.ru",
    };

    std::optional<TStringBuf> TConsumer::GetHostFromRequest(const TCgiParams& params) {
        auto it = params.find("host");
        if (it == params.end()) {
            return {};
        }

        TStringBuf param = it->second;
        param.SkipPrefix(".");

        for (const TString& dom : KNOWN_SIDE_DOMAINS) {
            TStringBuf copy(param);
            if (!copy.ChopSuffix(dom)) {
                continue;
            }

            if (copy.empty() || copy.EndsWith(".")) {
                return param;
            }
        }

        for (size_t idx = 0; param && idx < 3; ++idx) {
            TStringBuf part = TStringBuf(param).RNextTok(".");

            if (part == "yandex" || part == "yandex-team") {
                return param;
            }

            if (part.size() > 3) {
                return {};
            }

            param.RNextTok(".");
        }

        return {};
    }

    void TConsumer::ProcessListedArg(TCgiParams& params, TConsumer::TSet& set, const TString& arg) {
        auto it = params.find(arg);
        if (it == params.end()) {
            return;
        }

        TStringBuf value = it->second;
        while (value) {
            TStringBuf val = value.NextTok(',');
            if (!set.contains(val)) {
                set.emplace(val);
            }
        }

        params.erase(it);
    }

    void TConsumer::PrintRawListedArg(const TConsumer::TSet& set, const TString& arg, NTskvFormat::TLogBuilder& tskv) {
        if (set.empty()) {
            return;
        }

        std::vector<TString> tmp(set.begin(), set.end());
        std::sort(tmp.begin(), tmp.end());

        TString buf;
        for (const TString& val : tmp) {
            NUtils::AppendSeparated(buf, ',', val);
        }

        tskv.Add(arg, buf);
    }

    void TConsumer::ProcArgCount(TCumulativeValue& val, const TString& arg, const TCgiParams& params) {
        auto it = params.find(arg);
        if (it == params.end()) {
            return;
        }

        ui64 c = 0;
        TStringBuf buf = it->second;
        while (buf) {
            if (buf.NextTok(',')) {
                ++c;
            }
        }
        val.Map(c);
    }

    void TConsumer::ReduceSet(const TConsumer::TSet& from, TConsumer::TSet& to) { // TODO replace with std::unordered_set::merge()
        for (const TString& v : from) {
            to.insert(v);
        }
    }

    void TConsumer::TData::EnsureIsOk() const {
        Y_ENSURE(Params, "params are NULL");
    }
}
