#!/usr/bin/env python
# -*- coding: utf-8 -*-

import yt.wrapper as yt
import argparse
import contextlib
import struct
import hashlib
from collections import defaultdict

import ocr_factors

def parse_args():
    parser = argparse.ArgumentParser(description='Make pool for correct answers formula')
    parser.add_argument('-c', '--yt-cluster', default='hahn', help='YT cluster')
    parser.add_argument('-m', '--min-total-answers', default=100, type=int, help='Min total answers for a picture to use it')
    parser.add_argument('-l', '--log-results', required=True, help='Table with raw CAPTCHA answers')
    parser.add_argument('-t', '--toloka-answers', required=True, help='Table with toloka answers')
    parser.add_argument('-o', '--output-dir', required=True, help='Output dir')
    parser.add_argument('-w', '--word-lists', nargs='*', type=argparse.FileType('r'), default=[], help='Word lists for factors')
    parser.add_argument('--join-memory-limit', default=1024**3, type=int, help='Memory limit for join jobs')
    parser.add_argument('--memory-limit', default=1024**3, type=int, help='Memory limit for jobs calculating factors')
    return parser.parse_args()

@contextlib.contextmanager
def get_correct_answers(input_table):
    def mapper(record):
        if record["Type"] == "OK" or record["Type"] == "DIFF":
            yield {
                   "unique_name": record["Unique_name"],
                   "toloka_answer": record["WordGT"],
                   "ocr_answer": record["Recognition"],
                   "ocr_language": record["Language"],
                   "ocr_confidence": record["Confidence"],
                  }

    with yt.TempTable() as output_table:
        yt.run_map(mapper, input_table, output_table)
        yt.run_sort(output_table, sort_by="unique_name")
        yield output_table

def joiner(key, records):
    answer = None
    log_results = []
    for record in records:
        if record["@table_index"] == 0:
            answer = record
        else:
            log_results.append(record)

    if not answer or len(log_results) == 0:
        return

    for log_result in log_results:
        output = dict(answer)
        output.update(log_result)
        output["date"] = log_result["timestamp"].split()[0]
        output["@table_index"] = 0
        yield output

class Reducer(object):
    def __init__(self, min_total_answers, word_lists):
        self.min_total_answers = min_total_answers
        self.word_lists = word_lists

    def __call__(self, key, records):
        unique_answers = defaultdict(list)
        total_answers = 0
        for record in records:
            unique_answers[record["unique_answer"].decode('utf-8')].append(record)
            total_answers += 1

        if (total_answers < self.min_total_answers):
            return

        key_str = key["date"] + "#####" + key["unique_name"]
        toloka_answer = record["toloka_answer"].decode('utf-8')
        ocr_data = {
            "ocr_answer": record["ocr_answer"].decode('utf-8'),
            "ocr_confidence": record["ocr_confidence"],
            "ocr_language": record["ocr_language"]
        }

        for unique_answer, factors in ocr_factors.calculate_factors(unique_answers, ocr_data, self.word_lists):
            label = int(unique_answer == toloka_answer)
            tag = repr(key_str + "#####" + unique_answer)

            value = "%d\t%s\t0"%(label, tag)
            for fname in ocr_factors.factor_names(self.word_lists):
                value += "\t%f"%factors[fname]

            yield {
                "key": str(struct.unpack("<L", hashlib.md5(key_str).digest()[:4])[0]),
                "subkey": "",
                "value": value,
            }

@contextlib.contextmanager
def join_input_tables(log_results, toloka_answers, min_total_answers, memory_limit, join_memory_limit, word_lists):
    with yt.TempTable() as sorted_log_results:
        yt.run_sort(log_results, sorted_log_results, sort_by="unique_name")
        with get_correct_answers(toloka_answers) as good_answers_table:
            with yt.TempTable() as output_table:
                tables_for_join = [yt.TablePath(good_answers_table, attributes={'foreign': True}), sorted_log_results]
                yt.run_join_reduce(
                    joiner,
                    tables_for_join,
                    output_table,
                    join_by="unique_name",
                    format=yt.JsonFormat(control_attributes_mode="row_fields"),
                    memory_limit=join_memory_limit
                )
                yt.run_sort(output_table, sort_by=["unique_name", "date"])
                yt.run_reduce(Reducer(min_total_answers, word_lists), output_table, output_table, reduce_by=["unique_name", "date"],
                        format=yt.YsonFormat(control_attributes_mode="row_fields"), memory_limit=memory_limit)
                yield output_table

def main():
    args = parse_args()
    word_lists = [ocr_factors.load_word_list(f) for f in args.word_lists]

    yt.config["proxy"]["url"] = args.yt_cluster

    with join_input_tables(args.log_results, args.toloka_answers, args.min_total_answers, args.memory_limit, args.join_memory_limit, word_lists) as joined_tables:
        features = joined_tables
        with yt.Transaction():
            yt.create("map_node", args.output_dir, recursive=True, force=True)
            yt.copy(features, args.output_dir + "/features", force=True)
            factor_names_generator = ({"key": str(i), "value": fname} for i, fname in enumerate(ocr_factors.factor_names(word_lists)))
            yt.write_table(args.output_dir + "/factor_names", factor_names_generator, raw=False)

if __name__ == "__main__":
    main()
