#include <mail/so/libs/fast_text/fasttext.h>
#include <util/string/split.h>
#include <util/generic/vector.h>
#include <util/generic/strbuf.h>
#include <contrib/libs/intel/mkl/include/mkl_cblas.h>
#include "wmd.h"

namespace fasttext {

    TDoc::TDoc(const TStringBuf &text, const FastText &model) {
        const TVector<TStringBuf> words = StringSplitter(text).SplitByFunc(isspace).SkipEmpty();
        Distances = NWmd::TMatrix(model.getDimension(), words.size());
        float* targetRow = Distances.Data;
        fasttext::Vector vector(model.getDimension());
        for (size_t i = 0; i < words.size(); i++) {
            const TStringBuf &word = words[i];
            {
                model.getWordVector(vector, std::string(word));
                vector.mul(1.f / vector.norm());
                CopyN(vector.data(), vector.size(), targetRow);
                targetRow += Distances.Cols;
            }
        }
    }

    NWmd::TMatrix MakeDistances(const TDoc &doc1, const TDoc &doc2) {
        const NWmd::TMatrix& A = doc1.Distances;
        const NWmd::TMatrix& B = doc2.Distances;
        NWmd::TMatrix C(B.Rows, A.Rows, 0.5f);
        cblas_sgemm(
                CblasRowMajor,
                CblasNoTrans, CblasTrans,
                C.Rows, C.Cols, A.Cols,
                -0.5f,
                A.Data, A.Cols,
                B.Data, B.Cols,
                1.f,
                C.Data, C.Cols
        );

        return C;
    }
}
