#pragma once

#include "encoded_user_data.h"
#include "normal_suffucient_statistic_aggregator.h"
#include "strata_aggregator.h"
#include "utils.h"

#include "normal_suffucient_statistic_aggregator.h"

#include <crypta/lib/proto/user_data/user_data_stats.pb.h>
#include <library/cpp/bloom_filter/bloomfilter.h>

#include <util/generic/map.h>
#include <util/generic/hash.h>
#include <util/stream/str.h>

template<>
struct TLess<NLab::TUserDataStats::TGenderAgeIncome> {
    bool operator()(const NLab::TUserDataStats::TGenderAgeIncome& lhs, const NLab::TUserDataStats::TGenderAgeIncome& rhs) const {
        const auto lhsTie = std::make_tuple(lhs.gender(), lhs.age(), lhs.income());
        const auto rhsTie = std::make_tuple(rhs.gender(), rhs.age(), rhs.income());
        return lhsTie < rhsTie;
    }
};

namespace NLab {
    struct TDefaultAffinitiesOptions {
        struct TWeightWithCount {
            float Weight = 0;
            ui64 Count = 0;
        };

        using TProtoAffinities = TUserDataStats::TAffinitiveStats;
        using TTokensMap = THashMap<TString, TWeightWithCount>;
        using TTokenMapIter = TTokensMap::const_iterator;
        using TProtoTokensStats = TUserDataStats::TTokensStats;
        using TTokenStats = TUserDataStats::TWeightedTokenStats;

        static bool HasAffinities(const TUserDataStats& userDataStats);
        static const TProtoAffinities& GetAffinities(const TUserDataStats& userDataStats);
        static TProtoAffinities* MutableAffinities(TUserDataStats& userDataStats);
        static void UpdateAffinitiesWithTokensStats(TTokensMap& tokensMap, const TProtoTokensStats& tokensStats);
        static void UpdateAffinitiesWithTokensStatsWoAggregation(TTokensMap& tokensMap, const TProtoTokensStats& tokensStats);
        static float GetScore(TTokensMap::const_reference kv, const NEncodedUserData::TIdToWeightedTokenDict*);
        static ui64 GetCount(TTokensMap::const_reference kv);
        static void FillTokensStats(TTokenStats* protoToken, TTokensMap::const_reference kv);
    };

    struct TAffinitiesEncodedOptions {
        using TProtoAffinities = TUserDataStats::TAffinitiveStatsEncoded;
        using TTokensMap = THashMap<ui32, ui32>;
        using TTokenMapIter = TTokensMap::const_iterator;
        using TProtoTokensStats = TUserDataStats::TTokensStatsEncoded;
        using TTokenStats = TUserDataStats::TTokenStatsEncoded;

        static bool HasAffinities(const TUserDataStats& userDataStats);
        static const TProtoAffinities& GetAffinities(const TUserDataStats& userDataStats);
        static TProtoAffinities* MutableAffinities(TUserDataStats& userDataStats);
        static void UpdateAffinitiesWithTokensStats(TTokensMap& tokensMap, const TProtoTokensStats& tokensStats);
        static void UpdateAffinitiesWithTokensStatsWoAggregation(TTokensMap& tokensMap, const TProtoTokensStats& tokensStats);
        static float GetScore(TTokensMap::const_reference kv, const NEncodedUserData::TIdToWeightedTokenDict* dict);
        static ui64 GetCount(TTokensMap::const_reference kv);
        static void FillTokensStats(TTokenStats* protoToken, TTokensMap::const_reference kv);
    };

    template<typename TAffinitiesOptions = TDefaultAffinitiesOptions>
    class TUserDataStatsAggregator {
    public:
        struct TSettings {
            int MaxTokensCount = 1000000;
            double MinSampleRatio = 0.001;
            bool AccumulateAffinities = true;
        };

        TUserDataStatsAggregator() = default;
        explicit TUserDataStatsAggregator(const TSettings& settings)
            : Settings(settings)
        {
            Filter.Options.SetCapacity(0);
        }

        void UpdateWith(const TUserDataStats& userDataStats) {
            if (UpdateFilter(userDataStats) && !userDataStats.GetCounts().GetTotal()) {
                // skip record
                return;
            }

            if (!GetTotalCount()) {
                Identifiers = userDataStats.GetIdentifiers();
            }

            UpdateCounts(userDataStats);
            UpdateAttributes(userDataStats);
            UpdateAffinities(userDataStats);

            if (userDataStats.HasDistributions()) {
                FeatureDistribution.Aggregator.UpdateWith(userDataStats.GetDistributions().GetMain());
            }

            if (!SegmentInfo.GetInfo().size()) {
                SegmentInfo.CopyFrom(userDataStats.GetSegmentInfo());
            }
        }

        void MergeInto(TUserDataStats& userDataStats, const NEncodedUserData::TIdToWeightedTokenDict* words = nullptr, const NEncodedUserData::TIdToWeightedTokenDict* hosts = nullptr,
            const NEncodedUserData::TIdToWeightedTokenDict* apps = nullptr) {
            // Distributions
            auto* distributions = userDataStats.MutableDistributions();
            distributions->Clear();
            FeatureDistribution.Aggregator.MergeInto(*distributions->MutableMain());

            // Stratum
            StrataAggregator.MergeInto(userDataStats.MutableStratum());

            // Attributes
            auto* attributesStats = userDataStats.MutableAttributes();
            attributesStats->Clear();

            for (const auto& gender : DetailedCounts.Gender) {
                auto* genderCount = attributesStats->AddGender();
                genderCount->SetGender(gender.first);
                genderCount->SetCount(gender.second);
            }
            for (const auto& age : DetailedCounts.Age) {
                auto* ageCount = attributesStats->AddAge();
                ageCount->SetAge(age.first);
                ageCount->SetCount(age.second);
            }
            for (const auto& device : DetailedCounts.Device) {
                auto* deviceCount = attributesStats->AddDevice();
                deviceCount->SetDevice(device.first);
                deviceCount->SetCount(device.second);
            }
            for (const auto& region : DetailedCounts.Region) {
                auto* regionCount = attributesStats->AddRegion();
                regionCount->SetRegion(region.first);
                regionCount->SetCount(region.second);
            }
            for (const auto& income : DetailedCounts.Income) {
                auto* incomeCount = attributesStats->AddIncome();
                incomeCount->SetIncome(income.first);
                incomeCount->SetCount(income.second);
            }
            for (const auto& genderAgeIncome : DetailedCounts.GenderAgeIncome) {
                auto* genderAgeIncomeCount = attributesStats->AddGenderAgeIncome();
                genderAgeIncomeCount->MutableGenderAgeIncome()->CopyFrom(genderAgeIncome.first);
                genderAgeIncomeCount->SetCount(genderAgeIncome.second);
            }

            // Affinities
            auto* affinityStats = TAffinitiesOptions::MutableAffinities(userDataStats);
            affinityStats->Clear();
            FillTokensStats(*affinityStats->MutableWords(), Affinity.Words, words);
            FillTokensStats(*affinityStats->MutableHosts(), Affinity.Hosts, hosts);
            FillTokensStats(*affinityStats->MutableApps(), Affinity.Apps, apps);

            // Counts
            userDataStats.MutableCounts()->CopyFrom(Counts);

            // Filter
            if (Filter.Options.GetCapacity() > 0) {
                auto* filter = userDataStats.MutableFilter();
                filter->MutableOptions()->CopyFrom(Filter.Options);
                TString stringFilter;
                TStringOutput stream(stringFilter);
                Filter.Instance.Save(&stream);
                filter->SetBloomFilter(stringFilter);
            }

            // Segment info
            userDataStats.MutableSegmentInfo()->CopyFrom(SegmentInfo);

            // Identifiers
            userDataStats.MutableIdentifiers()->CopyFrom(Identifiers);
        }

        ui64 GetTotalCount() const {
            return Counts.GetTotal();
        }

    private:
        struct TAffinitiveStats {
            ui64 UsersCount = 0;
            ui64 TokensCount = 0;
            typename TAffinitiesOptions::TTokensMap TokensMap;
        };

        struct TDetailedCounts {
            THashMap<TGender, ui64> Gender;
            THashMap<TAge, ui64> Age;
            THashMap<TDevice, ui64> Device;
            THashMap<ui64, ui64> Region;
            THashMap<TIncome, ui64> Income;
            TMap<TUserDataStats::TGenderAgeIncome, ui64> GenderAgeIncome;
        };

        struct TFeatureDistribution {
            TNormalSufficientStatisticAggregator Aggregator;
        };

        struct TAffinity {
            TAffinitiveStats Words;
            TAffinitiveStats Hosts;
            TAffinitiveStats Apps;
        };

        struct TFilter {
            TBloomFilter Instance;
            TFilterOptions Options;
        };

        TUserDataStats::TCounts Counts;
        TDetailedCounts DetailedCounts;

        TFeatureDistribution FeatureDistribution;
        TAffinity Affinity;

        TFilter Filter;
        TIdentifiers Identifiers;

        TStrataAggregator StrataAggregator;
        TSegmentInfo SegmentInfo;

        TSettings Settings;

        bool UpdateFilter(const TUserDataStats& userDataStats) {
            if (userDataStats.HasFilter()) {
                const auto& protoFilter = userDataStats.GetFilter();
                if (protoFilter.HasBloomFilter()) {
                    auto stringFilter = protoFilter.GetBloomFilter();
                    TBloomFilter currentFilter;
                    TStringInput istream(stringFilter);
                    currentFilter.Load(&istream);
                    Filter.Instance = Filter.Instance.GetBitCount() ? Filter.Instance.Union(currentFilter) : currentFilter;
                } else {
                    if (Filter.Instance.GetBitCount() == 0) {
                        Filter.Options = protoFilter.GetOptions();
                        TBloomFilter bloomFilter(Filter.Options.GetCapacity(), Filter.Options.GetErrorRate());
                        Filter.Instance = bloomFilter;
                    }
                    if (protoFilter.HasSingle()) {
                        Filter.Instance.Add(protoFilter.GetSingle());
                    }
                }

                if (Filter.Options.GetCapacity() == 0) {
                    Filter.Options = protoFilter.GetOptions();
                }

                return true;
            }

            return false;
        }
        void UpdateCounts(const TUserDataStats& userDataStats) {
            const auto& counts = userDataStats.GetCounts();
            Counts.SetTotal(Counts.GetTotal() + counts.GetTotal());
            Counts.SetWithData(Counts.GetWithData() + counts.GetWithData());
            Counts.SetUniqYuid(Counts.GetUniqYuid() + counts.GetUniqYuid());
        }

        void UpdateAttributes(const TUserDataStats& userDataStats) {
            if (userDataStats.HasAttributes()) {
                const auto& attributes = userDataStats.GetAttributes();

                for (const auto& genderCount : attributes.GetGender()) {
                    DetailedCounts.Gender[genderCount.GetGender()] += genderCount.GetCount();
                }
                for (const auto& ageCount : attributes.GetAge()) {
                    DetailedCounts.Age[ageCount.GetAge()] += ageCount.GetCount();
                }
                for (const auto& deviceCount : attributes.GetDevice()) {
                    DetailedCounts.Device[deviceCount.GetDevice()] += deviceCount.GetCount();
                }
                for (const auto& regionCount : attributes.GetRegion()) {
                    DetailedCounts.Region[regionCount.GetRegion()] += regionCount.GetCount();
                }
                for (const auto& incomeCount : attributes.GetIncome()) {
                    DetailedCounts.Income[incomeCount.GetIncome()] += incomeCount.GetCount();
                }
                for (const auto& genderAgeIncomeCount : attributes.GetGenderAgeIncome()) {
                    DetailedCounts.GenderAgeIncome[genderAgeIncomeCount.GetGenderAgeIncome()] += genderAgeIncomeCount.GetCount();
                }

                StrataAggregator.UpdateWith(userDataStats.GetStratum());
            }
        }

        void UpdateAffinities(const TUserDataStats& userDataStats) {
            const bool hasAffinities = TAffinitiesOptions::HasAffinities(userDataStats);
            const bool affinityTypeNotSet = userDataStats.GetAffinitiesTypeCase() == TUserDataStats::AffinitiesTypeCase::AFFINITIESTYPE_NOT_SET;
            Y_ENSURE(hasAffinities || affinityTypeNotSet, "Wrong affinity format");

            if (hasAffinities) {
                const auto& affinities = TAffinitiesOptions::GetAffinities(userDataStats);
                if (affinities.HasWords()) {
                    UpdateWithTokenStats(Affinity.Words, affinities.GetWords(), Settings.AccumulateAffinities);
                }

                if (affinities.HasHosts()) {
                    UpdateWithTokenStats(Affinity.Hosts, affinities.GetHosts(), Settings.AccumulateAffinities);
                }

                if (affinities.HasApps()) {
                    UpdateWithTokenStats(Affinity.Apps, affinities.GetApps(), Settings.AccumulateAffinities);
                }
            }
        }

        void UpdateWithTokenStats(TUserDataStatsAggregator::TAffinitiveStats& stats, const typename TAffinitiesOptions::TProtoTokensStats& tokensStats, bool accumulateStats) {
            if (accumulateStats) {
                return UpdateWithTokenStats(stats, tokensStats);
            }
            return UpdateWithTokenStatsWoAccumulation(stats, tokensStats);
        }

        void UpdateWithTokenStats(TUserDataStatsAggregator::TAffinitiveStats& stats, const typename TAffinitiesOptions::TProtoTokensStats& tokensStats) {
            TAffinitiesOptions::UpdateAffinitiesWithTokensStats(stats.TokensMap, tokensStats);
            stats.TokensCount += tokensStats.GetTokensCount();
            stats.UsersCount += tokensStats.GetUsersCount();
        }

        void UpdateWithTokenStatsWoAccumulation(TUserDataStatsAggregator::TAffinitiveStats& stats, const typename TAffinitiesOptions::TProtoTokensStats& tokensStats) {
            TAffinitiesOptions::UpdateAffinitiesWithTokensStatsWoAggregation(stats.TokensMap, tokensStats);
            stats.TokensCount = stats.TokensMap.size();
            stats.UsersCount = stats.TokensCount != 0 ? 1 : 0;
        }

        void FillTokensStats(typename TAffinitiesOptions::TProtoTokensStats& tokensStats, const TUserDataStatsAggregator::TAffinitiveStats& stats, const NEncodedUserData::TIdToWeightedTokenDict* dict) {
            tokensStats.MutableToken()->Clear();
            tokensStats.SetTokensCount(stats.TokensCount);
            tokensStats.SetUsersCount(stats.UsersCount);

            using TTokensVector = TVector<std::pair<float, typename TAffinitiesOptions::TTokensMap::const_pointer>>;
            TTokensVector tokens;
            tokens.reserve(stats.TokensMap.size());

            for (const auto& kv : stats.TokensMap) {
                if (static_cast<float>(TAffinitiesOptions::GetCount(kv)) / stats.UsersCount >= Settings.MinSampleRatio) {
                    tokens.emplace_back(TAffinitiesOptions::GetScore(kv, dict), &kv);
                }
            }

            const auto size = std::min(static_cast<int>(tokens.size()), Settings.MaxTokensCount);
            std::partial_sort(tokens.begin(), tokens.begin() + size, tokens.end(), std::greater<typename TTokensVector::value_type>());
            for (int i = 0; i < size; ++i) {
                TAffinitiesOptions::FillTokensStats(tokensStats.AddToken(), *tokens.at(i).second);
            }
        }
    };

    using TEncodedUserDataStatsAggregator = TUserDataStatsAggregator<TAffinitiesEncodedOptions>;
}
