#include <library/cpp/l2_distance/l2_distance.h>
#include <library/cpp/protobuf/util/pb_io.h>

#include <kernel/embeddings_info/embedding_traits.h>

#include <robot/library/dssm/utils/title_normalization.h>

#include <wmconsole/version3/library/dssm/query_classify.pb.h>

#include "dssm_utils.h"

namespace NWebmaster {

const char *A_QUERY         = "query";
const char *A_DOC_URL       = "doc_url";
const char *A_DOC_UTA_URL   = "doc_uta_url";
const char *A_DOC_TITLE     = "doc_title";

bool ApplyQueryEmbeddingModel(const NNeuralNetApplier::TModel &model, const TString &normalizedQuery, TVector<float> &embeddingVector) {
    const TVector<TString> annotations = {"query"};
    const TVector<TString> values = {normalizedQuery};
    const auto samplePtr = MakeAtomicShared<NNeuralNetApplier::TSample>(annotations, values);
    NNeuralNetApplier::TEvalContext evalContext;
    model.FillEvalContextFromSample(samplePtr, evalContext);
    model.Apply(evalContext);
    const auto output = VerifyDynamicCast<NNeuralNetApplier::TMatrix*>(evalContext.at("query_embedding").Get());
    embeddingVector = output->RepresentAsArray();
    return true;
}

TBoostingXfOneSEDssm::TBoostingXfOneSEDssm() {
    TBlob blob = TBlob::FromFile("DssmBoostingXfOneSE.dssm");
    Model.Load(blob);
    Model.Init();
}

bool TBoostingXfOneSEDssm::Apply(const TString &normalizedQuery, TVector<float> &embeddingVector) {
    return ApplyQueryEmbeddingModel(Model, normalizedQuery, embeddingVector);
}

TBoostingXfOneSeAmSsHardDssm::TBoostingXfOneSeAmSsHardDssm() {
    TBlob blob = TBlob::FromFile("DssmBoostingXfOneSeAmSsHard.dssm");
    Model.Load(blob);
    Model.Init();
}

bool TBoostingXfOneSeAmSsHardDssm::Apply(const TString &normalizedQuery, TVector<float> &embeddingVector) {
    return ApplyQueryEmbeddingModel(Model, normalizedQuery, embeddingVector);
}

TKMeansClusters::TKMeansClusters() {
    const TBlob blobKMeans = TBlob::FromFile("kmeans-centers.pb.txt");
    TMemoryInput clusterInput(blobKMeans.Data(), blobKMeans.Length());
    const auto clusterList = ParseFromTextFormat<NWebmaster::NProto::TDssmQueryClassifyConfig>(clusterInput);

    KMeans.reserve(clusterList.GetClusters().size());
    for (const auto& cluster : clusterList.GetClusters()) {
        KMeans.emplace_back(cluster.GetVector().begin(), cluster.GetVector().end());
    }

    for (const auto& vector : KMeans) {
        Y_VERIFY(vector.size() == KMeans.front().size());
    }
}

ui32 TKMeansClusters::GetClusterId(const TVector<float> &embeddingVector) {
    ui32 clusterId = 0;
    float nearestDistance = L2SqrDistance(embeddingVector.data(), KMeans.front().data(), embeddingVector.ysize());
    for (size_t i = 1; i < KMeans.size(); i++) {
        float currentDistance = L2SqrDistance(embeddingVector.data(), KMeans[i].data(), embeddingVector.ysize());
        if (currentDistance < nearestDistance) {
            nearestDistance = currentDistance;
            clusterId = i;
        }
    }
    return clusterId;
}

void TKMeansClusters::GetClustersL2(const TVector<float> &embeddingVector, TVector<float> &clustersL2) {
    clustersL2.resize(KMeans.size());
    for (size_t i = 0; i < KMeans.size(); i++) {
        clustersL2[i] = L2SqrDistance(embeddingVector.data(), KMeans[i].data(), embeddingVector.ysize());
    }
}

void TKMeansClusters::GetClustersL2Top(const TVector<float> &embeddingVector, TVector<ui8> &clustersL2Top, const size_t top) {
    TVector<float> clustersL2;
    GetClustersL2(embeddingVector, clustersL2);
    TVector<std::pair<float, ui8>> clustersL2Idx;
    for (size_t i = 0; i < clustersL2.size(); i++) {
        clustersL2Idx.emplace_back(clustersL2[i], i);
    }

    std::sort(clustersL2Idx.begin(), clustersL2Idx.end());
    for (size_t i = 0; i < top; i++) {
        clustersL2Top.emplace_back(clustersL2Idx[i].second);
    }
}

void GetQueryEmbeddingImpl(const NNeuralNetApplier::TModel &model, const TString &normalizedQuery, const TVector<TString> &outputVariables, TVector<float> &embedding) {
    const static TVector<TString> annotations = { A_QUERY, A_DOC_URL, A_DOC_UTA_URL, A_DOC_TITLE };
    const TVector<TString> values = {normalizedQuery, "", "", ""};
    const auto samplePtr = MakeAtomicShared<NNeuralNetApplier::TSample>(annotations, values);
    model.Apply(samplePtr, outputVariables, embedding);
}

void GetDocWithUtaEmbeddingImpl(const NNeuralNetApplier::TModel &model, const TString &url, const TString &urlUta, const TString &normalizedTitle, const TVector<TString> &outputVariables, TVector<float> &embedding) {
    const static TVector<TString> annotations = { A_QUERY, A_DOC_URL, A_DOC_UTA_URL, A_DOC_TITLE };
    const TVector<TString> values = {"", url, urlUta, normalizedTitle};
    const auto samplePtr = MakeAtomicShared<NNeuralNetApplier::TSample>(annotations, values);
    model.Apply(samplePtr, outputVariables, embedding);
}

void GetDocWithoutUtaEmbeddingImpl(const NNeuralNetApplier::TModel &model, const TString &url, const TString &normalizedTitle, const TVector<TString> &outputVariables, TVector<float> &embedding) {
    const static TVector<TString> annotations = { A_QUERY, A_DOC_URL, A_DOC_TITLE };
    const TVector<TString> values = {"", url, normalizedTitle};
    const auto samplePtr = MakeAtomicShared<NNeuralNetApplier::TSample>(annotations, values);
    model.Apply(samplePtr, outputVariables, embedding);
}

TDssmL3Model::TDssmL3Model(const TString modelFileName) {
    Model.Load(TBlob::FromFile(modelFileName));
    Model.Init();
}

NNeuralNetApplier::TModelPtr TDssmL3Model::GetSubmodel(const TString& outputName, const TSet<TString>& terminalInputs) const {
    return Model.GetSubmodel(outputName, terminalInputs);
}

TDssmCtr::TDssmCtr(TDssmL3Model::Ptr l3ModelPtr)
    : ModelEmbeddingsPtr(l3ModelPtr->GetSubmodel(V_JOINT_OUTPUT_CTR, { V_INPUT }))
    , ModelOutputPtr(l3ModelPtr->GetSubmodel(V_JOINT_OUTPUT_CTR, { V_DOC_EMBEDDING_CTR, V_QUERY_EMBEDDING_CTR }))
{
}

void TDssmCtr::GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_QUERY_EMBEDDING_CTR };
    GetQueryEmbeddingImpl(*ModelEmbeddingsPtr, normalizedQuery, outputVariables, embedding);
}

void TDssmCtr::GetDocEmbedding(const TString &url, const TString &urlUta, const TString &normalizedTitle, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_DOC_EMBEDDING_CTR };
    GetDocWithUtaEmbeddingImpl(*ModelEmbeddingsPtr, url, urlUta, normalizedTitle, outputVariables, embedding);
}

float TDssmCtr::GetJointOutput(const TVector<float> &docEmbeddingCtr, const TVector<float> &queryEmbeddingCtr) {
    NNeuralNetApplier::TEvalContext evalContext;
    evalContext[V_DOC_EMBEDDING_CTR] = new NNeuralNetApplier::TMatrix(1, docEmbeddingCtr.size(), docEmbeddingCtr);
    evalContext[V_QUERY_EMBEDDING_CTR] = new NNeuralNetApplier::TMatrix(1, queryEmbeddingCtr.size(), queryEmbeddingCtr);
    TVector<float> result;
    ModelOutputPtr->Apply(evalContext, {V_JOINT_OUTPUT_CTR}, result);
    return result[0];
}

TDssmCtrNoMiner::TDssmCtrNoMiner(TDssmL3Model::Ptr l3ModelPtr)
    : ModelEmbeddingsPtr(l3ModelPtr->GetSubmodel(V_JOINT_OUTPUT_CTR_NO_MINER, { V_INPUT }))
    , ModelOutputPtr(l3ModelPtr->GetSubmodel(V_JOINT_OUTPUT_CTR_NO_MINER, { V_DOC_EMBEDDING_CTR_NO_MINER, V_QUERY_EMBEDDING_CTR_NO_MINER }))
{
}

void TDssmCtrNoMiner::GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_QUERY_EMBEDDING_CTR_NO_MINER };
    GetQueryEmbeddingImpl(*ModelEmbeddingsPtr, normalizedQuery, outputVariables, embedding);
}

void TDssmCtrNoMiner::GetDocEmbedding(const TString &url, const TString &normalizedTitle, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_DOC_EMBEDDING_CTR_NO_MINER };
    GetDocWithoutUtaEmbeddingImpl(*ModelEmbeddingsPtr, url, normalizedTitle, outputVariables, embedding);
}

float TDssmCtrNoMiner::GetJointOutput(const TVector<float> &docEmbeddingCtr, const TVector<float> &queryEmbeddingCtr) {
    NNeuralNetApplier::TEvalContext evalContext;
    evalContext[V_DOC_EMBEDDING_CTR_NO_MINER] = new NNeuralNetApplier::TMatrix(1, docEmbeddingCtr.size(), docEmbeddingCtr);
    evalContext[V_QUERY_EMBEDDING_CTR_NO_MINER] = new NNeuralNetApplier::TMatrix(1, queryEmbeddingCtr.size(), queryEmbeddingCtr);
    TVector<float> result;
    ModelOutputPtr->Apply(evalContext, { V_JOINT_OUTPUT_CTR_NO_MINER }, result);
    return result[0];
}

TDssmUta::TDssmUta(TDssmL3Model::Ptr l3ModelPtr)
    : ModelEmbeddingsPtr(l3ModelPtr->GetSubmodel( V_JOINT_OUTPUT_UTA, { V_INPUT }))
    , ModelOutputPtr(l3ModelPtr->GetSubmodel(V_JOINT_OUTPUT_UTA, { V_DOC_EMBEDDING_UTA, V_QUERY_EMBEDDING_UTA }))
{
}

void TDssmUta::GetQueryEmbedding(const TString &normalizedQuery, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_QUERY_EMBEDDING_UTA };
    GetQueryEmbeddingImpl(*ModelEmbeddingsPtr, normalizedQuery, outputVariables, embedding);
}

void TDssmUta::GetDocEmbedding(const TString &url, const TString &urlUta, const TString &normalizedTitle, TVector<float> &embedding) {
    const static TVector<TString> outputVariables = { V_DOC_EMBEDDING_UTA };
    GetDocWithUtaEmbeddingImpl(*ModelEmbeddingsPtr, url, urlUta, normalizedTitle, outputVariables, embedding);
}

float TDssmUta::GetJointOutput(const TVector<float> &docEmbeddingUta, const TVector<float> &queryEmbeddingUta) {
    NNeuralNetApplier::TEvalContext evalContext;
    evalContext[V_DOC_EMBEDDING_UTA] = new NNeuralNetApplier::TMatrix(1, docEmbeddingUta.size(), docEmbeddingUta);
    evalContext[V_QUERY_EMBEDDING_UTA] = new NNeuralNetApplier::TMatrix(1, queryEmbeddingUta.size(), queryEmbeddingUta);
    TVector<float> result;
    ModelOutputPtr->Apply(evalContext, {V_JOINT_OUTPUT_UTA}, result);
    return result[0];
}

} //namespace NWebmaster
