#include "filter.h"

using namespace NCrypta::NSiberia::NCustomAudience;

bool NFilter::Filter(const NLab::TUserData& userData, const TEncodedCaRule& rule) {
    return FilterAttributes(userData, rule) && FilterKernel(userData, rule);
}

bool NFilter::FilterAttributes(const NLab::TUserData& userData, const TEncodedCaRule& rule) {
    if (!userData.HasAttributes()) {
        return false;
    }

    const auto& userDataAttributes = userData.GetAttributes();

    return ((!rule.OriginalRule.HasGender() || (rule.OriginalRule.GetGender() == userDataAttributes.GetGender())) &&
            (rule.Ages.empty() || rule.Ages.contains(userDataAttributes.GetAge())) &&
            (rule.Devices.empty() || rule.Devices.contains(userDataAttributes.GetDevice())) &&
            (rule.Regions.empty() || rule.Regions.contains(userDataAttributes.GetRegion())) &&
            (rule.Incomes.empty() || rule.Incomes.contains(userDataAttributes.GetIncome())) &&
            (rule.Countries.empty() || rule.Countries.contains(userDataAttributes.GetCountry())) &&
            (rule.Cities.empty() || rule.Cities.contains(userDataAttributes.GetCity())));
}

bool NFilter::FilterKernel(const NLab::TUserData& userData, const TEncodedCaRule& rule) {
    const auto& userDataAffinities = userData.GetAffinitiesEncoded();
    const auto& ruleSegments = rule.OriginalRule.GetSegments();

    if (rule.IsKernelEmpty()) {
        return !rule.HasUnknownItems();
    }
    if (FilterSegments(userData, ruleSegments)) {
        return true;
    }
    if (HasTokens(userDataAffinities.GetWords(), rule.Words)) {
        return true;
    }
    if (HasTokens(userDataAffinities.GetHosts(), rule.Hosts)) {
        return true;
    }
    if (HasTokens(userDataAffinities.GetApps(), rule.Apps)) {
        return true;
    }
    if (HasTokens(userDataAffinities.GetAffinitiveSites(), rule.AffinitiveSites)) {
        return true;
    }
    if (HasTokens(userDataAffinities.GetTopCommonSites(), rule.TopCommonSites)) {
        return true;
    }

    return false;
}

bool NFilter::FilterSegments(const NLab::TUserData& userData, const ::google::protobuf::RepeatedPtrField<NLab::TSegment>& ruleSegments) {
    if (!ruleSegments.empty()) {
        const auto& userDataSegmentsProto = userData.GetSegments().GetSegment();
        THashSet<NLab::TSegment> userDataSegments;
        userDataSegments.reserve(userDataSegmentsProto.size());
        userDataSegments.insert(userDataSegmentsProto.begin(), userDataSegmentsProto.end());

        for (const auto& segment : ruleSegments) {
            if (userDataSegments.contains(segment)) {
                return true;
            }
        }
    }

    return false;
}

bool NFilter::HasTokens(const ::google::protobuf::RepeatedField<ui32>& userDataTokenIdsProto, const THashSet<ui32>& ruleTokenIds) {
    if (!ruleTokenIds.empty()) {
        for (const auto& id : userDataTokenIdsProto) {
            if (ruleTokenIds.contains(id)) {
                return true;
            }
        }
    }

    return false;
}

bool NFilter::HasTokens(const ::google::protobuf::RepeatedPtrField<NLab::TWeightedTokenEncoded>& userDataTokenIdsProto, const THashSet<ui32>& ruleTokenIds) {
    if (!ruleTokenIds.empty()) {
        for (const auto& token : userDataTokenIdsProto) {
            if (ruleTokenIds.contains(token.GetId())) {
                return true;
            }
        }
    }

    return false;
}

bool NFilter::Filter(const NLab::TUserData& userData, const TEncodedExtendedCaRule& rule) {
    return FilterAttributes(userData, rule) && FilterKernel(userData, rule);
}

bool NFilter::FilterAttributes(const NLab::TUserData& userData, const TEncodedExtendedCaRule& rule) {
    if (!userData.HasAttributes()) {
        return false;
    }

    const auto& userDataAttributes = userData.GetAttributes();

    return ((!rule.OriginalRule.HasGender() || (rule.OriginalRule.GetGender() == userDataAttributes.GetGender())) &&
            (rule.Ages.empty() || rule.Ages.contains(userDataAttributes.GetAge())) &&
            (rule.Incomes.empty() || rule.Incomes.contains(userDataAttributes.GetIncome())) &&
            (rule.Regions.empty() || rule.Regions.contains(userDataAttributes.GetRegion())) &&
            (rule.Countries.empty() || rule.Countries.contains(userDataAttributes.GetCountry())) &&
            (rule.Devices.empty() || rule.Devices.contains(userDataAttributes.GetDevice())));
}

bool NFilter::FilterKernel(const NLab::TUserData& userData, const TEncodedExtendedCaRule& rule) {
    if (!rule.Kernel.Defined()) {
        return true;
    }

    const auto& userDataAffinities = userData.GetAffinitiesEncoded();

    const auto& phrases = rule.Kernel->Phrases;
    const auto& hosts = rule.Kernel->Hosts;
    const auto& apps = rule.Kernel->Apps;

    if (rule.Kernel->AggregateByOr) {
        if (!phrases.empty() && SatisfyPhraseGroups(userDataAffinities.GetWords(), phrases)) {
            return true;
        }

        if (!hosts.empty() && SatisfyTokenGroups(userDataAffinities.GetHosts(), hosts)) {
            return true;
        }

        if (!apps.empty() && SatisfyTokenGroups(userDataAffinities.GetApps(), apps)) {
            return true;
        }

        return false;
    } else {
        if (!phrases.empty() && !SatisfyPhraseGroups(userDataAffinities.GetWords(), phrases)) {
            return false;
        }

        if (!hosts.empty() && !SatisfyTokenGroups(userDataAffinities.GetHosts(), hosts)) {
            return false;
        }

        if (!apps.empty() && !SatisfyTokenGroups(userDataAffinities.GetApps(), apps)) {
            return false;
        }

        return true;
    }
}

bool NFilter::HasPhrase(const THashSet<ui32>& words, const TEncodedExtendedCaRule::TPhrase& phrase) {
    if (phrase.HasUnknownWords) {
        return false;
    }

    for (const auto& word: phrase.Words) {
        if (!words.contains(word)) {
            return false;
        }
    }

    return true;
}

bool NFilter::SatisfyPhraseGroup(const THashSet<ui32>& words, const TEncodedExtendedCaRule::TPhraseGroup& group) {
    if (group.Negative) {
        for (const auto& phrase: group.Phrases) {
            if (HasPhrase(words, phrase)) {
                return false;
            }
        }

        return true;
    } else {
        for (const auto& phrase: group.Phrases) {
            if (HasPhrase(words, phrase)) {
                return true;
            }
        }

        return false;
    }
}

bool NFilter::SatisfyPhraseGroups(const ::google::protobuf::RepeatedField<ui32>& userDataTokenIdsProto, const TVector<TEncodedExtendedCaRule::TPhraseGroup>& groups) {
    if (groups.empty()) {
        return true;
    }

    THashSet<ui32> words;
    for (const auto& word: userDataTokenIdsProto) {
        words.insert(word);
    }

    for (const auto& group: groups) {
        if (!SatisfyPhraseGroup(words, group)) {
            return false;
        }
    }

    return true;
}

bool NFilter::SatisfyTokenGroup(const THashSet<ui32>& tokens, const TEncodedExtendedCaRule::TTokenGroup& group) {
    if (group.Items.empty()) {
        return group.Negative || !group.HasUnknownTokens;
    }

    if (group.Negative) {
        for (const auto& item: group.Items) {
            if (tokens.contains(item)) {
                return false;
            }
        }

        return true;
    } else {
        for (const auto& item: group.Items) {
            if (tokens.contains(item)) {
                return true;
            }
        }

        return false;
    }
}

bool NFilter::SatisfyTokenGroups(const ::google::protobuf::RepeatedField<ui32>& userDataTokenIdsProto, const TVector<TEncodedExtendedCaRule::TTokenGroup>& groups) {
    if (groups.empty()) {
        return true;
    }

    THashSet<ui32> tokens;
    for (const auto& token: userDataTokenIdsProto) {
        tokens.insert(token);
    }

    for (const auto& group: groups) {
        if (!SatisfyTokenGroup(tokens, group)) {
            return false;
        }
    }

    return true;
}
