#include "encoded_extended_ca_rule.h"

#include <crypta/lib/native/lemmer/get_lemmas.h>

#include<library/cpp/langmask/langmask.h>

#include <util/string/split.h>

using namespace NCrypta::NSiberia::NCustomAudience;
using namespace NLab::NEncodedUserData;

namespace {
    template<typename TResult, typename TProbobufField>
    void FillHashSet(TResult& result, const TProbobufField& field) {
        result.insert(field.begin(), field.end());
    }
}

TEncodedExtendedCaRule::TPhrase::TPhrase(const TVector<ui32>& words, bool hasUnknownWords)
    : Words(words)
    , HasUnknownWords(hasUnknownWords)
{}

TEncodedExtendedCaRule::TPhrase TEncodedExtendedCaRule::TPhrase::FromString(const TString& s, const TTokenToIdDict& dict) {
    TVector<ui32> words;
    bool hasUnknownWords = false;

    for (auto&& part: StringSplitter(s).Split(' ').SkipEmpty()) {
            const auto& lemmas = GetLemmas(TUtf16String::FromUtf8(part.Token()), NLanguageMasks::BasicLanguages());
            if (lemmas.empty()) {
                hasUnknownWords = true;
                continue;
            }
            //Y_ENSURE(!lemmas.empty(), "Empty lemmas for token '" << part.Token() << "'");

            const auto& lemma = TString::FromUtf16(*lemmas.begin());
            const auto& it = dict.find(lemma);

            if (it == dict.end()) {
                hasUnknownWords = true;
                continue;
            }
            //Y_ENSURE(it != dict.end(), "Can't find lemma '" << lemma << "' in word dict");

            words.push_back(it->second);
    }

    return TPhrase(words, hasUnknownWords);
}

TEncodedExtendedCaRule::TPhraseGroup::TPhraseGroup(const TVector<TPhrase>& phrases, bool negative)
    : Phrases(phrases)
    , Negative(negative)
{}

TEncodedExtendedCaRule::TPhraseGroup TEncodedExtendedCaRule::TPhraseGroup::FromProto(const TGroup& group, const TTokenToIdDict& dict) {
    TVector<TPhrase> phrases;
    phrases.reserve(group.GetItems().size());

    for (const auto& s: group.GetItems()) {
        phrases.emplace_back(TPhrase::FromString(s, dict));
    }

    return TPhraseGroup(phrases, group.GetNegative());
}

TEncodedExtendedCaRule::TTokenGroup::TTokenGroup(const TVector<ui32>& items, bool hasUnknownTokens, bool negative)
    : Items(items)
    , HasUnknownTokens(hasUnknownTokens)
    , Negative(negative)
{}

TEncodedExtendedCaRule::TTokenGroup TEncodedExtendedCaRule::TTokenGroup::FromProto(const TGroup& group, const TTokenToIdDict& dict) {
    TVector<ui32> items;
    bool hasUnknownTokens = false;
    items.reserve(group.GetItems().size());

    for (const auto& token: group.GetItems()) {
        const auto& it = dict.find(token);
        if (it == dict.end()) {
            hasUnknownTokens = true;
            continue;
        }
        //Y_ENSURE(it != dict.end(), "Can't find token: '" << token << "'");
        items.push_back(it->second);
    }

    return TTokenGroup(items, hasUnknownTokens, group.GetNegative());
}

TEncodedExtendedCaRule::TEncodedKernel::TEncodedKernel(
    const TKernel& kernel,
    const TTokenToIdDict& wordDict,
    const TTokenToIdDict& hostDict,
    const TTokenToIdDict& appDict
) : AggregateByOr(kernel.GetAggregateByOr())
{
    for (const auto& group: kernel.GetPhrases()) {
        Phrases.emplace_back(TPhraseGroup::FromProto(group, wordDict));
    }

    for (const auto& group: kernel.GetHosts()) {
        Hosts.emplace_back(TTokenGroup::FromProto(group, hostDict));
    }

    for (const auto& group: kernel.GetApps()) {
        Apps.emplace_back(TTokenGroup::FromProto(group, appDict));
    }
}

TEncodedExtendedCaRule::TEncodedExtendedCaRule(
    const TExtendedCaRule& rule,
    const TTokenToIdDict& wordDict,
    const TTokenToIdDict& hostDict,
    const TTokenToIdDict& appDict
) : OriginalRule(rule)
{
    FillHashSet(Ages, rule.GetAges());
    FillHashSet(Incomes, rule.GetIncomes());
    FillHashSet(Regions, rule.GetRegions());
    FillHashSet(Countries, rule.GetCountries());
    FillHashSet(Devices, rule.GetDevices());

    if (rule.HasKernel()) {
        const auto& kernelProto = rule.GetKernel();
        if (!(kernelProto.GetPhrases().empty() && kernelProto.GetHosts().empty() && kernelProto.GetApps().empty())) {
            Kernel = MakeMaybe<TEncodedKernel>(kernelProto, wordDict, hostDict, appDict);
        }
    }
}
