#include "state_processor.h"

#include <crypta/lib/native/identifiers/lib/generic.h>
#include <crypta/graph/rt/fp/state_processor/proto/state.pb.h>
#include <crypta/graph/rt/lib/sensors/time_shift.h>

#include <library/cpp/string_utils/base64/base64.h>
#include <ads/bsyeti/libs/experiments/user_ids/user_ids.h>

namespace {
    using namespace NHerschel;

    class TIdStatWrapper {
    public:
        explicit TIdStatWrapper() {
        }

        explicit TIdStatWrapper(const TIdStat& right)
            : stat(right) {
        }

        explicit TIdStatWrapper(const TString& id, TTimestamp timestamp = 0, ui64 counter = 0) {
            stat.SetId(id);
            stat.SetTimeStamp(timestamp);
            stat.SetCounter(counter);
        }

        void AddIdstats(const TIdStat& right) {
            if (stat.HasId() && (stat.GetId() != right.GetId())) {
                ythrow yexception() << "Allow merge only equal ids!";
            } else if (!stat.HasId()) {
                stat.SetId(right.GetId());
            }
            stat.SetTimeStamp(std::max(stat.GetTimeStamp(), right.GetTimeStamp()));
            stat.SetCounter(stat.GetCounter() + right.GetCounter());
        }

        void AddIdstats(const TIdStatWrapper& right) {
            AddIdstats(right.stat);
        }

        void AddIdstats(const TString& id, const TTimestamp timestamp, const ui64 counter = 1) {
            if (stat.HasId() && (stat.GetId() != id)) {
                ythrow yexception() << "Allow merge only equal ids!";
            } else if (!stat.HasId()) {
                stat.SetId(id);
            }
            stat.SetTimeStamp(std::max(stat.GetTimeStamp(), timestamp));
            stat.SetCounter(stat.GetCounter() + counter);
        }

        TIdStat GetStat() const {
            return std::move(stat);
        }

    private:
        TIdStat stat;
    };

    class TIdsAggregator {
    public:
        void AddIdstats(const TIdStat& stat) {
            if (stat.GetId().empty()) {
                return;
            }
            mapping[stat.GetId()].AddIdstats(stat);
        }

        bool AddIdstats(const TString& id, const TTimestamp timestamp) {
            if (id.empty()) {
                return false;
            }
            bool isNewId{mapping.find(id) == mapping.end()}; 
            mapping[id].AddIdstats(id, timestamp);
            return isNewId;
        }

        TVector<TIdStat> Prioritize(const ui64 limit,
                                    [[maybe_unused]] const TTimestamp currentTimestamp,
                                    const TTimestamp minTimestamp) {
            TVector<TIdStat> items{};
            items.reserve(mapping.size());
            for (const auto& [key, wrapper] : mapping) {
                items.push_back(wrapper.GetStat());
            }

            Sort(items.begin(), items.end(), [&](const TIdStat& left, const TIdStat& right) {
                const bool orderTimestamp{
                    (left.GetTimeStamp() > right.GetTimeStamp()) && (left.GetCounter() > right.GetCounter())
                };
                const bool orderCounter{
                    (left.GetCounter() == right.GetCounter())
                        ? (left.GetTimeStamp() > right.GetTimeStamp())
                        : (left.GetCounter() > right.GetCounter())
                };
                if ((left.GetTimeStamp() < minTimestamp) || (right.GetTimeStamp() < minTimestamp)) {
                    return orderTimestamp;
                }
                return orderCounter;
            });

            if (items.size() > limit) {
                items.resize(limit);
            }
            return items;
        }

    private:
        THashMap<TString, TIdStatWrapper> mapping;
    };
}

namespace NHerschel {

    THerschelProcessor::TGroupedChunk THerschelProcessor::PrepareGroupedChunk(
            TString dataSource, THerschelProcessorBase::TManager& stateManager, TMessageBatch data) {

        NSFStats::TSolomonContext sctx{SensorsContext.Detached(), {{"place", "herschel_processor"}}};
        INFO_LOG << "Parse message of type " << dataSource << "\n";
        TGroupedChunk result;
        size_t messagesSkiped{0};
        size_t messagesTaken{0};
        size_t messagesBanSkiped{0};

        for (auto& messageData : data.Messages) {
            messageData.Unpack();

            TEventMessage message{};
            TStringBuf skip;

            for (NFraming::TUnpacker unpacker(messageData.Data); unpacker.NextFrame(message, skip);) {
                TFpEvent msg{};
                TString key{};

                Y_PROTOBUF_SUPPRESS_NODISCARD msg.ParseFromString(message.GetBody());
                Y_PROTOBUF_SUPPRESS_NODISCARD msg.GetFingerprint().SerializeToString(&key);
                NCrypta::SetMessageLag(sctx, message);

                if (auto duration{Debounce.Pull(key)}; duration.Defined()) {
                    // skip debounce rate
                    ++messagesSkiped;
                    NCrypta::SetDebounceLag(sctx, (*duration));
                    continue;
                } else if (Banned.Pull(key)) {
                    // skip banned rate
                    ++messagesBanSkiped;
                    continue;
                } else {
                    Debounce.Push(key);
                    ++messagesTaken;
                }

                auto stateRequest{stateManager.RequestState()};
                stateRequest->Set(Descriptors.HerschelStateDescriptor, key);

                THerschelGroupedChunk chunk{msg, message.GetTimeStamp()};

                DEBUG_LOG << "TimeStamp: " << message.GetTimeStamp() << " Type: " << message.GetType() << Endl;
                result[std::move(stateRequest)].push_back(chunk);
            }
        }
        sctx.Get<NSFStats::TSumMetric<ui64>>("messages_count_ban_skipped").Inc(messagesBanSkiped);
        sctx.Get<NSFStats::TSumMetric<ui64>>("messages_count_skipped").Inc(messagesSkiped);
        sctx.Get<NSFStats::TSumMetric<ui64>>("messages_count_taken").Inc(messagesTaken);
        return result;
    }

    void THerschelProcessor::ProcessGroupedChunk(TString dataSource, TGroupedChunk groupedRows) {
        NSFStats::TSolomonContext sctx{SensorsContext.Detached(), {{"place", "herschel_processor"}}};

        INFO_LOG << "Processing message of type " << dataSource << Endl;
        const auto limit{Config.GetIdsLimit()};

        THashMap<TString, TBanCounterMetric> counter{};
        using TBanHistMetric = NSFStats::TSolomonThresholdMetric<
            1, 3, 5, 7, 9,
            11, 33, 55, 77, 99,
            111, 333, 555, 777, 999
        >;

        for (auto& [request, rows] : groupedRows) {

            auto& state{request->Get(Descriptors.HerschelStateDescriptor)->GetState()};
            auto& key{request->Get(Descriptors.HerschelStateDescriptor)->GetStateId()};
            NCrypta::NEvent::TFingerprint keyFp{};
            Y_PROTOBUF_SUPPRESS_NODISCARD keyFp.ParseFromString(key);
            const TString version{NCrypta::NEvent::EHerschelVersion_Name(keyFp.GetVersion())};

            state.SetScope(keyFp.GetUAHash() ? EHerschelScope::IP_UA : EHerschelScope::IP);
            const ui64 newIdsCounter{UpdateHershelStateIds(state, limit, rows)};

            [[maybe_unused]] bool stateWasSwitched{BanPopularFps(keyFp, state, newIdsCounter, &counter[version])};

            if (state.GetBan().GetIsBanned()) {
                Banned.Push(key);
                ClearVultureState(key);
                sctx.Get<TBanHistMetric>("ban_banned_hist").Add(state.GetBan().GetCounter());
                ++(counter[version].skipped);
            } else {
                UpdateVultureState(key, state);
                sctx.Get<TBanHistMetric>("ban_unbanned_hist").Add(state.GetBan().GetCounter());
                ++(counter[version].updated);
            }
            sctx.Get<TBanHistMetric>("ban_counter_hist").Add(state.GetBan().GetCounter());
            // Cerr << "\n final " << version;
        }

        for (const auto& [key, ctr] : counter) {
            // Cerr << "\n" << key << " " << ctr.changed << " " << ctr.bannedFps << " " << ctr.unBannedFps;

            sctx.Get<NSFStats::TSumMetric<ui64>>("hysteresis_time").Inc(ctr.hysteresisTime);
            sctx.Get<NSFStats::TSumMetric<ui64>>("unban_time").Inc(ctr.unBanTime);
            sctx.Get<NSFStats::TSumMetric<ui64>>("skipped_fps").Inc(ctr.skipped);
            sctx.Get<NSFStats::TSumMetric<ui64>>("updated_fps").Inc(ctr.updated);
            sctx.Get<NSFStats::TSumMetric<ui64>>("changed_fps").Inc(ctr.changed);
            sctx.Get<NSFStats::TSumMetric<ui64>>("unbanned_fps").Inc(ctr.unBannedFps);
            sctx.Get<NSFStats::TSumMetric<ui64>>("banned_fps").Inc(ctr.bannedFps);

            NSFStats::TSolomonContext lctx{SensorsContext.Detached(), {{"place", "herschel_processor"}, {"version", key}}};

            lctx.Get<NSFStats::TSumMetric<ui64>>("hysteresis_time").Inc(ctr.hysteresisTime);
            lctx.Get<NSFStats::TSumMetric<ui64>>("unban_time").Inc(ctr.unBanTime);
            lctx.Get<NSFStats::TSumMetric<ui64>>("skipped_fps").Inc(ctr.skipped);
            lctx.Get<NSFStats::TSumMetric<ui64>>("updated_fps").Inc(ctr.updated);
            lctx.Get<NSFStats::TSumMetric<ui64>>("changed_fps").Inc(ctr.changed);
            lctx.Get<NSFStats::TSumMetric<ui64>>("unbanned_fps").Inc(ctr.unBannedFps);
            lctx.Get<NSFStats::TSumMetric<ui64>>("banned_fps").Inc(ctr.bannedFps);
        }
    }

    NYT::TFuture<THerschelProcessor::TPrepareForAsyncWriteResult> THerschelProcessor::PrepareForAsyncWrite() {
        auto bruDataForWrite = MakeAtomicShared<TVector<TYtQueue::TWriteRow>>(std::move(BruPacker.Finish()));
        BruPacker.Clear();

        return NYT::MakeFuture<TPrepareForAsyncWriteResult>(
            {
                .AsyncWriter = [Config = this->Config, bruDataForWrite](NYT::NApi::ITransactionPtr tx) {
                    TYtQueue{Config.GetBrusilov().GetQueue(), tx->GetClient()}.Write(tx, *bruDataForWrite);
                }
            }
        );
    }

    bool THerschelProcessor::BanPopularFps(NCrypta::NEvent::TFingerprint key, THerschelState& state,
                                           const ui64 newIdsCounter, TBanCounterMetric* counter) {
        switch (key.GetVersion()) {
            case NCrypta::NEvent::EHerschelVersion::NAIVE:
                return BanPopularFpsNaive(state, newIdsCounter, counter);
            case NCrypta::NEvent::EHerschelVersion::MODEL_IP_UA_V0:
                return BanPopularFpsMlIpUaV0(state, newIdsCounter, counter);
        }
    }

    bool THerschelProcessor::BanPopularFpsNaive(THerschelState& state, const ui64 newIdsCounter,
                                                TBanCounterMetric* counter) {
        const ui64 durationSec{Config.GetBanCfg().GetDurationSec()};
        const ui64 idsLimit{Config.GetBanCfg().GetIdsLimit()};
        const TTimestamp limitTimestamp{TInstant::Now().Seconds() - durationSec};
        const TTimestamp hysteresisTimestamp{TInstant::Now().Seconds() - Config.GetBanCfg().GetHysteresisSec()};
        const ui64 currentIdsCount{state.GetBan().GetCounter()};

        if (!durationSec || !idsLimit) {
            // ban disabled
            return false;
        }
        if (state.GetBan().GetStamp() < limitTimestamp) {
            state.MutableBan()->SetCounter(newIdsCounter);
            state.MutableBan()->SetStamp(TInstant::Now().Seconds());
            ++(counter->unBanTime);
        } else {
            state.MutableBan()->SetCounter(currentIdsCount + newIdsCounter);
        }

        bool stateWasSwitched{false};
        if (state.GetBan().GetCounter() >= idsLimit) {
            // step over ban
            if (!state.GetBan().GetIsBanned()) {
                ++(counter->bannedFps);
                ++(counter->changed);
                state.MutableBan()->SetIsBanned(true);
                stateWasSwitched = true;
            }
        } else if (state.GetBan().GetStamp() < hysteresisTimestamp) {
            // step over unban
            if (state.GetBan().GetIsBanned()) {
                ++(counter->unBannedFps);
                ++(counter->changed);
                state.MutableBan()->SetIsBanned(false);
                stateWasSwitched = true;
            }
            ++(counter->hysteresisTime);
        }
        return stateWasSwitched;
    }

    bool THerschelProcessor::BanPopularFpsMlIpUaV0(THerschelState& state, [[maybe_unused]] const ui64 newIdsCounter,
                                                   TBanCounterMetric* counter) {
        const ui64 durationSec{Config.GetBanCfg().GetDurationSec()};
        const TTimestamp limitTimestamp{TInstant::Now().Seconds() - durationSec};

        bool stateWasSwitched{false};
        NCrypta::THerschelStats stats{};
        stats.SetYandexuidCount(state.GetYandexuids().size());
        stats.SetIdfaCount(state.GetIdfas().size());
        stats.SetGaidCount(state.GetGaids().size());

        const auto& CheckStateIsCorrect{
            [] (const NCrypta::THerschelCatboostApplier& model,
                const NCrypta::THerschelStats& stats,
                const THerschelState& state) {
                const auto result{model.Apply(state.GetUserIP(), state.GetUserAgent(), stats)};

                switch (state.GetScope()) {
                    case EHerschelScope::IP_UA:
                        return result.GetIpUseragent();
                    case EHerschelScope::IP:
                        return result.GetIp();
                    default:
                        return false;
                }
            }
        };

        if (state.GetBan().GetStamp() > limitTimestamp) {
            // wait till ban timeout is reached
            return stateWasSwitched;
        }

        const bool isBanned{!CheckStateIsCorrect(Model, stats, state)};

        if (isBanned) {
            if (!state.GetBan().GetIsBanned()) {
                ++(counter->bannedFps);
                ++(counter->changed);
                state.MutableBan()->SetIsBanned(true);
                state.MutableBan()->SetStamp(TInstant::Now().Seconds());
                stateWasSwitched = true;
            }
        } else {
            if (state.GetBan().GetIsBanned()) {
                ++(counter->unBannedFps);
                ++(counter->changed);
                state.MutableBan()->SetIsBanned(false);
                stateWasSwitched = true;
            }
        }

        return stateWasSwitched;
    }

    ui64 THerschelProcessor::UpdateHershelStateIds(THerschelState& state, size_t limit, const TChunk& rows) {
        TIdsAggregator idfaAggr{};
        TIdsAggregator gaidAggr{};
        TIdsAggregator yandexuidAggr{};

        for (const auto& id : state.GetIdfas()) {
            idfaAggr.AddIdstats(id);
        }
        for (const auto& id : state.GetGaids()) {
            gaidAggr.AddIdstats(id);
        }
        for (const auto& id : state.GetYandexuids()) {
            yandexuidAggr.AddIdstats(id);
        }

        ui64 newIdsCounter{0};
        for (const auto& chunk : rows) {
            state.SetUserIP(chunk.event.GetUserIP());
            state.SetUserAgent(chunk.event.GetUserAgent());

            for (const auto& protoId : chunk.event.GetIds()) {
                NIdentifiers::TGenericID identifier{protoId};
                const auto& value{identifier.GetValue()};
                switch (protoId.GetType()) {
                    case NCrypta::NIdentifiersProto::NIdType::YANDEXUID: {
                        newIdsCounter += yandexuidAggr.AddIdstats(value, chunk.TimeStamp);
                        break;
                    }
                    case NCrypta::NIdentifiersProto::NIdType::IDFA: {
                        newIdsCounter += idfaAggr.AddIdstats(value, chunk.TimeStamp);
                        break;
                    }
                    case NCrypta::NIdentifiersProto::NIdType::GAID: {
                        newIdsCounter += gaidAggr.AddIdstats(value, chunk.TimeStamp);
                        break;
                    }
                    default: {
                        break;
                    }
                }
            }
        }

        const TTimestamp currentTimestamp{TInstant::Now().Seconds()};
        const TTimestamp limitTimestamp{currentTimestamp - (60ull * 60 * 24)};
        ClearState(state);
        
        for (const auto& id : idfaAggr.Prioritize(limit, currentTimestamp, limitTimestamp)) {
            (*state.AddIdfas()) = id;
        }

        for (const auto& id : gaidAggr.Prioritize(limit, currentTimestamp, limitTimestamp)) {
            (*state.AddGaids()) = id;
        }

        for (const auto& id : yandexuidAggr.Prioritize(limit, currentTimestamp, limitTimestamp)) {
            (*state.AddYandexuids()) = id;
        }

        return newIdsCounter;
    }

    void THerschelProcessor::ClearState(THerschelState& state) {
        state.ClearIdfas();
        state.ClearGaids();
        state.ClearYandexuids();
    }

    void THerschelProcessor::ClearVultureState(const TString& key) {
        BruPacker.Add(
            NExperiments::GetBigbShardNumber(key, Config.GetBrusilov().GetReshardingModule()),
            CreateVultEvent(key, {})
        );
    }

    void THerschelProcessor::UpdateVultureState(const TString& key, THerschelState& herschel) {
        using TSourceUniq = yabs::proto::Profile::TSourceUniq;
        using TRepeatedIdStat = NProtoBuf::RepeatedPtrField<TIdStat>;

        TAssociatedUids associated{};

        const auto FillVultureValues{
            [&](const TRepeatedIdStat& ids, const TSourceUniq::EIdType idType, const size_t limit) {
                const size_t idsLimit{std::min(static_cast<size_t>(ids.size()), limit)};
                for (size_t index{0}; index < idsLimit; ++index) {
                    auto* id{associated.AddValueRecords()};
                    id->set_user_id(ids[index].GetId());
                    id->set_id_type(idType);
                    id->set_crypta_graph_distance(index + 1);
                    id->set_link_type(TSourceUniq::LT_CRYPTA_HERSCHEL);
                    id->add_link_types(TSourceUniq::LT_CRYPTA_HERSCHEL);
                }
            }
        };

        {
            auto* id{associated.MutableKeyRecord()};
            id->set_user_id(Base64Encode(key));
            id->set_id_type(TSourceUniq::FINGERPRINT);
        }

        FillVultureValues(herschel.GetIdfas(), TSourceUniq::IDFA, 2);
        FillVultureValues(herschel.GetGaids(), TSourceUniq::GAID, 2);
        FillVultureValues(herschel.GetYandexuids(), TSourceUniq::YANDEX_UID, 5);

        BruPacker.Add(
            NExperiments::GetBigbShardNumber(key, Config.GetBrusilov().GetReshardingModule()),
            CreateVultEvent(key, associated)
        );
    }

    TEventMessage THerschelProcessor::CreateVultEvent(const TString& key, const TAssociatedUids& associated) {
        TEventMessage wrapped{};
        TVultureEvent event{};

        wrapped.SetTimeStamp(TInstant::Now().Seconds());
        wrapped.SetType(NCrypta::NEvent::EMessageType::VULTURE);
        wrapped.SetSource("herschel");
        {
            // todo: move to config
            event.MutableLocations()->Add(NCrypta::NEvent::EVultureLocation::VULT_DEFAULT);
            event.MutableLocations()->Add(NCrypta::NEvent::EVultureLocation::VULT_EXP);
            event.SetKeyPrefix("he:");
        }
        event.SetId(key);
        event.MutableAssociated()->CopyFrom(associated);

        Y_PROTOBUF_SUPPRESS_NODISCARD event.SerializeToString(wrapped.MutableBody());
        return wrapped;
    }
}
