#include "BroadMatcher.h"

#include "Index.h"
#include "MultiIndex.h"

#include <util/generic/string.h>
#include <util/string/cast.h>

#include <algorithm>
#include <numeric>

const size_t MAX_PHRASE_WORDS = 7;

class CmpMasks {
public:

    bool operator()(unsigned lhs, unsigned rhs) {
        unsigned n1 = 0;
        unsigned n2 = 0;

        for(; lhs; lhs >>= 1) n1 += lhs & 1;
        for(; rhs; rhs >>= 1) n2 += rhs & 1;

        return n2 < n1;
    }
};

BroadMatcher::BroadMatcher() {
    masks.resize(MAX_PHRASE_WORDS);

    for(unsigned size = 1; size <= MAX_PHRASE_WORDS; size++) {
        for(unsigned mask = 1; mask < (1u << size); mask++) {
            masks[size - 1].push_back(mask);
        }

        std::sort(masks[size - 1].begin(), masks[size - 1].end(), CmpMasks());
    }
}

void BroadMatcher::Match(const MultiIndex& index, const char *dict, size_t max_phrases, int flags, const std::vector<std::string>& words, const std::vector<unsigned>& counts, TStringBuilder& buffer) const {
    // версия с матчингом по всем подфразам и расчётом весов
    if(flags & BM_FULL) {
        MatchFullWeighted(index, dict, max_phrases, flags, words, counts, buffer);
        return;
    }

    size_t count = std::min(MAX_PHRASE_WORDS, words.size());
    size_t current_word_count = count;
    char split[] = " , ";
    size_t phrase_count = 0;
    unsigned mask = (1 << count) - 1;

    if(!words.size()) {
        return;
    }

    while(mask && phrase_count < max_phrases) {
        // построение подфразы
        TStringBuilder subphrase;
        for(size_t word_index = 0; word_index < count; word_index++) {
            if(mask & (1 << word_index)) {
                if(subphrase) {
                    subphrase << ' ';
                }

                subphrase << words[word_index];
            }
        }

        // матчинг подфразы
        IndexValueType value_type;
        const unsigned char* value_ptr = NULL;
        const Index* index_ptr = NULL;
        index.FindValue(dict, TString(subphrase).Data(), value_type, value_ptr, index_ptr);

        // сохранение фраз в топ
        if(value_ptr) {
            if(buffer) {
                buffer << split;
            }
            auto value = index_ptr->GetValue(value_ptr, value_type);
            phrase_count++;
            for (size_t ch = 0; ch < value.Size(); ++ch) {
                if (value[ch] == ',') {
                    phrase_count++;
                    if(phrase_count > max_phrases) {
                        value = TStringBuf(value).SubString(0, ch);
                        break;
                    }
                }
            }
            buffer << value;
        }

        // построение новой маски
        if(phrase_count < max_phrases) {
            size_t worst_word = 0;
            size_t worst_count = 0;

            if(current_word_count <= 2) {
                break;
            }

            for(size_t word_index = 0; word_index < count; word_index++) {
                if((mask & (1 << word_index)) && (!worst_count || counts[word_index] > worst_count)) {
                    worst_count = counts[word_index];
                    worst_word = word_index;
                }
            }

            mask = mask & (~(((size_t) 1) << worst_word));
            current_word_count--;
        }
    }
}

void BroadMatcher::MatchFullWeighted(const MultiIndex& index, const char *dict, size_t max_phrases, int /* flags */, const std::vector<std::string>& words, const std::vector<unsigned>& counts, TStringBuilder& buffer) const {
    size_t count = std::min(MAX_PHRASE_WORDS, words.size());
    size_t mask_index, word_index;
    char split[] = " , ";
    typedef std::pair<std::string, float> PhraseWeight;
    std::vector<PhraseWeight> top_phrases;

    if(!count) {
        return;
    }

    // веса слов
    std::vector<float> word_weights(count);
    std::transform(counts.begin(), counts.begin() + count, word_weights.begin(), [](unsigned n) -> float { return 100000.0f / (n + 1.0f); });
    float sum_weight = std::accumulate(word_weights.begin(), word_weights.end(), 0.0f);
    if(sum_weight < 0.0001f) {
        sum_weight = 1.0f;
    }

    // веса масок
    std::vector<float> mask_weights(masks[count - 1].size());
    std::transform(
        masks[count - 1].begin(),  masks[count - 1].end(), mask_weights.begin(),
        [word_weights, sum_weight](unsigned mask) -> float {
            float w = 0.0f;
            size_t word_index;

            for(word_index = 0; word_index < word_weights.size(); word_index++) {
                if(mask & (1 << word_index)) {
                    w += word_weights[word_index];
                }
            }

            return w / sum_weight;
        }
    );

    // сортируем маски по убыванию веса
    std::vector<size_t> mask_indices(mask_weights.size());
    for(mask_index = 0; mask_index < mask_indices.size(); mask_index++) {
        mask_indices[mask_index] = mask_index;
    }
    std::sort(
        mask_indices.begin(), mask_indices.end(),
        [mask_weights](size_t lhs, size_t rhs){ return mask_weights[lhs] > mask_weights[rhs]; }
    );

    for(auto mask_index: mask_indices) {
        // построение подфразы
        unsigned mask = masks[count - 1][mask_index];
        TStringBuilder subphrase;
        for(word_index = 0; word_index < count; word_index++) {
            if(mask & (1 << word_index)) {
                if(subphrase) {
                    subphrase << ' ';
                }

                subphrase << words[word_index];
            }
        }

        // матчинг подфразы
        IndexValueType value_type;
        const unsigned char* value_ptr = NULL;
        const Index* index_ptr = NULL;
        index.FindValue(dict, TString(subphrase).Data(), value_type, value_ptr, index_ptr);

        // сохранение фраз в топ
        if(value_ptr) {
            TStringBuilder assocs;
            assocs << index_ptr->GetValue(value_ptr, value_type);

            size_t pair_begin = 0;
            const TStringBuf assocs_text = assocs;
            while(pair_begin < assocs.Size()) {
                // выделяем очередную пару "фраза, вес"
                size_t pair_end = pair_begin;
                while(pair_end < assocs.Size() && assocs_text[pair_end] != ',') {
                    pair_end++;
                }

                // парсинг веса
                float weight = 0.0;
                size_t weight_begin = pair_end - 1;
                bool is_int = false;
                int int_mul = 1;
                while(assocs_text[weight_begin] == ' ') weight_begin--;
                while(weight_begin > pair_begin && assocs_text[weight_begin] != ' ') {
                    char digit = assocs_text[weight_begin];

                    if(digit == '.') {
                        is_int = true;
                    } else if(is_int) {
                        weight += (digit - '0') * int_mul;
                        int_mul *= 10;
                    } else {
                        weight = (weight + (digit - '0')) * 0.1;
                    }

                    weight_begin--;
                }

                // парсинг фразы
                size_t phrase_begin = pair_begin;
                while(phrase_begin < weight_begin && assocs_text[phrase_begin] == ' ') {
                    phrase_begin++;
                }
                size_t phrase_end = weight_begin;
                while(phrase_end > phrase_begin && assocs_text[phrase_end] == ' ') {
                    phrase_end--;
                }

                top_phrases.push_back(PhraseWeight(
                    std::string(assocs_text.Data() + phrase_begin, assocs_text.Data() + phrase_end + 1),
                    weight * mask_weights[mask_index]
                ));
                pair_begin = pair_end + 1;
            }
        }
    }

    // сортируем топ по убыванию веса
    std::sort(
        top_phrases.begin(),
        top_phrases.end(),
        [](const PhraseWeight& lhs, const PhraseWeight& rhs) -> bool { return lhs.second > rhs.second; }
    );

    // формирование ответа
    for(size_t phrase_index = 0; phrase_index < top_phrases.size() && phrase_index < max_phrases; phrase_index++) {
        if(buffer) {
            buffer << split;
        }

        buffer << top_phrases[phrase_index].first << " " << FloatToString(top_phrases[phrase_index].second, PREC_POINT_DIGITS, 3);
    }
}

void BroadMatcher::MatchFull(const MultiIndex& index, const char *dict, size_t max_phrases, int /* flags */, const std::vector<std::string>& words, TStringBuilder& buffer) const {
    size_t count = std::min(MAX_PHRASE_WORDS, words.size());
    size_t i, j;
    char split[] = " , ";
    size_t phrase_count = 0;

    if(!words.size()) {
        return;
    }

    for(i = 0; i < masks[count - 1].size() && phrase_count < max_phrases; i++) {
        unsigned mask = masks[count - 1][i];
        TStringBuilder subphrase;
        for(j = 0; j < count; j++) {
            if(mask & (1 << j)) {
                if(subphrase) {
                    subphrase << ' ';
                }

                subphrase << words[j];
            }
        }

        // матчинг подфразы
        IndexValueType value_type;
        const unsigned char* value_ptr = NULL;
        const Index* index_ptr = NULL;
        index.FindValue(dict, TString(subphrase).Data(), value_type, value_ptr, index_ptr);

        if(value_ptr) {
            if(buffer) {
                buffer << split;
            }
            auto value = index_ptr->GetValue(value_ptr, value_type);
            phrase_count++;
            for (size_t ch = 0; ch < value.Size(); ++ch) {
                if (value[ch] == ',') {
                    phrase_count++;
                    if(phrase_count > max_phrases) {
                        value = TStringBuf(value).SubString(0, ch);
                        break;
                    }
                }
            }
            buffer << value;
        }
    }
}
