#include <crypta/audience/lib/affinity/affinity.h>

#include <crypta/lab/lib/native/stats.h>

const double MIN_SIZE = 1000;

using NLab::TSegment;
using NLab::TStrata;
using NLab::TUserDataStats;

double SafeDenominator(double x) {
    return x != 0.0 ? x : 1.0;
}

struct TFraction {
    TFraction(double count, double total) : Count(count), Total(total)
    {
    }

    double Count;
    double Total;

    double Value() const {
        return Count / SafeDenominator(Total);
    }

    TFraction AdditiveSmooth(const TFraction& other) {
        return TFraction(Count + other.Count, Total + other.Total);
    }
};

TMap<TStrata, double> ComputeStratumTotals(const TUserDataStats& stats) {
    TMap<TStrata, double> result;
    for (const auto& each : stats.GetStratum().GetStrata()) {
        result[each.GetStrata()] = each.GetCount();
    }
    return result;
}

TMap<TStrata, TMap<TSegment, double>> ComputeSegmentCounts(const TUserDataStats& stats) {
    TMap<TStrata, TMap<TSegment, double>> result;
    for (const auto& eachStrata : stats.GetStratum().GetStrata()) {
        for (const auto& eachSegment : eachStrata.GetSegment()) {
            result[eachStrata.GetStrata()][eachSegment.GetSegment()] = eachSegment.GetCount();
        }
    }
    return result;
}

double SumValues(const TMap<TStrata, double>& mapping) {
    double result = 0.0;
    for (const auto& each : mapping) {
        result += each.second;
    }
    return result;
}

TMap<TSegment, double> ComputeAffinities(
    const TVector<TSegment>& targetSegments,
    const TUserDataStats& localStats,
    const TUserDataStats& globalStats) {
    TMap<TStrata, TMap<TSegment, double>> localSegments = ComputeSegmentCounts(localStats);
    TMap<TStrata, TMap<TSegment, double>> globalSegments = ComputeSegmentCounts(globalStats);

    auto localStratumTotals = ComputeStratumTotals(localStats);
    auto globalStratumTotals = ComputeStratumTotals(globalStats);

    auto localGrandTotal = SumValues(localStratumTotals);

    TMap<TSegment, double> affinities{};

    for (const auto& eachSegment : targetSegments) {
        affinities[eachSegment] = 0.0;

        for (const auto& eachStrata : globalSegments) {
            const auto strata = eachStrata.first;

            if (globalSegments[strata][eachSegment] < MIN_SIZE) {
                continue;
            }

            auto globalProbability = TFraction(
                globalSegments[strata][eachSegment],
                globalStratumTotals[strata]
            );

            auto priorProbability = TFraction(
                MIN_SIZE * globalProbability.Value(),
                MIN_SIZE
            );
            auto localProbability = TFraction(
                localSegments[strata][eachSegment],
                localStratumTotals[strata]
            );
            auto probabilityRatio = TFraction(
                localProbability.AdditiveSmooth(priorProbability).Value(),
                globalProbability.Value()
            );
            auto strataWeight = TFraction(
                localStratumTotals[strata],
                localGrandTotal
            );

            affinities[eachSegment] += strataWeight.Value() * log(probabilityRatio.Value());
        }
    }
    for (const auto& eachSegment : targetSegments) {
        affinities[eachSegment] = exp(affinities[eachSegment]);
    }

    return affinities;
}
