#include "segments_user_data_stats_aggregator.h"

using namespace NLab;

namespace {
    template<typename T>
    void SetTokensStatsCounts(T& stats, size_t size) {
        stats.SetTokensCount(size);
        if (size > 0) {
            stats.SetUsersCount(1);
        }
    }

    void FillStrata(NLab::TStrata* strata, const NLab::TUserData::TAttributes& attributes) {
        strata->SetDevice(attributes.GetDevice());
        strata->SetCountry(attributes.GetCountry());
        strata->SetCity(attributes.GetCity());
        strata->SetHasCryptaID(attributes.GetHasCryptaID());
    }
}

TSegmentsUserDataStatsAggregator::TSegmentsUserDataStatsAggregator(const NNativeYT::TProtoState<TUserDataStatsOptions>& state)
    : State(state)
    , SkipFilter(0.)
{}

void TSegmentsUserDataStatsAggregator::ConvertUserDataToUserDataStats(NLab::TUserDataStats& userDataStats, const NLab::TUserData& userData, bool skip) {
    const auto& identifiers = userData.GetIdentifiers();
    bool notUnique = identifiers.GetNotUnique();

    if (userData.GetWithoutData() || notUnique || skip) {
        userDataStats.CopyFrom(GetEmptyStats());
    } else {
        GetStats(userDataStats, userData);
    }

    if (userData.HasGroupID()) {
        userDataStats.SetGroupID(userData.GetGroupID());
    }

    if (userData.HasIdentifiers()) {
        userDataStats.MutableIdentifiers()->CopyFrom(identifiers);
    }

    auto* counts = userDataStats.MutableCounts();
    counts->SetTotal(static_cast<ui64>(!skip));
    counts->SetUniqYuid(static_cast<ui64>(!notUnique));
}

void TSegmentsUserDataStatsAggregator::UpdateWith(const NLab::TUserDataStats& userDataStats) {
    auto& aggregator = StatsAggregatorsBySegment[userDataStats.GetGroupID()];
    aggregator.UpdateWith(userDataStats);
}

bool IgnoreStrataSegment(const TMaybe<TSet<NLab::TSegment>>& usedStrataSegments, const NLab::TSegment& segment) {
    if (usedStrataSegments.Defined()) {
        return (usedStrataSegments.Get())->count(segment) == 0;
    }
    return false;
}

void TSegmentsUserDataStatsAggregator::Start(NYT::TTableWriter<TUserDataStats>*, const NNativeYT::TProtoState<TUserDataStatsOptions>& state) {
    State = state;
    StatsAggregatorsBySegment.clear();
    UsedStrataSegments.Clear();
    if (State->GetFlags().GetIgnoreUnusedStrataSegments()) {
        const auto& repeatedSegments = State->GetUsedStrataSegments().GetSegment();
        TSet<NLab::TSegment> segments(repeatedSegments.begin(), repeatedSegments.end());
        UsedStrataSegments = segments;
        State->ClearUsedStrataSegments();
    }
    SkipFilter = TFilterIdentifier(State->GetSamplingOptions().GetSkipRate());
}

void TSegmentsUserDataStatsAggregator::EmitIfLarge(NYT::TTableWriter<TUserDataStats>* output) {
    if (StatsAggregatorsBySegment.size() > MAX_SEGMENTS_IN_STATS) {
        Emit(output);
    }
}

void TSegmentsUserDataStatsAggregator::Emit(NYT::TTableWriter<TUserDataStats>* output) {
    for (auto& segment : StatsAggregatorsBySegment) {
        auto& aggregator = segment.second;
        TUserDataStats userDataStats;
        aggregator.MergeInto(userDataStats);
        userDataStats.SetGroupID(segment.first);
        if (!segment.first) {
            userDataStats.ClearGroupID();
        }
        output->AddRow(userDataStats);
    }
    StatsAggregatorsBySegment.clear();
}

void TSegmentsUserDataStatsAggregator::Finish(NYT::TTableWriter<TUserDataStats>* output) {
    Emit(output);
}

NLab::TUserDataStatsAggregator<>& TSegmentsUserDataStatsAggregator::GetOrCreateSegmentAggr(const TString& segmentID) {
    return StatsAggregatorsBySegment[segmentID];
}

bool TSegmentsUserDataStatsAggregator::FillMetaFromOptions(TUserDataStats& userDataStats, const TUserData& userData) {
    TString GroupId = userData.GetGroupID();
    if (State->GetSegments().count(GroupId)) {
        auto options = State->GetSegments().at(GroupId);
        auto segmentInfo = options.GetInfo();
        userDataStats.MutableSegmentInfo()->CopyFrom(segmentInfo);

        if (options.GetFilterOptions().GetCapacity()) {
            TUserDataStats::TFilter filter;
            auto capacity = options.GetFilterOptions().GetCapacity();
            auto errorRate = options.GetFilterOptions().GetErrorRate();
            if (capacity > 1) {
                filter.MutableOptions()->SetCapacity(capacity);
                filter.MutableOptions()->SetErrorRate(errorRate);
                filter.SetSingle(userData.GetYandexuid());
            }
            userDataStats.MutableFilter()->CopyFrom(filter);
        }
        return true;
    } else {
        return false;
    }
}

bool TSegmentsUserDataStatsAggregator::Skip(const TString& identifier) {
    return SkipFilter.Filter(identifier);
}

TUserDataStats TSegmentsUserDataStatsAggregator::GetEmptyStats() {
    TUserDataStats userDataStats;
    userDataStats.MutableCounts()->SetWithData(0);
    return userDataStats;
}

bool TSegmentsUserDataStatsAggregator::IgnoreStrataSegment(const NLab::TSegment& segment) {
    if (UsedStrataSegments.Defined()) {
        return (UsedStrataSegments.Get())->count(segment) == 0;
    }
    return false;
}

void TSegmentsUserDataStatsAggregator::GetStats(TUserDataStats& userDataStats, const NLab::TUserData& userData) {
    userDataStats.Clear();

    // Attributes
    const auto& attributes = userData.GetAttributes();
    auto* attributesStats = userDataStats.MutableAttributes();

    auto* genderCount = attributesStats->AddGender();
    genderCount->SetGender(attributes.GetGender());
    genderCount->SetCount(1);

    auto* ageCount = attributesStats->AddAge();
    ageCount->SetAge(attributes.GetAge());
    ageCount->SetCount(1);

    auto* deviceCount = attributesStats->AddDevice();
    deviceCount->SetDevice(attributes.GetDevice());
    deviceCount->SetCount(1);

    auto* regionCount = attributesStats->AddRegion();
    regionCount->SetRegion(attributes.GetRegion());
    regionCount->SetCount(1);

    auto* incomeCount = attributesStats->AddIncome();
    incomeCount->SetIncome(attributes.GetIncome());
    incomeCount->SetCount(1);

    auto* genderAgeIncomeCount = attributesStats->AddGenderAgeIncome();
    auto* genderAgeIncome = genderAgeIncomeCount->MutableGenderAgeIncome();
    genderAgeIncome->SetGender(attributes.GetGender());
    genderAgeIncome->SetAge(attributes.GetAge());
    genderAgeIncome->SetIncome(attributes.GetIncome());
    genderAgeIncomeCount->SetCount(1);

    // Stratum
    auto* stratumStats = userDataStats.MutableStratum();
    auto* stratum = stratumStats->AddStrata();

    FillStrata(stratum->MutableStrata(), attributes);

    const auto& segments = userData.GetSegments().GetSegment();
    for (const auto& segment : segments) {
        if (!IgnoreStrataSegment(segment)) {
            auto* segmentCount = stratum->AddSegment();
            segmentCount->MutableSegment()->CopyFrom(segment);
            segmentCount->SetCount(1);
        }
    }

    stratum->AddAge()->CopyFrom(*ageCount);
    stratum->AddGender()->CopyFrom(*genderCount);
    stratum->AddIncome()->CopyFrom(*incomeCount);
    stratum->SetCount(1);

    // Vectors
    if (!State->GetFlags().GetIgnoreDistributions()) {
        auto* statistic = userDataStats.MutableDistributions()->MutableMain();
        const auto& vectors = userData.GetVectors().GetVector();
        statistic->MutableMean()->CopyFrom(vectors);
        statistic->SetCount(1);
    }

    // Affinities
    if (!State->GetFlags().GetIgnoreAffinities()) {
        if (userData.HasAffinities()) {
            auto* affinitiveStats = userDataStats.MutableAffinities();
            const auto& affinities = userData.GetAffinities();
            FillTokensStats(*affinitiveStats->MutableWords(), affinities.GetWords());
            FillTokensStats(*affinitiveStats->MutableHosts(), affinities.GetHosts());
            FillTokensStats(*affinitiveStats->MutableApps(), affinities.GetApps());
        } else if (userData.HasAffinitiesEncoded()) {
            auto* affinitiveStatsEncoded = userDataStats.MutableAffinitiesEncoded();
            const auto& affinitiesEncoded = userData.GetAffinitiesEncoded();
            FillTokensStatsEncoded(*affinitiveStatsEncoded->MutableWords(), affinitiesEncoded.GetWords());
            FillTokensStatsEncoded(*affinitiveStatsEncoded->MutableHosts(), affinitiesEncoded.GetHosts());
            FillTokensStatsEncoded(*affinitiveStatsEncoded->MutableApps(), affinitiesEncoded.GetApps());
        }
    }

    // Counts
    userDataStats.MutableCounts()->SetWithData(1);
}

void TSegmentsUserDataStatsAggregator::FillTokensStats(TUserDataStats::TTokensStats& stats, const TTokens& tokens) {
    const auto& tokensCollection = tokens.GetToken();
    for (const auto& token : tokensCollection) {
        auto* tokenStats = stats.AddToken();
        *tokenStats->MutableToken() = token.GetToken();
        tokenStats->SetWeight(token.GetWeight());
        tokenStats->SetCount(1);
    }

    SetTokensStatsCounts(stats, tokensCollection.size());
}

void TSegmentsUserDataStatsAggregator::FillTokensStatsEncoded(TUserDataStats::TTokensStatsEncoded& stats, const ::google::protobuf::RepeatedField<ui32>& tokens) {
    for (const auto& token : tokens) {
        auto* tokenStats = stats.AddToken();
        tokenStats->SetId(token);
        tokenStats->SetCount(1);
    }

    SetTokensStatsCounts(stats, tokens.size());
}

TNormalSufficientStatistic TSegmentsUserDataStatsAggregator::ToStatistic(const TVectorType& vector) {
    TNormalSufficientStatistic statistic;
    statistic.MutableMean()->CopyFrom(vector);
    statistic.SetCount(1);
    return statistic;
}
