#include "stats.h"

#include <crypta/lib/native/yt/utils/helpers.h>

template <>
struct TLess<TTotalHistogram::TExternalStep> {
    bool operator()(const TTotalHistogram::TExternalStep& first, const TTotalHistogram::TExternalStep& second) const {
        return first.GetValue() < second.GetValue();
    }
};

template <>
struct TLess<TTotalHistogram::THistogram::TStep> {
    bool operator()(const TTotalHistogram::THistogram::TStep& first, const TTotalHistogram::THistogram::TStep& second) {
        return first.GetValue() < second.GetValue();
    }
};

namespace NEdgeStats {
    static const i64 SECONDS_PER_DAY = 24 * 60 * 60;
    static const double DEFAULT_DECAY_RATE = 0.75;
    static const double DEFAULT_COMPATIBILITY_PR = 0.001;
    static const double DEFAULT_SURVIVAL_PROBABILITY_FOR_TYPES_WITHOUT_DATES_INFO = 1e-4;
    static const double MIN_SURVIVAL_SCORE_FOR_OLD_BUT_GOOD_EDGE = 0.01;
    static const double MAX_RATE_FOR_SURVIVAL_SCORE_FOR_OLD_BUT_GOOD_EDGE = 100;
    static const double SURVIVAL_DURATION_ACCEPTABLE_DELAY = 14;

    static const i64 MAX_DURATION = 120;
    static const i64 TOO_MUCH_COUNT_FOR_BAD_SCORE = 11;
    static const double TOO_MUCH_COMPATIBILITY_FOR_BAD_SCORE = 0.9;
    static const double EPS = 1e-9;

    namespace NEdgeFields {
        static const TString DATES("dates");
        static const TString SOURCE_TYPE("sourceType");
        static const TString LOG_SOURCE("logSource");
        static const TString ID1_TYPE("id1Type");
        static const TString ID2_TYPE("id2Type");

        TString GetEdgeType(const TString& sourceType, const TString& logSource, const TString& id1Type, const TString& id2Type) {
            TEdgeType edgeType;
            edgeType.SetSourceType(sourceType);
            edgeType.SetLogSource(logSource);
            edgeType.SetId1Type(id1Type);
            edgeType.SetId2Type(id2Type);
            TString result;
            Y_PROTOBUF_SUPPRESS_NODISCARD edgeType.SerializeToString(&result);
            return result;
        }

        TString GetEdgeType(const NYT::TNode& row) {
            if (
                !row.HasKey(NEdgeFields::SOURCE_TYPE) ||
                !row.HasKey(NEdgeFields::LOG_SOURCE) ||
                !row.HasKey(NEdgeFields::ID1_TYPE) ||
                !row.HasKey(NEdgeFields::ID2_TYPE)) {
                return "";
            }
            return GetEdgeType(
                row[NEdgeFields::SOURCE_TYPE].AsString(),
                row[NEdgeFields::LOG_SOURCE].AsString(),
                row[NEdgeFields::ID1_TYPE].AsString(),
                row[NEdgeFields::ID2_TYPE].AsString());
        }

        TVector<i64> GetSortedTimestamps(const NYT::TNode& row) {
            auto& dates = row[NEdgeFields::DATES];
            if (dates.IsNull() || !dates.IsList()) {
                return {};
            }
            return GetSortedTimestamps(dates.AsList());
        }

        TString ToString(const NYT::TNode& node) {
            return node.AsString();
        }

        template <typename Dates>
        TVector<i64> GetSortedTimestamps(const Dates& dates) {
            TVector<i64> timestamps;
            timestamps.reserve(dates.size());
            for (auto& date : dates) {
                i64 timestamp;
                ParseISO8601DateTime(ToString(date).data(), timestamp);
                timestamps.push_back(timestamp - TDuration::Hours(4).Seconds());
            }
            std::sort(timestamps.begin(), timestamps.end());
            return timestamps;
        }

    }

    template <typename ForwardIterator, typename Value>
    ForwardIterator ClosedLowerBound(ForwardIterator first, ForwardIterator last, const Value& value) {
        if (last <= first) {
            ythrow yexception() << "Array is empty";
        }
        return LowerBound(first, std::prev(last), value, TLess<Value>());
    }

    double GetDefaultSurvivalFunctionValue(i64 daysCount, i64 duration) {
        if (daysCount == 0) {
            return EPS;
        }
        i64 pow = duration - daysCount;
        return (pow <= 0) ? 1. : ((pow > 60) ? EPS : std::pow(DEFAULT_DECAY_RATE, pow));
    }

    double GetSurvivalFunctionValue(i64 daysCount, i64 duration, const TTotalHistogram& totalHistogram) {
        if (totalHistogram.GetTotalCount() && !totalHistogram.GetIsSignificant()) {
            return EPS;
        }
        if (daysCount == 0) {
            return EPS * 1.1;
        }
        if (duration && duration - daysCount > MAX_DURATION) {
            return EPS;
        }

        auto& daysCountSteps = totalHistogram.GetExternalStep();
        TTotalHistogram::TExternalStep searchExternalStep;
        searchExternalStep.SetValue(daysCount);
        auto daysCountStep = ClosedLowerBound(daysCountSteps.begin(), daysCountSteps.end(), searchExternalStep);

        auto& steps = daysCountStep->GetHistogram().GetStep();
        TTotalHistogram::THistogram::TStep durationStep;
        durationStep.SetValue(duration);
        auto step = ClosedLowerBound(steps.begin(), steps.end(), durationStep);
        return step->GetSF();
    }

    double ComputeCompatibilityDatesWithHistogram(const TVector<i64>& timestamps, const TTotalHistogram& totalHistogram) {
        double compatibilityPr = DEFAULT_COMPATIBILITY_PR;
        double secondCompatibilityPr = DEFAULT_COMPATIBILITY_PR;
        if (timestamps.size() > 1) {
            i64 previousTimestamp = 0;
            i64 count = 0;
            for (auto timestamp : timestamps) {
                if (previousTimestamp) {
                    auto pr = GetSurvivalFunctionValue(count, (timestamp - previousTimestamp) / SECONDS_PER_DAY - 1, totalHistogram);
                    if (pr > compatibilityPr) {
                        secondCompatibilityPr = compatibilityPr;
                        compatibilityPr = pr;
                    } else if (pr > secondCompatibilityPr) {
                        secondCompatibilityPr = pr;
                    }
                    if (secondCompatibilityPr > 1. - EPS) {
                        break;
                    }
                }
                count++;
                previousTimestamp = timestamp;
            }
        }
        return secondCompatibilityPr;
    }

    double ComputePDFDaysCount(const TVector<i64>& timestamps, const TTotalHistogram& totalHistogram) {
        double compatibilityPr = DEFAULT_COMPATIBILITY_PR;
        auto totalCount = static_cast<double>(totalHistogram.GetTotalCount());
        double pr = 0;
        double daysCount = static_cast<double>(timestamps.size());
        for (const auto &externalStep: totalHistogram.GetExternalStep()) {
            if (externalStep.GetValue() > daysCount) {
                break;
            }
            pr += totalCount ? static_cast<double>(externalStep.GetCount()) / totalCount : 0;
        }
        return Max(compatibilityPr, pr);
    }

    double GetSurvivalFunctionValue(i64 realCurrentTimestamp, const TVector<i64>& timestamps, const TTotalHistogram& totalHistogram) {
        i64 currentTimestamp = realCurrentTimestamp;
        if (totalHistogram.GetLastObservedTimestamp()) {
            currentTimestamp = Min(currentTimestamp, static_cast<i64>(totalHistogram.GetLastObservedTimestamp()));
        }
        i64 daysCount = timestamps.size();
        i64 duration = 0;
        if (timestamps.size()) {
            duration = (currentTimestamp - timestamps.back()) / SECONDS_PER_DAY;
            if (duration < 0) {
                if (realCurrentTimestamp > currentTimestamp) {
                    duration = Min(SURVIVAL_DURATION_ACCEPTABLE_DELAY + 1., (realCurrentTimestamp - timestamps.back() + 0.) / SECONDS_PER_DAY);
                } else {
                    return 1.;
                }
            }
        }

        if (totalHistogram.GetTotalCount() && !totalHistogram.GetIsSignificant()) {
            return DEFAULT_SURVIVAL_PROBABILITY_FOR_TYPES_WITHOUT_DATES_INFO;
        }
        if (!totalHistogram.GetTotalCount()) {
            return GetDefaultSurvivalFunctionValue(daysCount, duration);
        }

        auto compatibility = ComputeCompatibilityDatesWithHistogram(timestamps, totalHistogram);
        duration = MAX(0, duration - SURVIVAL_DURATION_ACCEPTABLE_DELAY + 1);

        auto survivalScore = GetSurvivalFunctionValue(daysCount, duration, totalHistogram);

        if (compatibility > TOO_MUCH_COMPATIBILITY_FOR_BAD_SCORE || daysCount > TOO_MUCH_COUNT_FOR_BAD_SCORE) {
            double rate = Min(MAX_RATE_FOR_SURVIVAL_SCORE_FOR_OLD_BUT_GOOD_EDGE, daysCount + 0.);
            if (survivalScore * rate < MIN_SURVIVAL_SCORE_FOR_OLD_BUT_GOOD_EDGE) {
                survivalScore *= rate;
            }
            survivalScore = Min(1., survivalScore + MIN_SURVIVAL_SCORE_FOR_OLD_BUT_GOOD_EDGE);
        }
        auto pdf = ComputePDFDaysCount(timestamps, totalHistogram);

        auto result = survivalScore *
                      compatibility *
                      pdf;

        return result;
    }

    double GetSurvivalFunctionValue(i64 currentTimestamp, const NYT::TNode& row, const TTotalHistogram& totalHistogram) {
        auto timestamps = NEdgeFields::GetSortedTimestamps(row);
        return GetSurvivalFunctionValue(currentTimestamp, timestamps, totalHistogram);
    }

    double GetSurvivalFunctionValue(i64 currentTimestamp, const TStatsQuery& query) {
        auto timestamps = NEdgeFields::GetSortedTimestamps(query.GetDate());
        auto& totalHistogram = query.GetTotalHistogram();
        return GetSurvivalFunctionValue(currentTimestamp, timestamps, totalHistogram);
    }

    double GetSurvivalFunctionValue(i64 currentTimestamp, TString& stringTStatsQuery) {
        TStatsQuery query;
        Y_PROTOBUF_SUPPRESS_NODISCARD query.ParseFromString(stringTStatsQuery);

        return GetSurvivalFunctionValue(currentTimestamp, query);
    }

    TStatsCollector::TStatsCollector(ui64 daysCountStepsSize, ui64 durationStepsSize)
        : Counts()
        , UsersCount(0)
    {
        Init(daysCountStepsSize, durationStepsSize);
    }

    TTotalHistogram TStatsCollector::ConvertToHistograms() {
        TTotalHistogram totalHistogram;
        ui64 totalTotalCount = 0;
        for (size_t daysCountIndex = 0; daysCountIndex < DaysCountSteps.size(); ++daysCountIndex) {
            auto daysCountStep = totalHistogram.AddExternalStep();
            TTotalHistogram::THistogram histogram;
            ui64 totalCount = 0;
            for (size_t durationIndex = 0; durationIndex < DurationsSteps.size(); ++durationIndex) {
                auto step = histogram.AddStep();
                auto count = Counts[daysCountIndex][durationIndex];
                step->SetValue(DurationsSteps[durationIndex]);
                step->SetCount(count);
                totalCount += count;
            }
            double sf = 1.;
            for (i64 i = 0; i < histogram.GetStep().size(); ++i) {
                auto step = histogram.MutableStep(i);
                auto pdf = totalCount ? static_cast<double>(step->GetCount()) / static_cast<double>(totalCount) : 0.;
                step->SetPDF(pdf);
                step->SetSF(sf);
                sf -= pdf;
                sf = std::max(EPS, sf);
            }
            daysCountStep->SetCount(totalCount);
            daysCountStep->MutableHistogram()->CopyFrom(histogram);
            daysCountStep->SetValue(DaysCountSteps[daysCountIndex]);
            totalTotalCount += totalCount;
        }
        for (i64 i = 0; i < totalHistogram.GetExternalStep().size(); ++i) {
            auto daysCountStep = totalHistogram.MutableExternalStep(i);
            auto pdf = totalTotalCount ? static_cast<double>(daysCountStep->GetCount()) / static_cast<double>(totalTotalCount) : 0.;
            if (i > 1 && daysCountStep->GetCount()) {
                totalHistogram.SetIsSignificant(true);
            }
            daysCountStep->SetPDF(pdf);
        }
        totalHistogram.SetTotalCount(totalTotalCount);
        totalHistogram.SetUsersCount(UsersCount);
        totalHistogram.SetLastObservedTimestamp(Min(LastObservedTimestamp, StatsCreationTime));
        totalHistogram.SetStatsCreationTime(StatsCreationTime);
        return totalHistogram;
    }

    void TStatsCollector::UpdateWith(const TTotalHistogram& totalHistogram) {
        LastObservedTimestamp = Max(LastObservedTimestamp, totalHistogram.GetLastObservedTimestamp());
        StatsCreationTime = Max(StatsCreationTime, totalHistogram.GetStatsCreationTime());
        UsersCount += totalHistogram.GetUsersCount();
        i64 daysCountIndex = 0;
        for (auto& daysCountStep : totalHistogram.GetExternalStep()) {
            if (static_cast<i64>(DaysCountSteps[daysCountIndex]) != static_cast<i64>(daysCountStep.GetValue())) {
                ythrow yexception() << "Mismatched histograms. DaysCountStep.Value " << DaysCountSteps[daysCountIndex] << " != " << daysCountStep.GetValue();
            }
            i64 durationIndex = 0;

            for (auto& step : daysCountStep.GetHistogram().GetStep()) {
                if (static_cast<i64>(DurationsSteps[durationIndex]) != static_cast<i64>(step.GetValue())) {
                    ythrow yexception() << "Mismatched histograms. Duration.Value " << DurationsSteps[durationIndex] << " != " << step.GetValue();
                }
                Counts[daysCountIndex][durationIndex] += step.GetCount();
                durationIndex++;
            }
            daysCountIndex++;
        }
    }

    void TStatsCollector::UpdateWith(ui64 daysCount, ui64 duration) {
        duration /= SECONDS_PER_DAY;
        i64 daysCountIndex = ClosedLowerBound(DaysCountSteps.begin(), DaysCountSteps.end(), daysCount) - DaysCountSteps.begin();
        i64 durationIndex = ClosedLowerBound(DurationsSteps.begin(), DurationsSteps.end(), duration) - DurationsSteps.begin();
        daysCountIndex = Min(daysCountIndex, static_cast<i64>(DaysCountSteps.size()) - 1);
        durationIndex = Min(durationIndex, static_cast<i64>(DurationsSteps.size()) - 1);
        Counts[daysCountIndex][durationIndex]++;
    }

    void TStatsCollector::UpdateWith(TVector<i64>& timestamps, i64 currentTimestamp, i64 deathTimeout) {
        UsersCount++;
        i64 previousTimestamp = 0;
        ui64 daysCount = 0;
        ui64 duration = (DurationsSteps.back() + 1) * SECONDS_PER_DAY;
        StatsCreationTime = Max(StatsCreationTime, static_cast<ui64>(currentTimestamp));
        for (auto timestamp : timestamps) {
            LastObservedTimestamp = Max(LastObservedTimestamp, static_cast<ui64>(timestamp));
            if (previousTimestamp) {
                duration = timestamp - previousTimestamp;
            }
            UpdateWith(daysCount, duration);
            previousTimestamp = timestamp;
            daysCount++;
        }

        if ((!timestamps.size()) || (currentTimestamp - timestamps.back() > SECONDS_PER_DAY * deathTimeout)) {
            UpdateWith(daysCount, 0);
        }
    }

    const TVector<TVector<ui64>> TStatsCollector::GetCounts() {
        return Counts;
    }

    ui64 TStatsCollector::GetUsersCount() {
        return UsersCount;
    }

    void TStatsCollector::Init(ui64 daysCountStepsSize, ui64 durationStepsSize) {
        DaysCountSteps = GetPowers(daysCountStepsSize);
        DurationsSteps = GetPowers(durationStepsSize);
        Counts.resize(DaysCountSteps.size());
        for (auto& counts : Counts) {
            counts.resize(DurationsSteps.size(), 0);
        }
    }

    TVector<ui64> TStatsCollector::GetPowers(ui64 size, i64 ratio) {
        TVector<ui64> values(size);
        size_t step = 0;
        for (auto& value : values) {
            value = step;
            step = (step == 0) ? 1 : step * ratio;
        }
        return values;
    }

    class TTransformDatesToHistogram: public NYT::IMapper<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<TEdgeStats>> {
    public:
        Y_SAVELOAD_JOB(CurrentTimestamp);
        TTransformDatesToHistogram() {
        }
        TTransformDatesToHistogram(i64 currentTimestamp)
            : CurrentTimestamp(currentTimestamp)
        {
        }

        void Do(NYT::TTableReader<NYT::TNode>* input, NYT::TTableWriter<TEdgeStats>* output) override {
            THashMap<TString, TStatsCollector> statsMap;

            for (; input->IsValid(); input->Next()) {
                auto& row = input->GetRow();
                auto edgeType = NEdgeFields::GetEdgeType(row);
                if (!edgeType) {
                    continue;
                }
                auto timestamps = NEdgeFields::GetSortedTimestamps(row);
                statsMap[edgeType].UpdateWith(timestamps, CurrentTimestamp, DEATH_TIMEOUT);
            }
            for (auto& it : statsMap) {
                auto histogram = it.second.ConvertToHistograms();
                TEdgeStats out;
                out.SetEdgeType(it.first);
                out.MutableTotalHistogram()->CopyFrom(histogram);
                output->AddRow(out);
            }
        }

    private:
        i64 CurrentTimestamp = 0;
        const i64 DEATH_TIMEOUT = 30;
    };

    class TMergeHistogram: public NYT::IReducer<NYT::TTableReader<TEdgeStats>, NYT::TTableWriter<TEdgeStats>> {
    public:
        void Do(NYT::TTableReader<TEdgeStats>* input, NYT::TTableWriter<TEdgeStats>* output) override {
            TStatsCollector stats;
            TString edgeType;
            for (; input->IsValid(); input->Next()) {
                auto& row = input->GetRow();
                edgeType = row.GetEdgeType();
                stats.UpdateWith(row.GetTotalHistogram());
            }
            auto histogram = stats.ConvertToHistograms();
            TEdgeType edgeTypeProto;
            Y_PROTOBUF_SUPPRESS_NODISCARD edgeTypeProto.ParseFromString(edgeType);

            TEdgeStats out;
            out.SetEdgeType(edgeType);
            out.MutableTotalHistogram()->CopyFrom(histogram);
            out.SetSourceType(edgeTypeProto.GetSourceType());
            out.SetLogSource(edgeTypeProto.GetLogSource());
            out.SetId1Type(edgeTypeProto.GetId1Type());
            out.SetId2Type(edgeTypeProto.GetId2Type());
            out.SetIsSignificant(histogram.GetIsSignificant());

            output->AddRow(out);
        }
    };

    void CollectStats(NYT::IClientPtr& client, i64 currentTimestamp, const TVector<TString>& sources, const TString& destination) {
        NYT::TMapReduceOperationSpec mapReduceSpec;
        for (auto& source : sources) {
            mapReduceSpec.AddInput<NYT::TNode>(TString::Join(source));
        }
        mapReduceSpec.AddOutput<TEdgeStats>(
            NYT::TRichYPath(destination).Schema(
                NYT::TTableSchema()
                .AddColumn(NYT::TColumnSchema().Name("EdgeType").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("TotalHistogram").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("sourceType").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("logSource").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("id1Type").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("id2Type").Type(NYT::VT_STRING))
                .AddColumn(NYT::TColumnSchema().Name("IsSignificant").Type(NYT::VT_BOOLEAN))
            )
        );
        mapReduceSpec.ReduceBy("EdgeType");
        NYT::TOperationOptions options;
        options.Spec(NYT::TNode()("data_size_per_sort_job", 1024 * 1024 * 256));

        client->MapReduce(
            mapReduceSpec,
            new NEdgeStats::TTransformDatesToHistogram(currentTimestamp),
            new NEdgeStats::TMergeHistogram,
            new NEdgeStats::TMergeHistogram,
            options);

         client->Sort(NYT::TSortOperationSpec()
            .AddInput(destination)
            .Output(destination)
            .SortBy({
                "sourceType",
                "logSource",
                "id1Type",
                "id2Type"}));

        NCrypta::SetYqlProtoFields<TEdgeStats>(client, destination);
    }

    void CollectStats(NYT::IClientPtr& client, i64 currentTimestamp, const TString& soupDir, const TString& destination) {
        auto attrFilter = NYT::TAttributeFilter().AddAttribute("type").AddAttribute("row_count");
        TVector<TString> sources;
        for (const auto& node : client->List(soupDir, NYT::TListOptions().AttributeFilter(attrFilter))) {
            auto& attributes = node.GetAttributes();
            if ((attributes["type"].AsString() != "table") || (!attributes["row_count"].AsInt64())) {
                continue;
            }
            sources.push_back(TString::Join(soupDir, "/", node.AsString()));
        }
        CollectStats(client, currentTimestamp, sources, destination);
    }
}

REGISTER_MAPPER(NEdgeStats::TTransformDatesToHistogram);
REGISTER_REDUCER(NEdgeStats::TMergeHistogram);
