#include "catboost_features_calculator.h"
#include "ads/bsyeti/libs/counter_lib/counters.h"
#include "library/cpp/protobuf/json/proto2json.h"

#include <crypta/lib/native/features_calculator/features_calculator.h>

#include <ads/bsyeti/libs/primitives/counter_proto/counter_ids.pb.h>
#include <yabs/proto/user_profile.pb.h>

#include <library/cpp/iterator/enumerate.h>
#include <library/cpp/iterator/zip.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>
#include <util/string/join.h>
#include <util/string/strip.h>

using namespace NCrypta;

namespace {
    const size_t MAX_QUERIES = 20;

    template <typename TCollection, typename TGetter>
    TString GetQueriesText(const TCollection& queries, const TGetter& getter) {
        TVector<TStringBuf> queryViews;
        queryViews.reserve(MAX_QUERIES);

        for (const auto& [i, query] : Enumerate(queries)) {
            if (i == MAX_QUERIES) {
                break;
            }
            queryViews.push_back(getter(query));
        }
        return Join("", Strip(JoinSeq(". ", queryViews)), ".");
    }

    void AddUserItemToFeatures(const TFeaturesCalculator& calculator, TVector<float>& features, const NBSYeti::TUserItemProto& item, const TString& alias) {
        if (item.has_string_value()) {
            calculator.AddFeatureToVector(features, TFeaturesCalculator::GetFeatureName(alias, item.string_value()));
        } else if (item.uint_values().size() > 0) {
            for (const auto& uint: item.uint_values()) {
                calculator.AddFeatureToVector(features, TFeaturesCalculator::GetFeatureName(alias, uint));
            }
        }
    }
}

TCatboostFeaturesCalculator::TCatboostFeaturesCalculator(
    const NCrypta::TFeaturesMapping& featuresMapping,
    const TIdToPrefix& countersToFeatures,
    const TIdToPrefix& keywordsToFeatures
)
    : FeaturesCalculator(featuresMapping)
    , CountersToFeatures(countersToFeatures)
    , KeywordsToFeatures(keywordsToFeatures) {
}

TVector<float> TCatboostFeaturesCalculator::PrepareFloatFeatures(const yabs::proto::Profile& profile) const {
    TVector<float> floatFeatures(FeaturesCalculator.GetSize(), 0);
    struct {
        void OnCounterValue(NBSYeti::TCounterId id, NBSYeti::TCounterKey key, float value) {
            const auto counterSearch = Object.CountersToFeatures.find(id);
            Object.FeaturesCalculator.AddFeatureToVector(FloatFeatures, TFeaturesCalculator::GetFeatureName(counterSearch->second, key), static_cast<float>(value));
        };
        bool NeedCounterId(NBSYeti::TCounterId id) {
            const auto counterSearch = Object.CountersToFeatures.find(id);
            return counterSearch != Object.CountersToFeatures.end();
        }
        void OnCounterIdEnd(NBSYeti::TCounterId /* id */) {
        };
        TVector<float>& FloatFeatures;
        const TCatboostFeaturesCalculator& Object;
    } visitor {.FloatFeatures=floatFeatures, .Object=*this};
    NBSYeti::IterateOverCounters(profile.counters(), profile.packed_counters(), visitor);

    for (const auto& item: profile.items()) {
        auto it = KeywordsToFeatures.find(item.keyword_id());
        if (it == KeywordsToFeatures.end()) {
            continue;
        }

        AddUserItemToFeatures(FeaturesCalculator, floatFeatures, item, it->second);
    }

    return floatFeatures;
}

TVector<float> TCatboostFeaturesCalculator::PrepareFloatFeatures(const NBSYeti::TProfile& profile) const {
    TVector<float> floatFeatures(FeaturesCalculator.GetSize(), 0);
    const auto& counters = profile.Counters;
    for (const auto& [counterId, alias] : CountersToFeatures) {
        
        auto locator = counters.FindFirstKey(counterId);
        while (NBSYeti::TCounters::IsOK(locator)) {
            FeaturesCalculator.AddFeatureToVector(floatFeatures, TFeaturesCalculator::GetFeatureName(alias, counters.GetKey(locator)), NBSYeti::CastValueTo<float>(counters.GetValue(locator)));
            locator = counters.FindNextKey(locator);
        }
    }

    for (const auto& [keywordId, alias] : KeywordsToFeatures) {
        for (const auto& [key, item] : profile.UserItems.FindRange(keywordId)) {
            AddUserItemToFeatures(FeaturesCalculator, floatFeatures, item.Get(), alias);
        }
    }

    return floatFeatures;
}

TString TCatboostFeaturesCalculator::PrepareTextFeatures(const yabs::proto::Profile& profile) const {
    return GetQueriesText(profile.Getqueries(), [](const auto& query) -> TStringBuf {
        return query.Getquery_text();
    });
}

TString TCatboostFeaturesCalculator::PrepareTextFeatures(const NBSYeti::TProfile& profile) const {
    return GetQueriesText(profile.Queries, [](const auto& query) -> TStringBuf {
        return query.second.GetQueryText();
    });
}

size_t TCatboostFeaturesCalculator::GetNumFloatFeatures() const {
    return FeaturesCalculator.GetSize();
}

bool TCatboostFeaturesCalculator::HasAllFeatures(const NBSYeti::TDataFilter& dataFilter) const {
    for (const auto& [counterId, _]: CountersToFeatures) {
        if (!dataFilter.HasCounter(counterId)) {
            return false;
        }
    }

    for (const auto& [keywordId, _]: KeywordsToFeatures) {
        if (!dataFilter.HasKeyword(keywordId)) {
            return false;
        }
    }

    return dataFilter.HasKeyword(NBSData::NKeywords::KW_SITE_SEARCH_TEXT);
}
