package ru.yandex.crypta.lab.utils;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.glassfish.jersey.internal.util.Producer;

import ru.yandex.crypta.audience.proto.TAge;
import ru.yandex.crypta.audience.proto.TAgeCount;
import ru.yandex.crypta.audience.proto.TGender;
import ru.yandex.crypta.audience.proto.TGenderCount;
import ru.yandex.crypta.audience.proto.TIncome;
import ru.yandex.crypta.audience.proto.TIncomeCount;
import ru.yandex.crypta.audience.proto.TSegment;
import ru.yandex.crypta.audience.proto.TStrata;
import ru.yandex.crypta.audience.proto.TUserDataStats;
import ru.yandex.crypta.lab.proto.Segment;
import ru.yandex.crypta.lab.proto.SegmentGroup;
import ru.yandex.crypta.lab.proto.TAgeAffinity;
import ru.yandex.crypta.lab.proto.TGenderAffinity;
import ru.yandex.crypta.lab.proto.TIncomeAffinity;
import ru.yandex.crypta.lab.proto.TSegmentAffinity;
import ru.yandex.crypta.lab.proto.TSimpleSampleStats;
import ru.yandex.crypta.lab.proto.TStringAffinity;

public class Affinities {

    private final int priorSampleSize;

    public Affinities(int priorSampleSize) {
        this.priorSampleSize = priorSampleSize;
    }

    public int getPriorSampleSize() {
        return priorSampleSize;
    }

    private Stream<TUserDataStats.TStrataStats> strataStats(TUserDataStats stats) {
        return stats.getStratum().getStrataList().stream();
    }

    private Collector<TUserDataStats.TStrataStats, ?, Map<TStrata, Long>> countsPerStrata() {
        return Collectors.toMap(TUserDataStats.TStrataStats::getStrata, TUserDataStats.TStrataStats::getCount);
    }

    private Collector<TUserDataStats.TStrataStats, ?, Map<TStrata, Map<TSegment, Long>>> segmentCountsPerStrata() {
        return Collectors.toMap(
                TUserDataStats.TStrataStats::getStrata,
                each -> each.getSegmentList().stream().collect(Collectors.toMap(
                        TUserDataStats.TSegmentCount::getSegment, TUserDataStats.TSegmentCount::getCount))
        );
    }

    private Collector<TUserDataStats.TStrataStats, ?, Map<TStrata, Map<TGender, Long>>> genderCountsPerStrata() {
        return Collectors.toMap(
                TUserDataStats.TStrataStats::getStrata,
                each -> each.getGenderList().stream().collect(Collectors.toMap(
                        TGenderCount::getGender, TGenderCount::getCount
                ))
        );
    }

    private Collector<TUserDataStats.TStrataStats, ?, Map<TStrata, Map<TAge, Long>>> ageCountsPerStrata() {
        return Collectors.toMap(
                TUserDataStats.TStrataStats::getStrata,
                each -> each.getAgeList().stream().collect(Collectors.toMap(
                        TAgeCount::getAge, TAgeCount::getCount
                ))
        );
    }

    private Collector<TUserDataStats.TStrataStats, ?, Map<TStrata, Map<TIncome, Long>>> incomeCountsPerStrata() {
        return Collectors.toMap(
                TUserDataStats.TStrataStats::getStrata,
                each -> each.getIncomeList().stream().collect(Collectors.toMap(
                        TIncomeCount::getIncome, TIncomeCount::getCount
                ))
        );
    }

    public void computeForTokens(Producer<TStringAffinity.Builder> result, TUserDataStats.TTokensStats stats,
                                 TUserDataStats.TTokensStats base)
    {
        double usersCount = stats.getUsersCount();

        Map<String, Float> perTokenBaseAffinity = base.getTokenList().stream().collect(
                Collectors.toMap(TUserDataStats.TWeightedTokenStats::getToken,
                        each -> each.getWeight() / base.getUsersCount())
        );

        double minBaseAffinity = base
                .getTokenList()
                .stream()
                .mapToDouble(each -> each.getWeight() / base.getUsersCount())
                .filter(each -> each != 0.0)
                .min()
                .orElse(1.0);

        stats.getTokenList().forEach(each -> {
            double affinity = each.getWeight() / usersCount;
            double ratio = each.getCount() / usersCount;
            var token = each.getToken();
            // TODO enable adjusted affinity
            double adjustedAffinity = (affinity + priorSampleSize) / (perTokenBaseAffinity.getOrDefault(token, (float) minBaseAffinity) + priorSampleSize);
            result.call()
                    .setValue(token)
                    .setAffinity(adjustedAffinity)
                    .setRatio(ratio);
        });
    }

    private <T> void collectAffinities(
            Map<TStrata, Map<T, Long>> globalPropertyPerStrata,
            Map<TStrata, Map<T, Long>> localPropertyPerStrata,
            TUserDataStats globalStats,
            TUserDataStats stats,
            Map<T, Double> affinities,
            Map<T, Double> ratios
    ) {
        Map<TStrata, Long> globalCountsPerStrata =
                strataStats(globalStats).collect(countsPerStrata());
        Map<TStrata, Long> localCountsPerStrata =
                strataStats(stats).collect(countsPerStrata());

        double localGrandTotal = localCountsPerStrata.values().stream().mapToLong(x -> x).sum();

        for (TStrata stratum : globalPropertyPerStrata.keySet()) {
            if (!localPropertyPerStrata.containsKey(stratum)) {
                continue;
            }

            double globalTotal = globalCountsPerStrata.get(stratum);
            double localTotal = localCountsPerStrata.get(stratum);
            double stratumWeight = localTotal / (localGrandTotal + 1.0);

            Map<T, Long> globalPropertyInTheStratum = globalPropertyPerStrata.get(stratum);

            for (T property : globalPropertyInTheStratum.keySet()) {
                Map<T, Long> localPropertyInTheStratum = localPropertyPerStrata.get(stratum);

                if (!localPropertyInTheStratum.containsKey(property)) {
                    continue;
                }
                double globalCount = globalPropertyInTheStratum.get(property);
                double globalRatio = globalCount / (globalTotal + 1.0);

                double priorTotal = priorSampleSize;
                double priorCount = priorTotal * globalRatio;

                double localCount = localPropertyInTheStratum.get(property);
                double localRatio = (localCount + priorCount) / (localTotal + priorTotal);

                double affinity = (localRatio / globalRatio);

                double runningAffinity = affinities.getOrDefault(property, 0.0);
                affinities.put(property, runningAffinity + Math.log(affinity) * stratumWeight);
                double runningRatio = ratios.getOrDefault(property, 0.0);
                ratios.put(property, runningRatio + (localCount / localTotal) * stratumWeight);
            }
        }
    }

    public void computeForSegments(TSimpleSampleStats.Builder result,
                                   TUserDataStats stats,
                                   TUserDataStats globalStats,
                                   Map<Long, Map<Long, Segment>> exportsMapping,
                                   Map<String, List<SegmentGroup>> parents)
    {
        Map<TStrata, Map<TSegment, Long>> globalSegmentsPerStrata =
                strataStats(globalStats).collect(segmentCountsPerStrata());
        Map<TStrata, Map<TSegment, Long>> localSegmentsPerStrata =
                strataStats(stats).collect(segmentCountsPerStrata());

        Map<TSegment, Double> affinities = new HashMap<>();
        Map<TSegment, Double> ratios = new HashMap<>();

        collectAffinities(
                globalSegmentsPerStrata, localSegmentsPerStrata, globalStats, stats, affinities, ratios
        );

        affinities.forEach((k, v) -> {
            Segment segment = exportsMapping.getOrDefault(Integer.toUnsignedLong(k.getKeyword()), Collections.emptyMap()).get(Integer.toUnsignedLong(k.getID()));
            if (Objects.isNull(segment)) {
                return;
            }
            Optional<Segment.Export> export = segment.getExports()
                    .getExportsList().stream()
                    .filter(each -> exportsMatch(k, each))
                    .findFirst();

            TSegmentAffinity.Builder affinityBuilder = result.getAffinitiesBuilder().addBySegmentBuilder();
            affinityBuilder
                    .setID(segment.getId())
                    .setName(segment.getName())
                    .setExportType(export.map(Segment.Export::getType).orElse(Segment.Export.Type.DEFAULT))
                    .setKeywordID(k.getKeyword())
                    .setSegmentID(k.getID())
                    .setAffinity(Math.exp(v))
                    .setRatio(ratios.getOrDefault(k, 0.0));
            if (parents.containsKey(segment.getId())) {
                for (SegmentGroup segmentGroup : parents.get(segment.getId())) {
                    affinityBuilder.addParentBuilder()
                            .setID(segmentGroup.getId())
                            .setName(segmentGroup.getName());
                }
            }
        });
    }

    public void computeForSocdem(TSimpleSampleStats.Builder result,
                                 TUserDataStats stats,
                                 TUserDataStats globalStats)
    {
        Map<TStrata, Map<TGender, Long>> globalGenderPerStrata =
                strataStats(globalStats).collect(genderCountsPerStrata());
        Map<TStrata, Map<TAge, Long>> globalAgePerStrata =
                strataStats(globalStats).collect(ageCountsPerStrata());
        Map<TStrata, Map<TIncome, Long>> globalIncomePerStrata =
                strataStats(globalStats).collect(incomeCountsPerStrata());

        Map<TStrata, Map<TGender, Long>> localGenderPerStrata =
                strataStats(stats).collect(genderCountsPerStrata());
        Map<TStrata, Map<TAge, Long>> localAgePerStrata =
                strataStats(stats).collect(ageCountsPerStrata());
        Map<TStrata, Map<TIncome, Long>> localIncomePerStrata =
                strataStats(stats).collect(incomeCountsPerStrata());

        Map<TGender, Double> genderAffinities = new HashMap<>();
        Map<TAge, Double> ageAffinities = new HashMap<>();
        Map<TIncome, Double> incomeAffinities = new HashMap<>();

        Map<TGender, Double> genderRatios = new HashMap<>();
        Map<TAge, Double> ageRatios = new HashMap<>();
        Map<TIncome, Double> incomeRatios = new HashMap<>();

        collectAffinities(
                globalGenderPerStrata, localGenderPerStrata, globalStats, stats, genderAffinities, genderRatios
        );

        collectAffinities(globalAgePerStrata, localAgePerStrata, globalStats, stats, ageAffinities, ageRatios);

        collectAffinities(
                globalIncomePerStrata, localIncomePerStrata, globalStats, stats, incomeAffinities, incomeRatios
        );

        genderAffinities.forEach((k, v) -> {
            TGenderAffinity.Builder affinityBuilder = result.getAffinitiesBuilder().addByGenderBuilder();
            affinityBuilder
                    .setGender(k)
                    .setAffinity(Math.exp(v))
                    .setRatio(genderRatios.getOrDefault(k, 0.0));
        });

        ageAffinities.forEach((k, v) -> {
            TAgeAffinity.Builder affinityBuilder = result.getAffinitiesBuilder().addByAgeBuilder();
            affinityBuilder
                    .setAge(k)
                    .setAffinity(Math.exp(v))
                    .setRatio(ageRatios.getOrDefault(k, 0.0));
        });

        incomeAffinities.forEach((k, v) -> {
            TIncomeAffinity.Builder affinityBuilder = result.getAffinitiesBuilder().addByIncomeBuilder();
            affinityBuilder
                    .setIncome(k)
                    .setAffinity(Math.exp(v))
                    .setRatio(incomeRatios.getOrDefault(k, 0.0));
        });
    }

    private boolean exportsMatch(TSegment k, Segment.Export x) {
        // Due to https://a.yandex-team.ru/review/1533658/files/1#file-0-48780279:R15
        // one should unify types
        return x.getKeywordId() == (long) k.getKeyword() && x.getSegmentId() == (long) k.getID();
    }
}
