#include "NormDict.h"

#include "Tokens.h"

#include <util/generic/string.h>
#include <util/generic/vector.h>
#include <util/stream/file.h>
#include <util/string/join.h>
#include <util/string/split.h>
#include <util/string/split.h>

#include <algorithm>
#include <fstream>
#include <iostream>
#include <vector>

bool NormDict::LoadConfig(const char *file_name) {
    const size_t BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];
    std::ifstream stream(file_name);
    if(stream.fail()) {
        return false;
    }

    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);
        TVector<TString> cmd = StringSplitter(TStringBuf(buffer, strlen(buffer))).Split('\t');
        // обработка команды
        if(!cmd.size() || !cmd[0].size()) {
            continue;
        } else if(cmd[0] == "norm") {
            if(cmd.size() != 2) {
                std::cout << "bad cmd: " << buffer << std::endl;
            } else if(!LoadNorm(cmd[1].c_str())) {
                std::cout << "can't load norm dict from " << cmd[1] << std::endl;
            }
        } else if(cmd[0] == "wordcount") {
            if(cmd.size() != 2) {
                std::cout << "bad cmd: " << buffer << std::endl;
            } else if(!LoadWordCount(cmd[1].c_str())) {
                std::cout << "can't load wordcount from " << cmd[1] << std::endl;
            }
        } else if(cmd[0] == "stops") {
            if(cmd.size() != 3) {
                std::cout << "bad cmd: " << buffer << std::endl;
            } else if(!LoadStops(cmd[1].c_str(), cmd[2].c_str())) {
                std::cout << "can't load stops from " << cmd[1] << std::endl;
            }
        } else if(cmd[0] == "syn_cells") {
            if(cmd.size() != 3) {
                std::cout << "bad cmd: " << buffer << std::endl;
            } else if(!LoadSynCells(cmd[1].c_str(), cmd[2].c_str())) {
                std::cout << "can't load syn_cells from " << cmd[1] << std::endl;
            }
        } else if(cmd[0] == "bad") {
            if(cmd.size() != 3) {
                std::cout << "bad cmd: " << buffer << std::endl;
            } else if(!LoadWordSet(cmd[1].c_str(), cmd[2].c_str(), bad_words)) {
                std::cout << "can't load bad words from " << cmd[1] << std::endl;
            }
        } else {
            std::cout << "unknown cmd: " << cmd[0] << std::endl;
        }
    }

    return true;
}

bool NormDict::LoadNorm(const char *filename) {
    std::ifstream stream(filename);

    const unsigned BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];

    if(stream.fail()) {
        return false;
    }

    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);

        if(buffer[0]) {
            const char *p;

            for(p = buffer; *p && *p != '\t'; p++) ;

            if(!*p) {
                continue;
            }

            std::string word(buffer, p - buffer);
            p++;

            if(*p == '*') {
                // нормализация для разных языков
                TVector<TString> parts = StringSplitter(TStringBuf(p + 1, strlen(p + 1))).Split(',');

                for(size_t i = 0; i < parts.size(); i++) {
                    size_t j = parts[i].find(':');

                    if(j == std::string::npos) {
                        word2norm[word] = parts[i];
                    } else {
                        std::string lang = parts[i].substr(0, j);
                        languages[lang].SetNorm(TString(word), TString(parts[i].substr(j + 1)));
                    }
                }
            } else {
                word2norm[word] = p;
            }
        }
    }

    return true;
}

bool NormDict::LoadWordCount(const char *filename) {
    std::ifstream stream(filename);

    const unsigned BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];

    if(stream.fail()) {
        return false;
    }

    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);

        if(buffer[0]) {
            const char *p;

            for(p = buffer; *p && *p != '\t'; p++) ;

            if(!*p) {
                continue;
            }

            std::string word(buffer, p - buffer);
            p++;

            norm2count[word] = (unsigned) atoi(p);
        }
    }

    return true;
}

bool NormDict::LoadStops(const char *filename, const char *lang) {
    std::ifstream stream(filename);

    size_t i;
    const unsigned BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];
    std::vector<std::string> tokens;

    if(stream.fail()) {
        return false;
    }

    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);

        if(buffer[0] == '#') {
            continue;
        }

        Tokens::GetTokens(buffer, tokens);

        for(i = 0; i < tokens.size(); i++) {
            languages[lang].AddStopWord(TString(tokens[i]));
        }
    }

    return true;
}

bool NormDict::LoadSynCells(const char *filename, const char *lang) {
    try {
        TFileInput input(filename);
        TString buffer;

        TVector<TString> cluster;
        TVector<TString> data;
        while (input.ReadLine(buffer)) {
            ui16 semicolons = 0;

            for (auto& ch: buffer) {
                if (ch == ';') {
                    ++semicolons;
                    ch = ',';
                }
            }

            if(semicolons != 2) {
                // конец текущего кластера
                for (const auto& str: cluster) {
                    if (str) {
                        languages[lang].SetSnorm(str, cluster[0]);
                    }
                }
                cluster.clear();
            } else {
                const TVector<TString> data = StringSplitter(buffer).Split(',');
                cluster.reserve(cluster.size() + data.size());
                cluster.insert(cluster.end(),data.begin(),data.end());
            }
        }

        return true;
    } catch (TIoException& exception) {
        Cerr << "Error opening file: `" << filename << "` exception: " << exception.what() << Endl;
        return false;
    }
}

bool NormDict::LoadWordSet(const char *filename, const char *lang, WordSet& words) {
    std::ifstream stream(filename);

    const unsigned BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];
    std::vector<std::string> tokens;

    if(stream.fail()) {
        return false;
    }

    std::vector<std::string> norms;
    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);

        GetNormWords(buffer, lang, norms);

        if(norms.size() == 1) {
            words.insert(norms[0]);
        }
    }

    return true;
}

const char* NormDict::WordToNorm(const char *word, const char *lang) const {
    std::map<std::string, NIRT::TLanguage>::const_iterator it_lang = languages.find(lang);

    // нормализация для конкрентного языка
    if(it_lang != languages.end()) {
        if (it_lang->second.HasNorm(word)) {
            return it_lang->second.GetNorm(word).data();
        }
    }

    // нормализация, общая для всех языков
    std::unordered_map<std::string, std::string>::const_iterator it = word2norm.find(word);
    if(it == word2norm.end()) {
        return word;
    }

    return it->second.c_str();
}

unsigned NormDict::GetNormCount(const char *word) const {
    std::unordered_map<std::string, unsigned>::const_iterator it = norm2count.find(word);

    if(it == norm2count.end()) {
        return 0;
    }

    return it->second;
}

void NormDict::GetNormWords(const char *text, const char *lang, std::vector<std::string>& words) const {
    std::vector<std::string> tokens;
    size_t i;
    std::map<std::string, NIRT::TLanguage>::const_iterator it_lang = languages.find(lang);

    Tokens::GetTokens(text, tokens);

    words.clear();
    for(i = 0; i < tokens.size(); i++) {
        if(!tokens[i].size() || tokens[i][0] == '-') {
            continue;
        }

        std::string word = WordToNorm(tokens[i].c_str(), lang);
        if(it_lang == languages.end() || !it_lang->second.IsStopWord(TString(word))) {
            words.push_back(word);
        }
    }
}

std::string NormDict::Normalize(const char *text, const char *lang, bool uniq) const {
    std::vector<std::string> words;

    GetNormWords(text, lang, words);
    std::sort(words.begin(), words.end());

    if(uniq) {
        std::vector<std::string>::iterator it = std::unique(words.begin(), words.end());
        words.erase(it, words.end());
    }

    return JoinSeq(" ", words);
}

std::string NormDict::Snormalize(const char *text, const char *lang) const {
    std::vector<std::string> words;
    size_t i;
    std::map<std::string, NIRT::TLanguage>::const_iterator it_lang = languages.find(lang);

    GetNormWords(text, lang, words);

    if(it_lang != languages.end()) {
        for(i = 0; i < words.size(); i++) {
            words[i] = it_lang->second.GetSnorm(TString(words[i]));
        }
    }
    std::sort(words.begin(), words.end());

    return JoinSeq(" ", words);
}
