#pragma once

#include "utils.h"

#include <crypta/lab/proto/lookalike.pb.h>
#include <crypta/lib/native/vectors/vectors.h>
#include <crypta/lib/python/native_yt/cpp/registrar.h>
#include <crypta/lib/python/native_yt/cpp/proto.h>
#include <crypta/lookalike/lib/native/common.h>
#include <crypta/lookalike/lib/native/segment_embedding_model.h>
#include <crypta/lookalike/lib/native/user_embedding_model.h>
#include <crypta/lookalike/proto/user_embedding.pb.h>

#include <mapreduce/yt/interface/client.h>

#include <library/cpp/bloom_filter/bloomfilter.h>
#include <util/datetime/base.h>
#include <util/generic/vector.h>
#include <util/generic/queue.h>
#include <util/stream/str.h>
#include <util/string/cast.h>

#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <list>
#include <sstream>
#include <string>
#include <utility>

using namespace NYT;
using namespace NLab;
using TEmbedding = NCrypta::NLookalike::TEmbedding;
using TUserEmbedding = NCrypta::NLookalike::TUserEmbedding;

TVector<float> UnpackVector(const TString& stringVector);

class TPredictMapper: public IMapper<TTableReader<TUserEmbedding>, TTableWriter<TLookalikeIntermediateScoredRecord>> {
public:
    TPredictMapper()
        : State()
        , Segments()
        , SegmentEmbeddingModel(nullptr)
        , Outputs()
    {
    }
    TPredictMapper(const TBuffer& buffer)
        : State(buffer)
        , Segments()
        , SegmentEmbeddingModel(nullptr)
        , Outputs()
    {
    }

    class TSegment {
    public:
        struct TCountStats {
            TCountStats()
                : Segment(0)
                , Global(0)
                , Conditional(1)
            {
            }
            double Segment;
            double Global;
            double Conditional;
        };

        TSegment(TString segmentId, const TLookalikeMapping::TSegmentMeta& meta,
                 const TUserDataStats* globalUserDataStats, double maxFilterErrorRate,
                 const TEmbedding& segmentDssmVector);

        void YieldTop(TTableWriter<TLookalikeIntermediateScoredRecord>* output);

        void Add(const TUserEmbedding& userEmbedding, TLookalikeIntermediateScoredRecord* out, const TEmbedding& dssmVector);

    private:
        double ComputeScore(const NCrypta::NLookalike::TEmbedding& rowDssmVector);

        TString GetStorageIndex(const TUserData::TAttributes& attributes, bool yetFiltered = true);

        void InitRegionStats(size_t sizeRegions = 15);

        void InitDeviceStats();

        TUserData::TAttributes GetFilterAttributes(const TUserData::TAttributes& userDataAttributes);

        double GetProbability(const TUserData::TAttributes& attributes);

    private:
        double MaxFilterErrorRate;
        double GlobalTotal;
        ui64 OutputSize;
        TString SegmentId;
        TString ExternalId;
        TEmbedding DssmVector;
        const TLookalikeOptions* Options;
        const TUserDataStats::TAttributesStats* SegmentAttributes;
        const TUserDataStats::TAttributesStats* GlobalAttributes;
        THashMap<TString, TOutputStorage<TLookalikeIntermediateScoredRecord>> OutputStorages;
        THashMap<ui64, double> OutputSizes;
        THashMap<ui64, TCountStats> RegionCounts;
        THashMap<ui64, TCountStats> DeviceProbabilities;
        double CurrentCount = 0;
        double TopProbability = 1.;
        double DeviceConditionalSum = 1.;
    };

    void Start(TWriter* writer) override;

    void Do(TTableReader<TUserEmbedding>* input, TTableWriter<TLookalikeIntermediateScoredRecord>* output) override;

    void Save(IOutputStream& output) const override {
        State.Save(output);
    }
    void Load(IInputStream& input) override {
        State.Load(input);
    }

private:
    void InitSegments();

    TLookalikeIntermediateScoredRecord* AddOutput(TUserEmbedding& userEmbedding);

    void YieldTop(TTableWriter<TLookalikeIntermediateScoredRecord>* output);

private:
    NNativeYT::TProtoState<TLookalikeMapping> State;
    TVector<TSegment> Segments;
    THolder<NCrypta::NLookalike::TSegmentEmbeddingModel> SegmentEmbeddingModel;
    std::list<TLookalikeIntermediateScoredRecord> Outputs;
    const int MAX_COUNT = 100000;
};

class TPredictReducer: public IReducer<TTableReader< ::google::protobuf::Message>, TTableWriter<TLookalikeOutput>> {
public:
    TPredictReducer()
        : State()
    {
    }
    TPredictReducer(const TBuffer& buffer)
        : State(buffer)
    {
    }

    using TInputs = std::tuple<TUserDataStats, TLookalikeIntermediateScoredRecord>;
    using TOutputs = std::tuple<TLookalikeOutput>;

    class TSegment {
    public:
        class TStatistics {
        public:
            TStatistics() {
            }
            TStatistics(ui64 inputSize, ui64 outputSize, double skipRate) {
                Init(inputSize, outputSize, skipRate);
            }

            void Init(ui64 inputSize, ui64 outputSize, double skipRate) {
                InputSize = inputSize;
                OutputSize = outputSize;
                TestInputSize = static_cast<double>(static_cast<double>(inputSize) * skipRate);
            }

            void AddToFilter(ui64 count = 1) {
                CurrentFilteredRecordsCount += count;
            }

            void AddToOutput(ui64 count = 1) {
                if (CurrentOutputtedRecordsCount < InputSize && CurrentOutputtedRecordsCount + count >= InputSize) {
                    FilteredRecordsCountByInputSizeTime = CurrentFilteredRecordsCount;
                }
                CurrentOutputtedRecordsCount += count;
            }

            bool HasOutputEnoughRecords() {
                return OutputSize == CurrentOutputtedRecordsCount;
            }

            void Write() {
                if (!InputSize || InputSize > OutputSize) {
                    return;
                }
                ::WriteCustomStatistics(
                    TNode()
                        ("Metrics", TNode()
                            ("InputSize", TestInputSize)
                            ("OutputSize", OutputSize)
                            ("FilteredRecords", TNode()
                                ("ByInputSize", TNode()
                                    ("Count", FilteredRecordsCountByInputSizeTime)
                                    ("Rate", ComputePercent(FilteredRecordsCountByInputSizeTime, TestInputSize)))
                                ("ByOutputSize", TNode()
                                    ("Count", CurrentFilteredRecordsCount)
                                    ("Rate", ComputePercent(CurrentFilteredRecordsCount, TestInputSize)))
                )));
            }

        private:
            ui64 InputSize = 0;
            ui64 TestInputSize = 0;
            ui64 OutputSize = 0;
            ui64 CurrentFilteredRecordsCount = 0;
            ui64 CurrentOutputtedRecordsCount = 0;
            ui64 FilteredRecordsCountByInputSizeTime = 0;

            double ComputeFraction(ui64 numerator, ui64 denominator) {
                if (!denominator) {
                    return 1.;
                }
                return static_cast<double>(numerator) / static_cast<double>(denominator);
            }

            ui64 ComputePercent(ui64 numerator, ui64 denominator) {
                return static_cast<ui64>(ComputeFraction(numerator, denominator) * 100.);
            }
        };

        class TReserve {
        public:
            struct TReserveRecord {
                TString yuid;
                double score;
                int filteredCount;
                TReserveRecord() {
                }
                TReserveRecord(TString yuid, double score, int filteredCount)
                    : yuid(yuid)
                    , score(score)
                    , filteredCount(filteredCount)
                {
                }
            };
            TReserve()
                : Reserve()
            {
            }

            bool Push(TString yuid, double score, bool filtered) {
                if (filtered) {
                    CurrentFilteredCount++;
                } else {
                    TReserveRecord record(yuid, score, CurrentFilteredCount - PreviousFilteredCount);
                    Reserve.push(record);
                    PreviousFilteredCount = CurrentFilteredCount;
                    CurrentFilteredCount = 0;
                }
                return !filtered;
            }

            TReserveRecord Pop() {
                auto record = Reserve.front();
                Reserve.pop();
                return record;
            }

            void Clear() {
                Reserve.clear();
                PreviousFilteredCount = 0;
                CurrentFilteredCount = 0;
            }

            bool IsEmpty() {
                return Reserve.empty();
            }

            const TQueue<TReserveRecord> GetQueue() {
                return Reserve;
            }

            ui64 Size() {
                return Reserve.size();
            }

        private:
            TQueue<TReserveRecord> Reserve;
            int CurrentFilteredCount = 0;
            int PreviousFilteredCount = 0;
        };

        TSegment()
            : Filter()
            , SegmentId()
            , ExternalId()
            , MaxCoverage()
            , Reserve()
            , Statistics()
        {
        }

        void Init(const TUserDataStats& segmentStats, NNativeYT::TProtoState<TLookalikeReducing>& State) {
            auto segments = State->GetSegments();
            auto skipRate = (State->GetSamplingOptions()).GetSkipRate();
            SegmentId = segmentStats.GetGroupID();
            auto& segment = segments[SegmentId];
            Statistics.Init(segment.GetCounts().GetInput(), segment.GetCounts().GetOutput(), skipRate);
            MaxCoverage = segment.GetCounts().GetMaxCoverage();
            ExternalId = ToString(segment.GetPermanentId());
            if (segmentStats.HasFilter() && segmentStats.GetFilter().HasBloomFilter()) {
                TStringInput stringFilter(segmentStats.GetFilter().GetBloomFilter());
                Filter.Load(&stringFilter);
            } else {
                WithoutFilter = true;
            }
            TestSample = TFilterIdentifier(skipRate);
        }

        bool Has(const TString& yandexuid) {
            if (WithoutFilter) {
                return false;
            }
            return Filter.Has(yandexuid);
        }

        bool IsReadyToFinish() {
            return Statistics.HasOutputEnoughRecords();
        }

        bool AddToReserve(const TString& yandexuid, double score, double maxReserveSize) {
            if (Reserve.Size() >= maxReserveSize) {
                return false;
            }
            return Reserve.Push(yandexuid, score, Has(yandexuid) && TestSample.Filter(yandexuid));
        }

        void WriteCustomStatistics() {
            Statistics.Write();
        }

        void YieldFromReserve(double count, TTableWriter<TLookalikeOutput>* output) {
            for (int i = 0; i < count && !Reserve.IsEmpty(); ++i) {
                auto record = Reserve.Pop();
                Yield(record.score, record.yuid, output);
                Statistics.AddToFilter(record.filteredCount);
            }
            Reserve.Clear();
        }

        bool Yield(double score, const TString& yandexuid,
                   TTableWriter<TLookalikeOutput>* output) {
            if (Statistics.HasOutputEnoughRecords()) {
                return false;
            }
            if (Has(yandexuid)) {
                if (TestSample.Filter(yandexuid)) {
                    Statistics.AddToFilter();
                }
                return false;
            }
            TLookalikeOutput out;
            out.SetExternalId(ExternalId);
            out.SetSegmentId(SegmentId);
            out.SetScore(score);
            out.SetYandexuid(yandexuid);
            output->AddRow(out);
            Statistics.AddToOutput();
            return true;
        }

    private:
        TBloomFilter Filter;
        TString SegmentId;
        TString ExternalId;
        ui64 MaxCoverage;
        bool WithoutFilter = false;
        TReserve Reserve;
        TStatistics Statistics;
        TFilterIdentifier TestSample{};
    };

    void Do(TReader* input, TTableWriter<TLookalikeOutput>* output) override;

    void Save(IOutputStream& output) const override {
        State.Save(output);
    }
    void Load(IInputStream& input) override {
        State.Load(input);
    }

private:
    NNativeYT::TProtoState<TLookalikeReducing> State;
};

class TLookalikeMapper: public IMapper<TTableReader<TNode>, TTableWriter<TNode>> {
public:
    TLookalikeMapper()
        : State()
    {
    }
    TLookalikeMapper(const TBuffer& buffer)
        : State(buffer)
    {
    }

    class TSegment {
    public:
        template <class Meta>
        TSegment(TString segmentId, const Meta& meta)
            : SegmentId(segmentId)
        {
            double minProbability = 1e-4;
            auto stringVector = meta.vector();
            Vector = UnpackVector(stringVector);
            auto parameters = meta.parameters();
            TopProbability = parameters.probability();
            auto factors = meta.factors();
            for (const auto& factor : factors) {
                for (const auto& probability : factor.second.probabilities()) {
                    FactorProbabilities[factor.first][probability.first] =
                        std::max(minProbability, probability.second);
                }
                FactorProbabilities[factor.first][UNDEFINED] = minProbability;
            }
        }

        void YieldTop(TTableWriter<TNode>* output);

        void Add(const TNode& row, TNode* out, const TVector<float>& rowVector);

    private:
        TString UNDEFINED = "-";
        const TString SENTINEL = "$";
        double TopProbability;
        TString SegmentId;
        TVector<float> Vector;
        THashMap<TString, THashMap<TString, double>> FactorProbabilities;
        THashMap<TString, TOutputStorage<TNode>> OutputStorages;

        double ComputeScore(const TVector<float>& rowVector);

        TString GetStorageIndex(const THashMap<TString, TString>& factors);

        THashMap<TString, TString> GetFactors(const TNode& row);

        double GetProbability(const THashMap<TString,
                                             TString>& factors,
                              double minValue = 1e-8);
    };

    void Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) override;

    void Save(IOutputStream& output) const override {
        State.Save(output);
    }
    void Load(IInputStream& input) override {
        State.Load(input);
    }

private:
    NNativeYT::TProtoState<Mapping> State;
    TVector<TSegment> Segments;
    std::list<TNode> Outputs;
    time_t OldestTimestamp = 0;
    const int MAX_COUNT = 100000;

    void InitSegments();

    TNode* AddOutput(TNode& row);

    void YieldTop(TTableWriter<TNode>* output);
};

class TLookalikeReducer: public IReducer<TTableReader<TNode>, TTableWriter<TNode>> {
public:
    TLookalikeReducer()
        : State()
    {
    }
    TLookalikeReducer(const TBuffer& buffer)
        : State(buffer)
    {
    }
    void Do(TTableReader<TNode>* input, TTableWriter<TNode>* output) override;

    void Save(IOutputStream& output) const override {
        State.Save(output);
    }
    void Load(IInputStream& input) override {
        State.Load(input);
    }

private:
    NNativeYT::TProtoState<Reducing> State;
};

class TLookalikeJoiner: public IReducer<TTableReader<TNode>, TTableWriter<TNode>> {
public:
    TLookalikeJoiner()
        : State()
    {
    }
    TLookalikeJoiner(const TBuffer& buffer)
        : State(buffer)
    {
    }
    void Do(TTableReader<TNode>* input, TTableWriter<TNode>* output);

    void Save(IOutputStream& output) const {
        State.Save(output);
    }
    void Load(IInputStream& input) {
        State.Load(input);
    }

private:
    NNativeYT::TProtoState<Reducing> State;
};
