#include "common.h"
#include "normalize.h"
#include "segment_features_calculator.h"

#include <crypta/lib/native/features_calculator/features_calculator.h>
#include <crypta/lib/proto/user_data/user_data_stats.pb.h>

#include <library/cpp/digest/md5/md5.h>
#include <library/cpp/iterator/functools.h>

#include <util/generic/algorithm.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>
#include <util/string/join.h>

using namespace NCrypta::NLookalike;

TSegmentFeaturesCalculator::TSegmentFeaturesCalculator(NCrypta::TFeaturesMapping featuresMapping)
    : TFeaturesCalculator(std::move(featuresMapping)) {
}

TString TSegmentFeaturesCalculator::PrepareFloatFeatures(const NLab::TUserDataStats& segment) const {
    const auto& userDataStatsVector = segment.GetDistributions().GetMain().GetMean().GetData();
    TVector<float> segmentSite2vec(userDataStatsVector.begin(), userDataStatsVector.end());
    Normalize(segmentSite2vec);

    auto totalYuids = segment.GetCounts().GetTotal();
    Y_ENSURE(totalYuids > 0, "User data stats should contain at least one yandexuid");

    TVector<float> segmentCategoricalRatios(FeaturesMapping.size(), 0);

    for (const auto& gender : segment.GetAttributes().GetGender()) {
        AddFeatureToVector(segmentCategoricalRatios,
                           GetFeatureName("gender", static_cast<int>(gender.GetGender())),
                           static_cast<float>(gender.GetCount()) / totalYuids);
    }

    for (const auto& age : segment.GetAttributes().GetAge()) {
        AddFeatureToVector(segmentCategoricalRatios,
                           GetFeatureName("age", static_cast<int>(age.GetAge())),
                           static_cast<float>(age.GetCount()) / totalYuids);
    }

    for (const auto& income : segment.GetAttributes().GetIncome()) {
        AddFeatureToVector(segmentCategoricalRatios,
                           GetFeatureName("income", static_cast<int>(income.GetIncome())),
                           static_cast<float>(income.GetCount()) / totalYuids);
    }

    for (const auto& region : segment.GetAttributes().GetRegion()) {
        const auto& regionFeature = GetFeatureName("city", region.GetRegion());
        if (regionFeature != "city_0") {
            const auto regionRatio = static_cast<float>(region.GetCount()) / totalYuids;
            if (!AddFeatureToVector(segmentCategoricalRatios, regionFeature, regionRatio)) {
                AddFeatureToVector(segmentCategoricalRatios, "city_other", regionRatio);
            }
        }
    }

    for (const auto& strata : segment.GetStratum().GetStrata()) {
        for (auto segmentInfo: strata.GetSegment()) {
            AddFeatureToVector(segmentCategoricalRatios,
                               GetFeatureName(segmentInfo.GetSegment().GetKeyword(), segmentInfo.GetSegment().GetID()),
                               static_cast<float>(segmentInfo.GetCount()) / totalYuids);
        }
    }
    return JoinSeq(",", NFuncTools::Concatenate(segmentSite2vec, segmentCategoricalRatios));
}

TString TSegmentFeaturesCalculator::PrepareAffinitiveSitesIds(const NLab::TUserDataStats& segment) {
    const auto& affinitiesProto = segment.GetAffinities().GetHosts().GetToken();
    TVector<NLab::TUserDataStats::TWeightedTokenStats> affinities(affinitiesProto.begin(), affinitiesProto.end());

    Sort(affinities, [](const auto& left, const auto& right) { return left.GetWeight() > right.GetWeight(); });

    TVector<TString> topHosts;
    for (const auto& affinity : affinities) {
        if (topHosts.size() >= SegmentsAffinitiveSitesN) {
            break;
        }
        topHosts.push_back(affinity.GetToken());
    }

    return JoinSeq(" ", topHosts);
}

TString TSegmentFeaturesCalculator::PrepareAffinitiveApps(const NLab::TUserDataStats& segment) {
    const auto& affinitiesProto = segment.GetAffinities().GetApps().GetToken();
    TVector<NLab::TUserDataStats::TWeightedTokenStats> affinities(affinitiesProto.begin(), affinitiesProto.end());

    Sort(affinities, [](const auto& left, const auto& right) { return left.GetWeight() > right.GetWeight(); });

    TVector<TString> topApps;
    for (const auto& affinity : affinities) {
        if (topApps.size() >= SegmentsAffinitiveAppsN) {
            break;
        }
        topApps.push_back(ToString(MD5::CalcHalfMix(affinity.GetToken())));
    }

    return JoinSeq(" ", topApps);
}
