#include <google/protobuf/util/message_differencer.h>
#include <library/cpp/testing/unittest/registar.h>

#include <crypta/lib/native/yaml/yaml2proto.h>
#include <crypta/lib/proto/user_data/user_data.pb.h>
#include <crypta/lib/proto/user_data/user_data_stats.pb.h>
#include <crypta/lab/lib/native/user_data_stats_aggregator.h>

using namespace NLab;
using namespace NLab::NEncodedUserData;

namespace {
    TUserDataStats GetUserDataStatsEncoded() {
        const TString userDataYaml = "attributes:\n"
                                     "  age:\n"
                                     "  - age: from_0_to_17\n"
                                     "    count: 1\n"
                                     "  device:\n"
                                     "  - device: desktop\n"
                                     "    count: 1\n"
                                     "  gender:\n"
                                     "  - gender: male\n"
                                     "    count: 1\n"
                                     "  region:\n"
                                     "  - region: 100\n"
                                     "    count: 1\n"
                                     "  income:\n"
                                     "  - income: income_b1\n"
                                     "    count: 1\n"
                                     "  gender_age_income:\n"
                                     "  - gender_age_income:\n"
                                     "      gender: male\n"
                                     "      age: from_0_to_17\n"
                                     "      income: income_b1\n"
                                     "    count: 1\n"
                                     "stratum:\n"
                                     "  strata:\n"
                                     "  - strata:\n"
                                     "      device: desktop\n"
                                     "      country: russia\n"
                                     "      city: moscow\n"
                                     "      has_crypta_i_d: true\n"
                                     "    segment:\n"
                                     "    - segment:\n"
                                     "        keyword: 1\n"
                                     "        i_d: 2\n"
                                     "      count: 1\n"
                                     "    age:\n"
                                     "    - age: from_0_to_17\n"
                                     "      count: 1\n"
                                     "    count: 1\n"
                                     "distributions:\n"
                                     "  main:\n"
                                     "    mean:\n"
                                     "      data:\n"
                                     "      - 1.0\n"
                                     "      - 1.0\n"
                                     "    count: 1\n"
                                     "counts:\n"
                                     "  with_data: 1\n"
                                     "  total: 1\n"
                                     "  uniq_yuid: 1\n"
                                     "group_i_d: Group\n"
                                     "identifiers:\n"
                                     "  identifiers:\n"
                                     "  - key: str_1\n"
                                     "    value: str_2\n"
                                     "  not_unique: false\n"
                                     "segment_info:\n"
                                     "  info:\n"
                                     "    str_1: str_2\n"
                                     "affinities_encoded:\n"
                                     "  hosts:\n"
                                     "    token:\n"
                                     "    - id: 1\n"
                                     "      count: 1\n"
                                     "    tokens_count: 1\n"
                                     "    users_count: 1\n"
                                     "  words:\n"
                                     "    token:\n"
                                     "    - id: 1\n"
                                     "      count: 1\n"
                                     "    - id: 2\n"
                                     "      count: 2\n"
                                     "    - id: 3\n"
                                     "      count: 3\n"
                                     "    - id: 4\n"
                                     "      count: 4\n"
                                     "    - id: 5\n"
                                     "      count: 5\n"
                                     "    tokens_count: 5\n"
                                     "    users_count: 30\n"
                                     "  apps:\n"
                                     "    token:\n"
                                     "    - id: 1\n"
                                     "      count: 1\n"
                                     "    - id: 2\n"
                                     "      count: 5\n"
                                     "    tokens_count: 2\n"
                                     "    users_count: 8";

        return NCrypta::Yaml2Proto<TUserDataStats>(userDataYaml);
    }

    TUserDataStats GetUserDataStats() {
        const TString userDataYaml = "attributes:\n"
                                    "  age:\n"
                                    "  - age: from_0_to_17\n"
                                    "    count: 1\n"
                                    "  device:\n"
                                    "  - device: desktop\n"
                                    "    count: 1\n"
                                    "  gender:\n"
                                    "  - gender: male\n"
                                    "    count: 1\n"
                                    "  region:\n"
                                    "  - region: 100\n"
                                    "    count: 1\n"
                                    "  income:\n"
                                    "  - income: income_b1\n"
                                    "    count: 1\n"
                                    "  gender_age_income:\n"
                                    "  - gender_age_income:\n"
                                    "      gender: male\n"
                                    "      age: from_0_to_17\n"
                                    "      income: income_b1\n"
                                    "    count: 1\n"
                                    "stratum:\n"
                                    "  strata:\n"
                                    "  - strata:\n"
                                    "      device: desktop\n"
                                    "      country: russia\n"
                                    "      city: moscow\n"
                                    "      has_crypta_i_d: true\n"
                                    "    segment:\n"
                                    "    - segment:\n"
                                    "        keyword: 1\n"
                                    "        i_d: 2\n"
                                    "      count: 1\n"
                                    "    age:\n"
                                    "    - age: from_0_to_17\n"
                                    "      count: 1\n"
                                    "    gender:\n"
                                    "    - gender: male\n"
                                    "      count: 1\n"
                                    "    count: 1\n"
                                    "distributions:\n"
                                    "  main:\n"
                                    "    mean:\n"
                                    "      data:\n"
                                    "      - 1.0\n"
                                    "      - 1.0\n"
                                    "    count: 1\n"
                                    "counts:\n"
                                    "  with_data: 1\n"
                                    "  total: 1\n"
                                    "  uniq_yuid: 1\n"
                                    "group_i_d: Group\n"
                                    "identifiers:\n"
                                    "  identifiers:\n"
                                    "  - key: str_1\n"
                                    "    value: str_2\n"
                                    "  not_unique: false\n"
                                    "segment_info:\n"
                                    "  info:\n"
                                    "    str_1: str_2\n"
                                    "affinities:\n"
                                    "  hosts:\n"
                                    "    token:\n"
                                    "    - token: host_1\n"
                                    "      weight: 1.2\n"
                                    "      count: 1\n"
                                    "    tokens_count: 1\n"
                                    "    users_count: 1\n"
                                    "  words:\n"
                                    "    token:\n"
                                    "    - token: word_1\n"
                                    "      weight: 100\n"
                                    "      count: 1\n"
                                    "    - token: word_2\n"
                                    "      weight: 200\n"
                                    "      count: 2\n"
                                    "    - token: word_3\n"
                                    "      weight: 3\n"
                                    "      count: 3\n"
                                    "    - token: word_4\n"
                                    "      weight: 4\n"
                                    "      count: 4\n"
                                    "    - token: word_5\n"
                                    "      weight: 5\n"
                                    "      count: 5\n"
                                    "    tokens_count: 5\n"
                                    "    users_count: 30\n"
                                    "  apps:\n"
                                    "    token:\n"
                                    "    - token: app_1\n"
                                    "      weight: 10\n"
                                    "      count: 100\n"
                                    "    - token: app_2\n"
                                    "      weight: 20\n"
                                    "      count: 200\n"
                                    "    tokens_count: 2\n"
                                    "    users_count: 8";

        return NCrypta::Yaml2Proto<TUserDataStats>(userDataYaml);
    }

    TUserDataStats GetUserDataStatsWithAffinities1() {
        TUserDataStats stats;
        auto* words = stats.MutableAffinities()->MutableWords();
        words->SetTokensCount(2);
        words->SetUsersCount(1);

        auto* token1 = words->MutableToken()->Add();
        token1->SetToken("word_1");
        token1->SetWeight(1);
        token1->SetCount(1);

        auto* token2 = words->MutableToken()->Add();
        token2->SetToken("word_2");
        token2->SetWeight(2);
        token2->SetCount(1);

        return stats;
    }

    TUserDataStats GetUserDataStatsWithAffinities2() {
        TUserDataStats stats;
        auto* words = stats.MutableAffinities()->MutableWords();
        words->SetTokensCount(2);
        words->SetUsersCount(1);

        auto* token1 = words->MutableToken()->Add();
        token1->SetToken("word_2");
        token1->SetWeight(2);
        token1->SetCount(1);

        auto* token2 = words->MutableToken()->Add();
        token2->SetToken("word_3");
        token2->SetWeight(3);
        token2->SetCount(1);

        return stats;
    }

    TUserDataStats MergeUserDataStats(bool accumulateAffinities) {
        TUserDataStatsAggregator<> statsAggregator({.MaxTokensCount = 1000, .MinSampleRatio = 0, .AccumulateAffinities = accumulateAffinities});
        statsAggregator.UpdateWith(GetUserDataStatsWithAffinities1());
        statsAggregator.UpdateWith(GetUserDataStatsWithAffinities2());

        TUserDataStats result;
        statsAggregator.MergeInto(result);
        return result;
    }

    THashMap<TString, std::pair<float, ui64>> GetWords(const TUserDataStats& stats) {
        THashMap<TString, std::pair<float, ui64>> result;
        for (const auto& token : stats.GetAffinities().GetWords().GetToken()) {
            result.insert({token.GetToken(), {token.GetWeight(), token.GetCount()}});
        }
        return result;
    }
}

Y_UNIT_TEST_SUITE(TUserDataStatsAggregator) {
    Y_UNIT_TEST(EncodedUpdateWithCommonTest) {
        TEncodedUserDataStatsAggregator statsAggregator({.MaxTokensCount = 2, .MinSampleRatio = 0.1, .AccumulateAffinities = true});
        statsAggregator.UpdateWith(GetUserDataStatsEncoded());
        statsAggregator.UpdateWith(GetUserDataStatsEncoded());

        TUserDataStats result;
        TIdToWeightedTokenDict hosts = {
            {1, {.Token = "host_1", .Weight = 1}}
        };
        TIdToWeightedTokenDict words = {
            {1, {.Token = "word_1", .Weight = 100}},
            {2, {.Token = "word_2", .Weight = 100}},
            {3, {.Token = "word_3", .Weight = 2}},
            {4, {.Token = "word_4", .Weight = 1}},
            {5, {.Token = "word_5", .Weight = 1}},
        };

        TIdToWeightedTokenDict apps = {
            {1, {.Token = "app_1", .Weight = 10}},
            {2, {.Token = "app_2", .Weight = 20}},
        };

        statsAggregator.MergeInto(result, &words, &hosts, &apps);

        const TString refYaml = "attributes:\n"
                                "  age:\n"
                                "  - age: from_0_to_17\n"
                                "    count: 2\n"
                                "  device:\n"
                                "  - device: desktop\n"
                                "    count: 2\n"
                                "  gender:\n"
                                "  - gender: male\n"
                                "    count: 2\n"
                                "  region:\n"
                                "  - region: 100\n"
                                "    count: 2\n"
                                "  income:\n"
                                "  - income: income_b1\n"
                                "    count: 2\n"
                                "  gender_age_income:\n"
                                "  - gender_age_income:\n"
                                "      gender: male\n"
                                "      age: from_0_to_17\n"
                                "      income: income_b1\n"
                                "    count: 2\n"
                                "stratum:\n"
                                "  strata:\n"
                                "  - strata:\n"
                                "      device: desktop\n"
                                "      country: russia\n"
                                "      city: moscow\n"
                                "      has_crypta_i_d: true\n"
                                "    segment:\n"
                                "    - segment:\n"
                                "        keyword: 1\n"
                                "        i_d: 2\n"
                                "      count: 2\n"
                                "    age:\n"
                                "    - age: from_0_to_17\n"
                                "      count: 2\n"
                                "    count: 2\n"
                                "distributions:\n"
                                "  main:\n"
                                "    mean:\n"
                                "      data:\n"
                                "      - 1.41421\n"
                                "      - 1.41421\n"
                                "    count: 2\n"
                                "    mean2:\n"
                                "      data:\n"
                                "      - 1\n"
                                "      - 1\n"
                                "counts:\n"
                                "  with_data: 2\n"
                                "  total: 2\n"
                                "  uniq_yuid: 2\n"
                                "identifiers:\n"
                                "  identifiers:\n"
                                "  - key: str_1\n"
                                "    value: str_2\n"
                                "  not_unique: false\n"
                                "segment_info:\n"
                                "  info:\n"
                                "  - key: str_1\n"
                                "    value: str_2\n"
                                "affinities_encoded:\n"
                                "  hosts:\n"
                                "    token:\n"
                                "    - id: 1\n"
                                "      count: 2\n"
                                "    tokens_count: 2\n"
                                "    users_count: 2\n"
                                "  words:\n"
                                "    token:\n"
                                "    - id: 3\n"
                                "      count: 6\n"
                                "    - id: 5\n"
                                "      count: 10\n"
                                "    tokens_count: 10\n"
                                "    users_count: 60\n"
                                "  apps:\n"
                                "    token:\n"
                                "    - id: 2\n"
                                "      count: 10\n"
                                "    - id: 1\n"
                                "      count: 2\n"
                                "    tokens_count: 4\n"
                                "    users_count: 16";

        const auto& ref = NCrypta::Yaml2Proto<TUserDataStats>(refYaml);
        UNIT_ASSERT(google::protobuf::util::MessageDifferencer::ApproximatelyEquals(ref, result));
    }

    Y_UNIT_TEST(UpdateWithCommonTest) {
        TUserDataStatsAggregator<> statsAggregator({.MaxTokensCount = 2, .MinSampleRatio = 0.1, .AccumulateAffinities = true});
        statsAggregator.UpdateWith(GetUserDataStats());
        statsAggregator.UpdateWith(GetUserDataStats());

        TUserDataStats result;
        statsAggregator.MergeInto(result);

        const TString refYaml = "attributes:\n"
                                "  age:\n"
                                "  - age: from_0_to_17\n"
                                "    count: 2\n"
                                "  device:\n"
                                "  - device: desktop\n"
                                "    count: 2\n"
                                "  gender:\n"
                                "  - gender: male\n"
                                "    count: 2\n"
                                "  region:\n"
                                "  - region: 100\n"
                                "    count: 2\n"
                                "  income:\n"
                                "  - income: income_b1\n"
                                "    count: 2\n"
                                "  gender_age_income:\n"
                                "  - gender_age_income:\n"
                                "      gender: male\n"
                                "      age: from_0_to_17\n"
                                "      income: income_b1\n"
                                "    count: 2\n"
                                "stratum:\n"
                                "  strata:\n"
                                "  - strata:\n"
                                "      device: desktop\n"
                                "      country: russia\n"
                                "      city: moscow\n"
                                "      has_crypta_i_d: true\n"
                                "    segment:\n"
                                "    - segment:\n"
                                "        keyword: 1\n"
                                "        i_d: 2\n"
                                "      count: 2\n"
                                "    age:\n"
                                "    - age: from_0_to_17\n"
                                "      count: 2\n"
                                "    gender:\n"
                                "    - gender: male\n"
                                "      count: 2\n"
                                "    count: 2\n"
                                "distributions:\n"
                                "  main:\n"
                                "    mean:\n"
                                "      data:\n"
                                "      - 1.41421\n"
                                "      - 1.41421\n"
                                "    count: 2\n"
                                "    mean2:\n"
                                "      data:\n"
                                "      - 1\n"
                                "      - 1\n"
                                "counts:\n"
                                "  with_data: 2\n"
                                "  total: 2\n"
                                "  uniq_yuid: 2\n"
                                "identifiers:\n"
                                "  identifiers:\n"
                                "  - key: str_1\n"
                                "    value: str_2\n"
                                "  not_unique: false\n"
                                "segment_info:\n"
                                "  info:\n"
                                "    str_1: str_2\n"
                                "affinities:\n"
                                "  hosts:\n"
                                "    token:\n"
                                "    - token: host_1\n"
                                "      weight: 2.4\n"
                                "      count: 2\n"
                                "    tokens_count: 2\n"
                                "    users_count: 2\n"
                                "  words:\n"
                                "    token:\n"
                                "    - token: word_5\n"
                                "      weight: 10\n"
                                "      count: 10\n"
                                "    - token: word_4\n"
                                "      weight: 8\n"
                                "      count: 8\n"
                                "    tokens_count: 10\n"
                                "    users_count: 60\n"
                                "  apps:\n"
                                "    token:\n"
                                "    - token: app_2\n"
                                "      weight: 40\n"
                                "      count: 400\n"
                                "    - token: app_1\n"
                                "      weight: 20\n"
                                "      count: 200\n"
                                "    tokens_count: 4\n"
                                "    users_count: 16";

        const auto& ref = NCrypta::Yaml2Proto<TUserDataStats>(refYaml);
        UNIT_ASSERT(google::protobuf::util::MessageDifferencer::ApproximatelyEquals(ref, result));
    }

    Y_UNIT_TEST(UpdateWithAffinitiesTest) {
        const auto& result = MergeUserDataStats(true);

        THashMap<TString, std::pair<float, ui64>> refWords = {
            {"word_1", {1, 1}},
            {"word_2", {4, 2}},
            {"word_3", {3, 1}},
        };

        UNIT_ASSERT_EQUAL(refWords, GetWords(result));
        UNIT_ASSERT_EQUAL(4, result.GetAffinities().GetWords().GetTokensCount());
        UNIT_ASSERT_EQUAL(2, result.GetAffinities().GetWords().GetUsersCount());
    }

    Y_UNIT_TEST(UpdateWithWrongAffinitiesFormatTest) {
        TUserDataStatsAggregator<TDefaultAffinitiesOptions> statsAggregator({.MaxTokensCount = 1000, .MinSampleRatio = 0, .AccumulateAffinities = true});

        TUserDataStats stats;
        auto* words = stats.MutableAffinitiesEncoded()->MutableWords();
        words->SetTokensCount(2);

        UNIT_ASSERT_EXCEPTION(statsAggregator.UpdateWith(stats), yexception);
    }

    Y_UNIT_TEST(UpdateWithAffinitiesWoAccumalationTest) {
        const auto& result = MergeUserDataStats(false);

        THashMap<TString, std::pair<float, ui64>> refWords = {
            {"word_1", {1, 1}},
            {"word_2", {2, 1}},
            {"word_3", {3, 1}},
        };

        UNIT_ASSERT_EQUAL(refWords, GetWords(result));
        UNIT_ASSERT_EQUAL(3, result.GetAffinities().GetWords().GetTokensCount());
        UNIT_ASSERT_EQUAL(1, result.GetAffinities().GetWords().GetUsersCount());
    }
}
