#pragma once

#include <util/string/join.h>

#include <kernel/dssm_applier/begemot/production_data.h>
#include <kernel/url_text_analyzer/url_analyzer.h>

#include <ysite/yandex/reqanalysis/fast_normalize_request.h>

namespace NWebmaster {

struct TBoostingXfOneSEDssm {
    TBoostingXfOneSEDssm();

    bool Apply(const TString &normalizedQuery, TVector<float> &embeddingVector);

public:
    NNeuralNetApplier::TModel Model;
};

struct TBoostingXfOneSeAmSsHardDssm {
    TBoostingXfOneSeAmSsHardDssm();

    bool Apply(const TString &normalizedQuery, TVector<float> &embeddingVector);

public:
    NNeuralNetApplier::TModel Model;
};

struct TKMeansClusters {
    TKMeansClusters();

    ui32 GetClusterId(const TVector<float> &embeddingVector);
    void GetClustersL2(const TVector<float> &embeddingVector, TVector<float> &clustersL2);
    void GetClustersL2Top(const TVector<float> &embeddingVector, TVector<ui8> &clustersL2Top, const size_t top = 10);

public:
    TVector<TVector<float>> KMeans;
};

inline float SoftSign(float x) {
    const float Bias = 0.5;
    const float Scale = 0.5;
    const float Constant = 1.0;
    return Bias + Scale * x / (Constant + fabs(x));
}

struct TDssmL3Model {
    using Ptr = TAtomicSharedPtr<TDssmL3Model>;

    TDssmL3Model(const TString modelFileName = "l3_model.dssm");
    NNeuralNetApplier::TModelPtr GetSubmodel(const TString& outputName, const TSet<TString>& terminalInputs) const;

public:
    NNeuralNetApplier::TModel Model;
};

struct TDssmCtr {
    const char *V_INPUT                 = "input";
    const char *V_DOC_EMBEDDING_CTR     = "doc_embedding_ctr";
    const char *V_QUERY_EMBEDDING_CTR   = "query_embedding_ctr";
    const char *V_JOINT_OUTPUT_CTR      = "joint_output_ctr";

public:
    TDssmCtr(TDssmL3Model::Ptr l3ModelPtr);

    void GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding);
    void GetDocEmbedding(const TString &url, const TString &urlUta, const TString &normalizedTitle, TVector<float> &embedding);
    float GetJointOutput(const TVector<float> &docEmbeddingCtr, const TVector<float> &queryEmbeddingCtr);

public:
    NNeuralNetApplier::TModelPtr ModelEmbeddingsPtr;
    NNeuralNetApplier::TModelPtr ModelOutputPtr;
};

struct TDssmCtrNoMiner {
    const char *V_INPUT                         = "input";
    const char *V_DOC_EMBEDDING_CTR_NO_MINER    = "doc_embedding_ctr_no_miner";
    const char *V_QUERY_EMBEDDING_CTR_NO_MINER  = "query_embedding_ctr_no_miner";
    const char *V_JOINT_OUTPUT_CTR_NO_MINER     = "joint_output_ctr_no_miner";

public:
    TDssmCtrNoMiner(TDssmL3Model::Ptr l3ModelPtr);

    void GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding);
    void GetDocEmbedding(const TString &url, const TString &normalizedTitle, TVector<float> &embedding);
    float GetJointOutput(const TVector<float> &docEmbeddingCtr, const TVector<float> &queryEmbeddingCtr);

public:
    NNeuralNetApplier::TModelPtr ModelEmbeddingsPtr;
    NNeuralNetApplier::TModelPtr ModelOutputPtr;
};

struct TDssmUta {
    const char *V_INPUT                 = "input";
    const char *V_DOC_EMBEDDING_UTA     = "doc_embedding_uta";
    const char *V_QUERY_EMBEDDING_UTA   = "query_embedding_uta";
    const char *V_JOINT_OUTPUT_UTA      = "joint_output_uta";

public:
    TDssmUta(TDssmL3Model::Ptr l3ModelPtr);

    void GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding);
    void GetDocEmbedding(const TString &url, const TString &urlUta, const TString &normalizedTitle, TVector<float> &embedding);
    float GetJointOutput(const TVector<float> &docEmbeddingUta, const TVector<float> &queryEmbeddingUta);

public:
    NNeuralNetApplier::TModelPtr ModelEmbeddingsPtr;
    NNeuralNetApplier::TModelPtr ModelOutputPtr;
};

struct TUrlUTA {
    TString Get(const TString &url) const {
        return JoinSeq(" ", Analyzer.AnalyzeUrlUTF8(url));
    }

    static const TUrlUTA &CInstance() {
        return *Singleton<TUrlUTA>();
    }

public:
    NUta::TSmartUrlAnalyzer Analyzer;
};

inline TString NormalizeQuery(const TString &query) {
    return FastNormalizeRequest(query, false);
}

} //namespace NWebmaster
