#include <search/web/core/rule.h>
#include <search/web/util/grouping/grouping.h>
#include <search/web/util/common/common.h>

#include <search/grouping/name.h>

#include <util/generic/algorithm.h>
#include <util/generic/vector.h>
#include <util/stream/output.h>
#include <library/cpp/dot_product/dot_product.h>
#include <extsearch/images/kernel/dssm_applier/dssm_applier.h>


class TCollectionsDssmRule : public IRearrangeRule {
public:
    TCollectionsDssmRule(const TString & /*config*/, const TSearchConfig & searchConfig) {
        TFsPath rearrangeDataDir(
            !!searchConfig.RearrangeIndexDir ? searchConfig.RearrangeIndexDir : searchConfig.RearrangeDataDir
        );
        DssmApplier.Init(rearrangeDataDir.Child("images").Child("dssm_q2t_query.cfg").GetPath());
        DocDssmApplier.Init(rearrangeDataDir.Child("images").Child("dssm_q2t_doc.cfg").GetPath());
    }

    virtual ~TCollectionsDssmRule() noexcept {}

    IRearrangeRuleContext* DoConstructContext() const override;

private:
    NImages::TDssmApplier DssmApplier;
    NImages::TDssmApplier DocDssmApplier;
};

class TCollectionsDssmContext : public INoAdjustRearrangeRuleContext {
public:
    TCollectionsDssmContext() = default;

    TCollectionsDssmContext(const NImages::TDssmApplier &dssmApplier, const NImages::TDssmApplier &docDssmApplier)
    : DssmApplier(dssmApplier), DocDssmApplier(docDssmApplier) {}

    void RegisterPrettyParams() {
        AddRuleParamToKuka("Enabled", KPT_BOOL, "Enable collections dssm wizard");
        AddRuleParamToKuka("Mxnet", KPT_NUMBER, "Mxnet coef in mixture");
        AddRuleParamToKuka("Dssm", KPT_NUMBER, "Dssm coef in mixture");
        AddRuleParamToKuka("Beauty", KPT_NUMBER, "Beauty coef in mixture");
        AddRuleParamToKuka("Rate", KPT_NUMBER, "Rate coef in mixture");
        AddRuleParamToKuka("BoundaryRelevance", KPT_NUMBER, "Min board relevance to be accepted");
    }

    void DoPrepareRearrangeParams(const TAdjustRuleParams& ap) override {
        Y_UNUSED(ap);
        RegisterPrettyParams();
    }

    void DoRearrangeAfterMerge(TRearrangeParams& rearrangeParams) override {
        if (!LocalScheme()["Enabled"].IsTrue()) {
            return;
        }

        if (rearrangeParams.Current.Params.Attr != NSearchGrouping::NName::s_board_id) {
            return;
        }

        TMetaGrouping* collectionsGrouping = rearrangeParams.Current.Grouping;
        if (!collectionsGrouping) {
            return;
        }

        const TString query = rearrangeParams.RP.RelevParams.Get("norm", "");
        const TVector<i8> queryFeatures = DssmApplier.GetPackedFeaturesFromLayer(query, QUERY_T2T_LAYER);
        const float bound = LocalScheme()["BoundaryRelevance"].GetNumber(DefaultBound);
        const float mxnet = LocalScheme()["Mxnet"].GetNumber(0.5);
        const float dssm = LocalScheme()["Dssm"].GetNumber(0.5);
        const float beauty = LocalScheme()["Beauty"].GetNumber(0.0);
        const float rate = LocalScheme()["Rate"].GetNumber(0.0);

        for (size_t i = 0; i < collectionsGrouping->Size(); ++i) {
            TMetaGroup& currentGroup = collectionsGrouping->GetMetaGroup(i);
            RelevanceRefresh(&currentGroup, queryFeatures, bound, mxnet, dssm, beauty, rate);
        }

        TEmptyGroupFilter filter;
        const size_t diff = collectionsGrouping->FilterGroups(filter);
        rearrangeParams.InsertWorkedRule("BoundaryRelevanceValue.debug", ToString(bound));
        collectionsGrouping->SortGroups();
        const size_t size = collectionsGrouping->Size();
        rearrangeParams.InsertWorkedRule("GroupingSize.debug", ToString(size));
        rearrangeParams.InsertWorkedRule("GroupingSizeDiff.debug", ToString(diff));
    }

private:

    void RelevanceRefresh(TMetaGroup* group, const TVector<i8>& queryFeatures, float bound, float mxnetCoef, float dssmCoef, float beautyCoef, float rateCoef) {
        TMetaGroup::TDocs& groupDocs = group->MetaDocs;
        if (groupDocs.empty()) {
            return;
        }
        const TMergedDoc& doc = groupDocs.front();
        const float fmlRelev = 1. * (group->GetRelevance() - 100000000) / 10000000;
        TMsString boardTitle;
        doc.Attributes().GetOneValue(boardTitle, "board_title");
        const TVector<i8> titleFeatures = DocDssmApplier.GetPackedFeaturesFromLayer(TString{boardTitle}, DOC_T2T_LAYER);
        const float titleDssmRelev = CalculateSimilarity(queryFeatures, titleFeatures);
        if (titleDssmRelev < bound) {
            groupDocs.clear();
        } else {
            const float averageBeauty = GetRelevFactorFromDoc(doc, "average_beauty");
            const float smartRate = GetRelevFactorFromDoc(doc, "smart_rate");
            const float mixinRelev = mxnetCoef * fmlRelev + dssmCoef * titleDssmRelev + beautyCoef * averageBeauty + rateCoef * smartRate;
            group->SetRelevance((i64)(mixinRelev * 10000000 + 100000000));
        }
    }

    float CalculateSimilarity(const TVector<i8>& queryFeatures, const TVector<i8>& docFeatures) {
        if (queryFeatures.empty() || docFeatures.empty()) {
            return 0;
        }
        const i32 similarity = DotProduct(&docFeatures[0], &queryFeatures[0], 300);
        const float dssmRelev = Min(Max(similarity * Normalizer + Shift, 0.f), 1.f);
        return dssmRelev;
    }

    const NImages::TDssmApplier& DssmApplier;
    const NImages::TDssmApplier& DocDssmApplier;

    const TStringBuf QUERY_T2T_LAYER = "features_query";
    const TStringBuf DOC_T2T_LAYER = "features_doc";
    static constexpr float DefaultBound = 0.55f;
    static constexpr float Normalizer = 1.183e-06f;
    static constexpr float Shift = 0.384f;
};

///////////////////////////////////////////////////////////////////////////////

IRearrangeRuleContext* TCollectionsDssmRule::DoConstructContext() const {
    return new TCollectionsDssmContext(DssmApplier, DocDssmApplier);
}

///////////////////////////////////////////////////////////////////////////////

IRearrangeRule* CreateCollectionsDssmRule(const TString& config, const TSearchConfig& searchConfig) {
    return new TCollectionsDssmRule(config, searchConfig);
}

///////////////////////////////////////////////////////////////////////////////

REGISTER_REARRANGE_RULE(CollectionsDssm, CreateCollectionsDssmRule);

///////////////////////////////////////////////////////////////////////////////
