#include "HuffmanTree.h"

#include <library/cpp/logger/global/global.h>

#include <util/datetime/cputimer.h>
#include <util/string/split.h>
#include <util/string/join.h>
#include <util/system/mem_info.h>

const size_t NUM_THREADS = 16;
const size_t BUFFER_SIZE = 1024 * 1024; // 1 Mb per line

const unsigned char PACKED_NODE_LEAF = 1;
const unsigned char PACKED_NODE_END_MARKER = 2;

class CompareNodes {
public:
    bool operator()(const HuffmanTree::Node* lhs, const HuffmanTree::Node* rhs) {
        return lhs->freq > rhs->freq;
    }
};

bool CompareWordCodes(const HuffmanTree::WordCode& lhs, const HuffmanTree::WordCode& rhs) {
    return std::strcmp(lhs.text, rhs.text) < 0;
}

void HuffmanTree::Generate(const DictFrequencies& freqs) {
    std::vector<Node*> q;

    INFO_LOG << "HuffmanTree::Generate started (mem: " << NMemInfo::GetMemInfo().VMS << ")";

    // создание начального списка узлов
    for(DictFrequencies::const_iterator freqs_it = freqs.begin(); freqs_it != freqs.end(); freqs_it++) {
        Node node;

        node.text = freqs_it->first;
        node.freq = freqs_it->second;

        nodes.push_back(node);
        q.push_back(&nodes.back());
    }
    std::make_heap(q.begin(), q.end(), CompareNodes());

    // построение дерева
    while(q.size() > 1) {
        std::pop_heap(q.begin(), q.end(), CompareNodes());
        Node* n1 = q.back();
        q.pop_back();
        std::pop_heap(q.begin(), q.end(), CompareNodes());
        Node* n2 = q.back();
        q.pop_back();

        // создание узла, объединяющего n1 и n2
        Node node;
        node.freq = n1->freq + n2->freq;
        node.left = n1;
        node.right = n2;
        nodes.push_back(node);
        n1->parent = n2->parent = &nodes.back();
        q.push_back(&nodes.back());
        std::push_heap(q.begin(), q.end(), CompareNodes());
    }

    // корневой узел
    root_node = q[0];

    INFO_LOG << " " << nodes.size() << " huffman nodes";

    // упаковка дерева
    PackTree();

    INFO_LOG << " packed tree is " << packed_nodes.size() << " bytes";

    // построение обратного словаря
    BuildWordCodes();
    INFO_LOG << "HuffmanTree::Generate finished (mem: " << NMemInfo::GetMemInfo().VMS << ")";
}

void HuffmanTree::BuildWordCodes() {
    word_codes.reserve (num_leaf_nodes);

    BuildWordCodes(0, 0, 0);

    std::sort(word_codes.begin(), word_codes.end(), CompareWordCodes);

    for(size_t i = 0; i < word_codes.size(); i++) {
        if(!strcmp(word_codes[i].text, HUFFMAN_END_MARKER)) {
            wc_end_marker = &word_codes[i];
        }
    }
}

size_t HuffmanTree::PackNode(const Node* node, size_t shift) {
    if(node->left) {
        size_t next_shift = shift + 1 + sizeof(size_t) * 2;
        packed_nodes[shift] = 0;

        // упаковка обоих потомков
        size_t *ptr = (size_t*) &packed_nodes[shift + 1];
        *ptr = next_shift;
        *(ptr + 1) = PackNode(node->left, next_shift);
        return PackNode(node->right, *(ptr + 1));
    }

    packed_nodes[shift] = PACKED_NODE_LEAF;

    if(node->text == HUFFMAN_END_MARKER) {
        packed_nodes[shift] |= PACKED_NODE_END_MARKER;
    }

    // терминальный узел -- упаковка строки
    std::copy(node->text.begin(), node->text.end(), &packed_nodes[shift + 1]);
    packed_nodes[shift + 1 + node->text.size()] = 0;
    num_leaf_nodes++;

    return shift + 2 + node->text.size();
}

void HuffmanTree::PackTree() {
    num_leaf_nodes = 0;

    // вычисление размера упакованных данных
    size_t full_size = 1;
    for(std::list<Node>::iterator node = nodes.begin(); node != nodes.end(); node++) {
        if(node->left) {
            full_size += 1 + sizeof(size_t) * 2;
        } else {
            full_size += 2 + node->text.size();
        }
    }
    packed_nodes.resize(full_size);

    // упаковка
    PackNode(root_node, 0);
}

void HuffmanTree::BuildWordCodes(size_t shift, WordBits bits, unsigned char num_bits) {
    if(packed_nodes[shift] & PACKED_NODE_LEAF) {
        WordCode wc;

        wc.text = &packed_nodes[shift + 1];
        wc.bits = bits;
        wc.num_bits = num_bits;

        word_codes.push_back(wc);
    } else {
        size_t *ptr = (size_t*) &packed_nodes[shift + 1];

        BuildWordCodes(*ptr, (bits | (((WordBits)1) << num_bits)), num_bits + 1);
        BuildWordCodes(*(ptr + 1), bits, num_bits + 1);
    }
}

bool HuffmanTree::GenerateFromFile(const char *file_name, const DictsProperties& dp) {
    std::ifstream stream(file_name);

    if(stream.fail()) {
        return false;
    }

    // вычисление частот слов
    const unsigned BUFFER_SIZE = 1024 * 1024;
    char buffer[BUFFER_SIZE];
    DictFrequencies freqs;
    TSimpleTimer timer;

    while(stream.good()) {
        stream.getline(buffer, BUFFER_SIZE);

        if(buffer[0]) {
            char *dict_name = &buffer[0];
            char *ptr = dict_name;

            // пропускаем первое поле
            while(*ptr && *ptr != '\t') ptr++;
            if(!*ptr) {
                continue;
            }

            *ptr = 0;
            ptr++;

            // вычисление частот слов
            DictsProperties::const_iterator it = dp.find(dict_name);
            bool is_key_compressed = (it == dp.end() || it->second.key_type == IKT_COMPRESSED);
            bool is_value_compressed = (it == dp.end() || it->second.value_type == IVT_COMPRESSED);
            bool is_key = true;
            while(*ptr) {
                char *begin = ptr;
                while(*ptr && *ptr != ' ' && *ptr != '\t') ptr++;
                bool is_tab = *ptr == '\t';
                char *next_ptr = *ptr ? ptr + 1 : ptr;

                // учитываются только те слова, которые нужно сжимать
                if(ptr > begin && ((is_key && is_key_compressed) || (!is_key && is_value_compressed))) {
                    *ptr = 0;
                    freqs[begin]++;
                }

                // проверяем, закончился ли ключ
                if(is_tab) {
                    is_key = false;

                    if(!is_value_compressed) {
                        break;
                    }
                }

                ptr = next_ptr;
            }

            if(is_key_compressed) {
                freqs[HUFFMAN_END_MARKER]++;
            }

            if(is_value_compressed) {
                freqs[HUFFMAN_END_MARKER]++;
            }
        }
    }

    INFO_LOG << " " << freqs.size() << " unique words have been loaded for "
             << timer.Get().Seconds() << " seconds";

    Generate(freqs);

    return true;
}

bool HuffmanTree::GenerateFromFileParallel(
            const char *file_name,
            const DictsProperties& dict_prop) {

    /* вычисление частот слов */
    TSimpleTimer timer_all;

    // заводим класс для параллельной обработки файла
    WordCountProcessor words_counter(&dict_prop, NUM_THREADS, BUFFER_SIZE);
    // запускаем обработку

    TSimpleTimer timer;
    words_counter.process_file(file_name);
    INFO_LOG << "file " << file_name << " processed in "
             << timer.Get().Seconds() << " seconds";

    // мержим результаты
    timer.Reset();
    DictFrequencies freqs = words_counter.join_results();
    INFO_LOG << "results merged in "
             << timer.Get().Seconds() << " seconds";

    INFO_LOG << " " << freqs.size() << " unique words have been loaded in "
           << timer_all.Get().Seconds() << " seconds";

    /* построение дерева */
    Generate(freqs);

    return true;
}


unsigned char* HuffmanTree::WordCode::Write(unsigned char* output, unsigned char* end, unsigned& bit_count) const {
    unsigned i;
    for(i = 0; i < num_bits && output < end; i++) {
        if(!bit_count) {
            *output = 0;
        }

        unsigned char mask = (unsigned char)(((bits >> i) & 1) << bit_count);
        *output |= mask;

        bit_count++;
        if(bit_count >= 8) {
            bit_count = 0;
            output++;
        }
    }

    return output;
}

void HuffmanTree::WordCode::Write(
    std::vector<char>& dst,
    unsigned& bit_count) const {
  for (size_t bit_index = 0; bit_index < num_bits; ++bit_index) {
    if (bit_count == 0) {
      dst.push_back(0);
    }
    unsigned char mask = (unsigned char)(((bits >> bit_index) & 1) << bit_count);
    dst.back() |= mask;
    ++bit_count;
    if (bit_count == 8) {
      bit_count = 0;
    }
  }
}

const HuffmanTree::WordCode* HuffmanTree::FindWord(const char* word) const {
    WordCode wc;

    wc.text = word;

    std::vector<WordCode>::const_iterator it = std::lower_bound(word_codes.begin(), word_codes.end(), wc, CompareWordCodes);
    if(it != word_codes.end() && !strcmp(word, it->text)) {
        return &(*it);
    }

    return 0;
}

unsigned HuffmanTree::CompressSize(const TStringBuf& data) const {
    const char *ptr = data.data();
    unsigned bit_count = wc_end_marker->num_bits;

    while(*ptr) {
        const char *begin = ptr;
        while(*ptr && *ptr != ' ') ptr++;

        std::string word;
        word.assign(begin, ptr);
        const WordCode* wc = FindWord(word.c_str());
        if(wc) {
            bit_count += wc->num_bits;
        }

        while(*ptr && *ptr == ' ') ptr++;
    }

    return bit_count / 8 + (bit_count % 8 ? 1 : 0);
}

unsigned HuffmanTree::Compress(const TStringBuf& data, unsigned char* output, unsigned char* output_end) const {
    const char *ptr = data.data();
    unsigned char* curr_output = output;
    unsigned bit_count = 0;
    std::string word;

    while(*ptr) {
        const char *begin = ptr;
        while(*ptr && *ptr != ' ') ptr++;

        word.assign(begin, ptr);
        const WordCode* wc = FindWord(word.c_str());
        if(wc) {
            curr_output = wc->Write(curr_output, output_end, bit_count);
        } else {
            return 0; // во фразе есть слово, которого нет в словаре
        }

        while(*ptr && *ptr == ' ') ptr++;
    }

    curr_output = wc_end_marker->Write(curr_output, output_end, bit_count);

    return (curr_output - output) + (bit_count ? 1 : 0);
}

bool HuffmanTree::Encode(
    const std::vector<std::string>& data,
    std::vector<char>& encoded_data,
    bool add_end_marker) const {
  encoded_data.clear();
  unsigned bit_count = 0;
  for (const std::string& word : data) {
    const WordCode* word_code = FindWord(word.c_str());
    if (!word_code) {
      return false;
    }
    word_code->Write(encoded_data, bit_count);
  }
  if (add_end_marker) {
    wc_end_marker->Write(encoded_data, bit_count);
  }
  return true;
}

std::string HuffmanTree::Decode(
    const std::vector<char>& encoded_data,
    char words_delim) const {
  std::string result = "";
  unsigned bit_count = 0;
  const unsigned char* ptr = (const unsigned char*)encoded_data.data();

  while(true) {
    size_t node_shift = 0;

    while(!(packed_nodes[node_shift] & PACKED_NODE_LEAF)) {
      size_t *node_ptr = (size_t*) &packed_nodes[node_shift + 1];

          if((*ptr >> bit_count) & 1) {
              node_shift = *node_ptr;
          } else {
              node_shift = *(node_ptr + 1);
          }

          bit_count++;
          if(bit_count >= 8) {
              bit_count = 0;
              ptr++;
          }
      }

      if(packed_nodes[node_shift] & PACKED_NODE_END_MARKER) {
          break;
      } else {
          if(!result.empty()) {
            result.push_back(words_delim);
          }
          for(const char *src = &packed_nodes[node_shift + 1];
              *src;
              src++) {
              result.push_back(*src);
          }
      }
  }

  return result;
}

TString HuffmanTree::Decompress(const unsigned char* data) const {
    TVector<TStringBuf> nodes;

    const unsigned char* ptr = data;
    unsigned bit_count = 0;
    std::vector<std::string> words;

    while(true) {
        size_t node_shift = 0;

        while(!(packed_nodes[node_shift] & PACKED_NODE_LEAF)) {
            size_t *node_ptr = (size_t*) &packed_nodes[node_shift + 1];

            if((*ptr >> bit_count) & 1) {
                node_shift = *node_ptr;
            } else {
                node_shift = *(node_ptr + 1);
            }

            bit_count++;
            if(bit_count >= 8) {
                bit_count = 0;
                ptr++;
            }
        }

        if(packed_nodes[node_shift] & PACKED_NODE_END_MARKER) {
            break;
        } else {
            nodes.push_back(&packed_nodes[node_shift + 1]);
        }
    }

    return JoinSeq(" ", nodes);
}

bool HuffmanTree::Load(std::istream& stream, bool verbose) {
    Y_UNUSED(verbose);
    if(stream.fail()) {
        return false;
    }

    // заголовок
    DictFileHeader header;
    stream.read((char*) &header, sizeof(header));

    if(stream.fail() || strcmp(header.signature, "DICT") || header.version != 0) {
        return false;
    }

    // упакованное дерево
    packed_nodes.resize(header.nodes_block_size);
    stream.read((char*) &packed_nodes[0], packed_nodes.size());

    // обратный словарь
    word_codes.resize(header.num_words);
    for(size_t i = 0; i < header.num_words; i++) {
        size_t shift;
        stream.read((char*) &shift, sizeof(shift));
        word_codes[i].text = &packed_nodes[shift];
        stream.read((char*) &word_codes[i].bits, sizeof(word_codes[i].bits));
        stream.read((char*) &word_codes[i].num_bits, sizeof(word_codes[i].num_bits));

        // маркер конца строки
        if(packed_nodes[shift-1] & PACKED_NODE_END_MARKER) {
            wc_end_marker = &word_codes[i];
        }
    }

    unsigned max_bits = 0;
    for(unsigned i = 0; i < word_codes.size(); i++) {
        max_bits = std::max((unsigned)word_codes[i].num_bits, max_bits);
    }
    DEBUG_LOG << word_codes.size() << " unique words totally, " << max_bits << " bits max size";
    return true;
}

bool HuffmanTree::Load(const TStringBuf& file_name, bool verbose) {
    std::ifstream stream(file_name.data(), std::ios::binary);
    return Load(stream, verbose);
}

bool HuffmanTree::Save(std::ostream& stream) const {
    if(stream.fail()) {
        return false;
    }

    // заголовок
    DictFileHeader header;
    std::strcpy(header.signature, "DICT");
    header.version = 0;
    header.nodes_block_size = packed_nodes.size();
    header.num_words = word_codes.size();
    stream.write((const char*)&header, sizeof(header));

    // упакованное дерево
    stream.write((const char*)&packed_nodes[0], packed_nodes.size());

    // кодирование слов -- каждый указатель на строку заменяется смещением относительно начала массива
    for(size_t i = 0; i < word_codes.size(); i++) {
        size_t shift = word_codes[i].text - &packed_nodes[0];
        stream.write((const char*)&shift, sizeof(shift));
        stream.write((const char*)&word_codes[i].bits, sizeof(word_codes[i].bits));
        stream.write((const char*)&word_codes[i].num_bits, sizeof(word_codes[i].num_bits));
    }

    return true;
}

bool HuffmanTree::Save(const TStringBuf& file_name) const {
    std::ofstream stream(file_name.data(), std::ios::binary);
    return Save(stream);
}


///////////////////////////////////////
/* Методы класса WordCountProcessor  */
///////////////////////////////////////

void WordCountProcessor::_clear_data(const size_t worker_i) {
    // чистим словать с частотами для текущего треда
    _worker_df[worker_i].clear();
}

WordCountProcessor::WordCountProcessor(
            const DictsProperties* dict_prop,
            const size_t num_threads, const size_t buffer_size):
    FileProcessor(num_threads, buffer_size), _dict_prop(dict_prop) {
    // определяем словари с результатами по каждому треду
    _worker_df.resize(_num_threads);
}

DictFrequencies WordCountProcessor::join_results() {
    DictFrequencies result;

    // merge результатов тредов
    for (size_t i=0; i<_worker_df.size(); ++i)
        for (auto it=_worker_df[i].begin(); it!=_worker_df[i].end(); it++)
            result[it->first] += it->second;
    return result;
}

void WordCountProcessor::process_line(char* line, const size_t worker_i) {
    // строку <название_словаря\tключ\tзначение> разбиваем на 3 части по табам
    TStringBuf lineBuffer(line, strlen(line));
    TVector<TStringBuf> parts;
    Split(lineBuffer, "\t", parts);

    if(parts.size() != 3) return;

    // оптимизация, создаем вектор для данных и резервируем память
    TVector<TStringBuf> data;
    data.reserve(32);

    // получаем свойства словаря
    DictsProperties::const_iterator it = _dict_prop->find(parts[0]);
    bool is_key_compressed = (it == _dict_prop->end() ||
                              it->second.key_type == IKT_COMPRESSED);
    bool is_val_compressed = (it == _dict_prop->end() ||
                              it->second.value_type == IVT_COMPRESSED);

    // сжимаем ключ, если нужно
    if (is_key_compressed) {
        // разбиваем на слова
        Split(parts[1], " ", data);
        for(const auto& word: data)
            _worker_df[worker_i][TString(word)]++;
    }

    // сжимаем значение, если нужно
    if (is_val_compressed) {
        // разбиваем на слова
        Split(parts[2], " ", data);
        for(const auto& word: data)
            _worker_df[worker_i][TString(word)]++;
    }

    // учитываем частоту маркера конца лексемы
    _worker_df[worker_i][HUFFMAN_END_MARKER] += is_key_compressed + is_val_compressed;
}

