#include "counts.h"

namespace NCrypta {
    namespace NGraphEngine {

        ui64 TTypesCountScoringStrategy::Count(const TGraph& graph) const {
            int count = 0;
            for (const auto& vertex : graph.GetVertices()) {
                if (vertex.GetType() == SearchType) {
                    count++;
                }
            }
            return count;
        }

        double THistogramCountsScore::GetScore(ui64 count) const {
            if (!count && Counts.PenalizeEmpty) {
                return -1.;
            }
            const auto &item = Counts.Ranges.lower_bound(count);
            if (item == Counts.Ranges.end()) {
                return Counts.TheRestProb;
            } else {
                return item->second.GetProbability(count);
            }
        }

        double THistogramCountsScore::TRange::GetProbabilityInterval() const {
            ui64 range = To - From;
            return Value / ((1. + range) * range / 2);
        }

        double THistogramCountsScore::TRange::GetProbability(ui64 count) const {
            if (count < From) {
                return  GetProbability(From);
            }
            if (count > To) {
                return GetProbability(To);
            }
            return ProbFunc(*this, count);
        }

        THistogramCountsScore::TRange THistogramCountsScore::TRange::BuildUniform(ui64 from, ui64 to, double value) {
            return {.From=from, .To=to, .Value=value, .ProbFunc=UniformValue};
        }

        THistogramCountsScore::TRange THistogramCountsScore::TRange::BuildIncreasing(ui64 from, ui64 to, double value) {
            return {.From=from, .To=to, .Value=value, .ProbFunc=IncreasingValue};
        }

        THistogramCountsScore::TRange THistogramCountsScore::TRange::BuildDecreasing(ui64 from, ui64 to, double value) {
            return {.From=from, .To=to, .Value=value, .ProbFunc=DecreasingValue};
        }

        bool THistogramCountsScore::CheckInterval(const TRange& range, ui64 key) {
            return key > range.From && key <= range.To;
        }

        double THistogramCountsScore::IncreasingValue(const TRange& range, ui64 count) {
            if (count <= range.From) {
                return IncreasingValue(range, count + 1);
            }
            return (count - range.From + 0.) * range.GetProbabilityInterval();
        }
        double THistogramCountsScore::DecreasingValue(const TRange& range, ui64 count) {
            if (count <= range.From) {
                return DecreasingValue(range, count + 1);
            }
            return (range.To - count + 1.) * range.GetProbabilityInterval();
        }

        double THistogramCountsScore::UniformValue(const TRange& range, ui64 /*count*/) {
            return range.Value;
        }

        THistogramCountsScore::Builder::Builder() {
            Counts.TheRestProb = 1.;
        }

        void THistogramCountsScore::Builder::LessOrEqualAs(ui64 count, double prob) {
            auto item = Counts.Ranges.lower_bound(count);
            ui64 from = 0;
            if (item != Counts.Ranges.end()) {
                from = item->first;
            }
            Counts.Ranges[count] = TRange::BuildUniform(from, count, prob);
        }

        void THistogramCountsScore::Builder::UniformIncreasingRange(ui64 from, ui64 to, double probabilityRange) {
            if (from > to) {
                ythrow yexception() << "Invalid range";
            }
            Counts.Ranges[to] = TRange::BuildIncreasing(from, to, probabilityRange);
        }

        void THistogramCountsScore::Builder::UniformDecreasingRange(ui64 from, ui64 to, double probabilityRange) {
            if (from > to) {
                ythrow yexception() << "Invalid range";
            }
            Counts.Ranges[to] = TRange::BuildDecreasing(from, to, probabilityRange);
        }

        void THistogramCountsScore::Builder::AndTheRestAs(double restProb) {
            Counts.TheRestProb = restProb;
        }

        void THistogramCountsScore::Builder::AndPenalizeEmpty() {
            Counts.PenalizeEmpty = true;
        }

        THistogramCountsScore THistogramCountsScore::Builder::Build() {
            THistogramCountsScore histogram(Counts);
            return histogram;
        }

        double TAbstractCountScoringStrategy::ComputeScore(const TGraph& graph) const {
            return Options.HistogramCountsScore.GetScore(Count(graph));
        }

        double TAbstractCountScoringStrategy::ComputeWeight(const TGraph& /*graph*/) const {
            return Options.ScoreWeight;
        }

        ui64 TVerticesCountScoringStrategy::Count(const TGraph& graph) const {
            return graph.vertices_size();
        }

        ui64 CountActiveDevices(const TGraph& graph) {
            ui64 activeDevicesCount = 0;
            for (const auto& device : graph.GetIdsInfo().GetDevicesInfo()) {
                if (device.GetIsActive()) {
                    activeDevicesCount++;
                }
            }
            return activeDevicesCount;
        }

        ui64 TDevicesCountStrategy::Count(const TGraph& graph) const {
            return CountActiveDevices(graph);
        }

        double TCrossDevicesStrategy::ComputeScore(const TGraph& graph) const {
            if (CountActiveDevices(graph) > 0) {
                for (const auto& browserInfo : graph.GetIdsInfo().GetBrowsersInfo()) {
                    if (!browserInfo.GetIsMobile() && browserInfo.GetIsActive()) {
                        // active desktop
                        return 1.;
                    }
                }
            }
            return 0.;
        }

        ui64 TLoginsCountScoringStrategy::Count(const TGraph& graph) const {
            ui64 index = 0;
            THashSet<int> loginIndexes{};
            THashSet<TString> goodLogins{};
            THashSet<int> unknownPuidsWithoutLoginIndexes{};
            for (const auto& vertex : graph.GetVertices()) {
                if (vertex.GetType() == EIdType::LOGIN) {
                    TString login = vertex.GetLogin().value();
                    if (!IsSyntheticLogin(login)) {
                        goodLogins.insert(login);
                    }
                    loginIndexes.insert(index);
                }
                if (vertex.GetType() == EIdType::PUID) {
                    unknownPuidsWithoutLoginIndexes.insert(index);
                }
                index++;
            }

            for (const auto& edge : graph.GetEdges()) {
                if ((unknownPuidsWithoutLoginIndexes.contains(edge.GetVertex1()) && loginIndexes.contains(edge.GetVertex2()))) {
                    unknownPuidsWithoutLoginIndexes.erase(edge.GetVertex1());
                }
                if (unknownPuidsWithoutLoginIndexes.contains(edge.GetVertex2()) && loginIndexes.contains(edge.GetVertex1())) {
                    unknownPuidsWithoutLoginIndexes.erase(edge.GetVertex2());
                }
            }
            return goodLogins.size() + unknownPuidsWithoutLoginIndexes.size();
        }

        double TSocdemScoringStrategy::ComputeScore(const TGraph &graph) const {
            double malesCount = 0.;
            double femalesCount = 0.;
            for (const auto& socdem : graph.GetIdsInfo().GetSocdemInfo()) {
                malesCount += (socdem.GetGender() == "m");
                femalesCount += (socdem.GetGender() == "f");
            }
            if (malesCount + femalesCount == 0.) {
                return 0.;
            }
            return -Min(malesCount, femalesCount) / (malesCount + femalesCount);
        }
    }
}
