#include <util/string/join.h>
#include <mail/so/spamstop/tools/so-common/safe_recode.h>
#include "spstat.h"
#include "rulesholder.h"
#include "processing_context.h"
#include "arform.h"

#define NOT_ANTI_RULE -1


void TShAttrs::SetLabel(TString str) {
    if (!str.empty() && (str.length() > 0))
        Label = std::move(str);
}

TString TShAttrs::BracketedLabel() const {
    if (!Label.empty() && Label.length() > 0)
        return "(" + Label + ")";

    return {};
}

void TShAttrs::Increment(const TStringBuf &pText, const TStringBuf &psh, bool bput, int incCounter = 0) {
    if (incCounter)
        count += incCounter;
    else
        count = 1;

    if (pText)
        SetLabel(TString{pText});
    if (psh)
        shString.assign(psh);
    if (bput)
        bPut = true;
}

struct TMLLogEscaper {

    friend IOutputStream &operator<<(IOutputStream &stream, const TMLLogEscaper &escaper) {

        for (char c : escaper.buf) {
            if ((c == '\t') || (c == '\n') || (c == '\r'))
                stream << ' ';
            else if ((c == '\\') || (c == '"')) {
                stream << '\\';
                stream << c;
            } else
                stream << c;
        }
        return stream;
    }

    TStringBuf buf;
};

IOutputStream &operator<<(IOutputStream &stream, const TMlLogBuilder &builder) {
    stream << "tskv\ttskv_format=" << builder.Prefix << "\tunixtime=" << Now().TimeT();

    for (const auto&[key, value]: builder.Parts)
        stream << '\t' << key << '=' << TMLLogEscaper{value};
    return stream;
}

void TMlLogBuilder::_Add(const TStringBuf &label, TString str) {
    try {
        str = SafeRecode(str);
    } catch (...) {
        ErrorLogger << TLOG_ERR << "fail to recode " << label << ':' << CurrentExceptionMessage();
    }

    const TString loweredLabel = to_lower(TString{label});
    auto it = Parts.find(loweredLabel);
    if (it != Parts.end()) {
        it->second = std::move(str);
    } else {
        Parts.emplace(loweredLabel, std::move(str));
    }
}

NJson::TJsonValue TClassifiersContext::TVecDiff::ToJson() const {

    TVector<TStringBuf> left(Left.size()), right(Right.size());

    Copy(Left.cbegin(), Left.cend(), left.begin());
    Copy(Right.cbegin(), Right.cend(), right.begin());

    SortUnique(left);
    SortUnique(right);

    TVector<TStringBuf> simDiff;

    SetSymmetricDifference(right.cbegin(), right.cend(), left.cbegin(), left.cend(),
                           std::back_inserter(simDiff));

    NJson::TJsonValue result;
    for (const auto &rule : simDiff) {
        result.AppendValue(rule);
    }
    return result;
}

NJson::TJsonValue TClassifiersContext::TMapDiff::ToJson() const {
    TVector<TStringBuf> leftKeys(Reserve(Left.size()));
    for(const auto& [k, _]: Left)
        leftKeys.emplace_back(k);

    TVector<TStringBuf> rightKeys(Reserve(Right.size()));
    for(const auto& [k, _]: Right)
        rightKeys.emplace_back(k);

    SortUnique(leftKeys);
    SortUnique(rightKeys);

    TVector<TStringBuf> simDiff, intersection;

    SetIntersection(rightKeys.cbegin(), rightKeys.cend(), leftKeys.cbegin(), leftKeys.cend(), std::back_inserter(intersection));
    SetSymmetricDifference(rightKeys.cbegin(), rightKeys.cend(), leftKeys.cbegin(), leftKeys.cend(), std::back_inserter(simDiff));


    NJson::TJsonValue result;

    {
        {
            NJson::TJsonValue &target = result["mis"];
            for (const auto &rule : simDiff) {
                target.AppendValue(rule);
            }
        }

        {
            NJson::TJsonValue &target = result["ne"];
            for (const auto &k: intersection) {
                if (Left.at(k) != Right.at(k)) {
                    target.AppendValue(k);
                }
            }
        }
    }
    return result;
}

NJson::TJsonValue TClassifiersContext::TAddonsDiff::ToJson() const {
    NJson::TJsonValue result;

    for (const auto&[name, features]: Addons) {
        TVector<TStringBuf> absentFeatures;
        TVector<const std::pair<const TString, float> *> nonEqRules;

        for (const auto &p: features) {
            const auto&[feature, weight] = p;
            const auto *prodFeatureWeight = MapFindPtr(Features, feature);
            if (prodFeatureWeight) {
                if (*prodFeatureWeight != weight) {
                    nonEqRules.emplace_back(&p);
                }
            } else {
                absentFeatures.emplace_back(feature);
            }
        }

        NJson::TJsonValue& localResult = result[name];
        {
            {
                NJson::TJsonValue &target = localResult["absent"];
                {
                    for (const auto &feature: absentFeatures) {
                        target.AppendValue(feature);
                    }
                }
            }
            {
                NJson::TJsonValue &target = localResult["ne"];
                for (const std::pair<const TString, float> *feature: nonEqRules) {
                    target[feature->first] = feature->second;
                }
            }
        }
    }
    return result;
}

NJson::TJsonValue TClassifiersContext::DiffReport(const TClassifiersContext& context1, const TClassifiersContext& context2) {
    NJson::TJsonValue result;

    result["rules_diff"] = TVecDiff{context1.RulesNames, context2.RulesNames}.ToJson();
    result["mnf_diff"] = TMapDiff{context1.Features, context2.Features}.ToJson();
    result["addons_diff"] = TAddonsDiff{context1.Features, context1.Addons}.ToJson();

    return result;
}

TRuleCurrent::TRuleCurrent(const TRuleDef& ruleDef, TRulesContext& master) noexcept
        : ruleDef(ruleDef)
        , Master(master) {
    SetScore(ruleDef.score);
}

IOutputStream& operator<<(IOutputStream& stream, const TRuleCurrent& rule) {
    stream << rule.GetDef().pRuleName;
    if (rule.IsSignificant()) {
        stream << ' ' << Prec(rule.Score, PREC_POINT_DIGITS_STRIP_ZEROES, 2);
    }
    return stream;
}

void TRuleCurrent::Activate() {
    Master.WorkedRule(*this);
}

bool TRuleCurrent::GetBfResult() const {
    int value = 0;

    const TBfDef &bf = std::get<TBfDef>(ruleDef.rules);

    int j = 0;

    for (const TRuleDef& dep : ruleDef.Dependencies) {
        if (Master.IsRuleWorked(dep.id)) {
            value |= 1 << j;
        }
        j ++;
    }

    int ind = value >> 3;
    return ((bf.pResult[ind] >> (value & 7)) & 1) == 1;
}

bool TRuleCurrent::GetArResult() const {
    int c = 0;

    const TArDef &ar = std::get<TArDef>(ruleDef.rules);

    int j = 0;

    for (const TRuleDef& dep : ruleDef.Dependencies) {
        if (Master.IsRuleWorked(dep.id)) {
            if (ar.pSignes[j] == AR_SIGN_MINUS)
                --c;
            else
                ++c;
        }
        j ++;
    }

    if (ar.comp == AR_SIGN_GREAT) {
        if (c > ar.value)
            return true;
    } else if (c < ar.value)
        return true;

    return false;
}

void TRulesContext::WorkedRule(TRuleCurrent& rule) {
    if (rule.fWorked)
        return;

    Uncompute(rule);

    RulesAsAppeared.emplace_back(rule);
    rule.fWorked = true;
    int iAnti = RulesHolder.GetRuleToAntirule(rule.ruleDef.id); //  m_pRuleToAntiRule[rid];
    if (iAnti != NOT_ANTI_RULE) {
        const TRuleDef &antiRule = *RulesHolder.RuleById(iAnti);
        const auto &anti = std::get<TAntiDef>(antiRule.rules);
        for(const TRuleDef& ruleDefToCancel : anti.pCancelRules) {
            TRuleCurrent &toCancel = AllRules[ruleDefToCancel.id];
            toCancel.Cancelled = true;
            Uncompute(toCancel);
        }
        AntiRulesAsAppeared.emplace_back(rule);
    }
}

void TRulesContext::WorkedRule(int rid) {
    auto &rule = AllRules[rid];
    WorkedRule(rule);
}

bool TRulesContext::IsRuleWorked(const TStringBuf &szRuleName) const {
    if (int rid; RulesHolder.AllRulesFindRid(szRuleName, rid))
        return IsRuleWorked(rid);
    else
        return false;
}

bool TRulesContext::IsRuleWorked(int rid) const {
    return AllRules[rid].fWorked;
}

bool TRulesContext::SetRule(const TStringBuf &szRuleName) {
    if (int rid; RulesHolder.AllRulesFindRid(szRuleName, rid)) {
        if (freeze) {
            RulesAppearedAfterFreeze.emplace_back(AllRules[rid]);
        }
        SetRule(rid);
        return true;
    } else
        return false;
}

void TRulesContext::SetRule(int rid) {
    WorkedRule(rid);
}

void TRulesContext::Uncompute(TRuleCurrent& current) {
    if(current.Computed) {
        current.Computed = false;

        for(const TRuleDef& master : current.ruleDef.Masters) {
            Uncompute(AllRules[master.id]);
        }
    }
}

void TRulesContext::UnsetCanceledRules() {
    for (TRuleCurrent &current : RulesAsAppeared) {
        if (current.Cancelled && current.fWorked) {
            current.fWorked = false;

            Uncompute(current);
        }
    }
}

void TRulesContext::Recalc() {
    TRulesRefs filtered;
    for (TRuleCurrent &rule : RulesAsAppeared) {
        if (rule.fWorked)
            filtered.emplace_back(rule);
        else
            CanceledRulesAsAppeared.emplace_back(rule);
    }
    RulesAsAppeared = std::move(filtered);
}

TScores TRulesContext::CalcScores() const {
    TScores scores = TScores();
    for (const TRuleCurrent &current : RulesAsAppeared) {
        if (!current.Cancelled) // antirule
        {
            if (current.ruleDef.fDelivery)
                scores.HitsDlvr += GetScore(current.ruleDef.pRuleName);
            else
                scores.Hits += GetScore(current.ruleDef.pRuleName);
        }
    }
    return scores;
}

const TRulesContext::TRulesRefs &TRulesContext::GetCanceledRules() const {
    return CanceledRulesAsAppeared;
}
const TRulesContext::TRulesRefs &TRulesContext::GetAntiRules() const {
    return AntiRulesAsAppeared;
}

const TRulesContext::TRulesRefs &TRulesContext::GetOccuredRules() const {
    return RulesAsAppeared;
}

const TRuleCurrent &TRulesContext::GetRuleCurrent(int rid) const {
    return AllRules.at(rid);
}

TRuleCurrent& TRulesContext::GetRuleCurrentMutable(int rid) {
    return AllRules.at(rid);
}

void TRulesContext::CheckExpressions(TRuleCurrent& ruleCurrent, const TRuleCurrent* base) {
    if(ruleCurrent.Computed) {
        return;
    }

    ruleCurrent.Computed = true;

    if(ruleCurrent.Cancelled) {
        return;
    }

    for(const TRuleDef& depDef : ruleCurrent.ruleDef.Dependencies) {
        TRuleCurrent& depCurrent = AllRules[depDef.id];

        if(std::addressof(depCurrent) == base) {
            SysLogger() << TLOG_WARNING << "loop " << ruleCurrent.ruleDef << " -> " << base->ruleDef;
        }

        CheckExpressions(depCurrent, base);
    }

    bool newState;
    if (ruleCurrent.ruleDef.rt == RT_BF) {
        newState = ruleCurrent.GetBfResult();
    } else if(ruleCurrent.ruleDef.rt == RT_ARITHMETIC) {
        newState = ruleCurrent.GetArResult();
    } else {
        return;
    }

    if(newState != ruleCurrent.fWorked) {
        if(newState) {
            WorkedRule(ruleCurrent);
        } else {
            ruleCurrent.fWorked = false;
        }
    }
}

void TRulesContext::CheckExpressions() {
    for (const TRuleDef& ruleDef : RulesHolder.GetExprRulesVector()) {
        TRuleCurrent& ruleCurrent = AllRules[ruleDef.id];
        CheckExpressions(ruleCurrent, std::addressof(ruleCurrent));
    }
}

const TRuleCurrent* TRulesContext::FindRule(const TStringBuf& szRuleName) const {
    if(int rid; RulesHolder.AllRulesFindRid(szRuleName, rid)) {
        return &GetRuleCurrent(rid);
    }
    return nullptr;
}

void TRulesContext::SetScore(const TStringBuf& szRuleName, double score) {
    if(int rid; RulesHolder.AllRulesFindRid(szRuleName, rid)) {
        GetRuleCurrentMutable(rid).SetScore(score);
    } else {
        Logger << (TLOG_ERR) << "cannot find rule " << szRuleName;
    }
}

double TRulesContext::GetScore(const TStringBuf& szRuleName) const {
    if(int rid; RulesHolder.AllRulesFindRid(szRuleName, rid)) {
        return GetRuleCurrent(rid).GetScore();
    } else {
        Logger << (TLOG_ERR) << "cannot find rule " << szRuleName;
        return 0.;
    }
}

const NLua::IRule &TRulesContext::Get(const TStringBuf &ruleName) const {
    if (int rid; RulesHolder.AllRulesFindRid(ruleName, rid)) {
        return GetRuleCurrent(rid);
    } else {
        ythrow TWithBackTrace<yexception>() << "cannot find rule " << ruleName;
    }
}

NLua::IRule &TRulesContext::GetMutable(const TStringBuf &ruleName) {
    if (int rid; RulesHolder.AllRulesFindRid(ruleName, rid)) {
        return GetRuleCurrentMutable(rid);
    } else {
        ythrow TWithBackTrace<yexception>() << "cannot find rule " << ruleName;
    }
}

const TRulesContext::TRulesRefs &TRulesContext::GetRulesAppearedAfterFreeze() const {
    return RulesAppearedAfterFreeze;
}

TMaybe<TRuleCurrentRef> TRulesContext::GetMostSlowlyRule() const {
    TRulesContext::TRulesRefs::const_iterator worstRuleIt = MaxElementBy(RulesAsAppeared,
                                                                         [](const TRuleCurrent &rule) {
                                                                             return rule.GetProfiler().GetDuration();
                                                                         });

    if(worstRuleIt != RulesAsAppeared.cend()) {
        return MakeMaybe(*worstRuleIt);
    }

    return Nothing();
}

TRulesContext::TRulesContext(const TRulesHolder &rulesHolder, const TLog logger) noexcept
: RulesHolder(rulesHolder)
, Logger(std::move(logger)) {

    AllRules.reserve(rulesHolder.m_cRules);
    for (size_t rid = 0; rid < static_cast<size_t>(rulesHolder.m_cRules); rid++) {
        const auto &rule = *RulesHolder.RuleById(rid);
        AllRules.emplace_back(rule, *this);
    }
}

const TVector<TRuleCurrent> &TRulesContext::GetRules() const {
    return AllRules;
}

TCurMessageEngine::TCurMessageEngine(
        const TRulesHolder &rulesHolder,
        TLog logger,
        TString mlLogPrefix,
        const TStringBuf &messageId,
        double requiredScore,
        double requiredDeliveryScore)
        : rulesContext(rulesHolder, logger)
        , m_sMessageId(messageId.After(':'))
        , m_required(requiredScore)
        , m_delivery_required(requiredDeliveryScore)
        , MlLogBuilder(std::move(mlLogPrefix), logger)
        , Logger(std::move(logger)) {

    FieldsValues[FD_PLACEHOLDER] = {NJson::TJsonValue(NJson::JSON_NULL)};
}

std::pair<bool, TUrlStatistic*> TCurMessageEngine::AddShortUrl(const TStringBuf & url) {
    const auto lowered = to_lower(TString{url});

    if(size_t * index = MapFindPtr(UrlRepBackIndex, lowered)) {
        return {false, &m_url_reputation_vec[*index]};
    } else {
        UrlRepBackIndex.emplace(lowered, m_url_reputation_vec.size());
        auto& emplaced = m_url_reputation_vec.emplace_back(url, CODES_KOI8, true, false);
        m_mapurl_reputation.emplace(std::move(lowered), 1);
        return {true, &emplaced};
    }
}

std::pair<bool, TUrlStatistic*> TCurMessageEngine::AddUrl(const TStringBuf & url, TVector<TString> aliases) {
    return AddUrl(url, std::move(aliases), false, EUrlStaticticSource::BODY);
}

std::pair<bool, TUrlStatistic*> TCurMessageEngine::AddUrl(const TStringBuf & url, TVector<TString> aliases, bool doCount, EUrlStaticticSource code) {
    const auto lowered = to_lower(TString{url});

    if(size_t * index = MapFindPtr(UrlRepBackIndex, lowered)) {
        return {false, &m_url_reputation_vec[*index]};
    } else {
        UrlRepBackIndex.emplace(std::move(lowered), m_url_reputation_vec.size());
        auto& emplaced = m_url_reputation_vec.emplace_back(url, CODES_KOI8, false, false, doCount, std::move(aliases), code);
        return {true, &emplaced};
    }
}
