#include "fast_text_vector_calculator.h"

#include <mail/so/libs/fast_text/fasttext.h>

#include <util/stream/input.h>
#include <util/stream/output.h>
#include <util/string/split.h>

static void AddVector(float* dst, const float* src, size_t size) {
    for (size_t i = 0; i < size; ++i) {
        *dst++ += *src++;
    }
}

static void DiffVector(float* dst, const float* src, size_t size) {
    for (size_t i = 0; i < size; ++i) {
        *dst++ -= *src++;
    }
}

static float Norm(const float* v, size_t size) {
    double sum = 0.f;
    for (size_t i = 0; i < size; ++i) {
        sum += ((double) v[i]) * v[i];
    }
    return sqrt(sum);
}

int main(int argc, char** argv) {
    if (argc < 3) {
        Cerr << "Usage: " << argv[0] << " <model type> <model>" << Endl;
        return 1;
    }
    TString type{argv[1]};
    THolder<NWLO::IVectorCalculator> calculator;
    if (type == "fasttext") {
        THolder<fasttext::FastText> model(new fasttext::FastText);
        Cout << "Loading fasttext model..." << Endl;
        model->loadModel(argv[2]);
        calculator = MakeHolder<NWLO::TFastTextVectorCalculator>(std::move(model));
        Cout << "Loading done." << Endl;
    } else {
        Cerr << "Unknown model type <" << type << '>' << Endl;
    }
    const size_t vectorSize = calculator->GetVectorSize();
    TVector<float> tmp(vectorSize, 0.f);
    TString line;
    while (Cin.ReadLine(line)) {
        Cout << "Vector size: " << vectorSize << ", line read <" << line << '>' << Endl;
        TVector<float> main(vectorSize, 0.f);
        TVector<TStringBuf> words = StringSplitter(line).SplitByFunc(isspace).SkipEmpty();
        if (!words.empty()) {
            Cout << "Word <" << words[0] << '>' << Endl;
            calculator->GetWordVector(words[0].Data(), words[0].Size(), main.data());
        }
        for (size_t i = 1; i < words.size(); ++i) {
            TString word{words[i]};
            if (word == "+") {
                word = words[++i];
                Cout << "Add word <" << word << '>' << Endl;
                calculator->GetWordVector(word.Data(), word.Size(), tmp.data());
                AddVector(main.data(), tmp.data(), vectorSize);
            } else if (word == "-") {
                word = words[++i];
                Cout << "Minus word <" << word << '>' << Endl;
                calculator->GetWordVector(word.Data(), word.Size(), tmp.data());
                DiffVector(main.data(), tmp.data(), vectorSize);
            } else {
                Cerr << "Bad op at pos #" << i << Endl;
                break;
            }
        }
        Cout << "Main norm: " << Norm(main.data(), vectorSize) << Endl;
        for (size_t i = 0; i < vectorSize; ++i) {
            if (i) {
                Cout << " ";
            }
            Cout << FloatToString(main[i], PREC_POINT_DIGITS_STRIP_ZEROES, 3);
        }
        Cout << Endl << Endl;
    }
}

