#!/usr/bin/env python

from collections import defaultdict
import yt.wrapper as yt
import argparse
import json
import itertools
import libmxnet
import sys

import ocr_factors

def parse_args():
    parser = argparse.ArgumentParser(description='Select best answers for OCR captcha using trained formula')
    parser.add_argument('-c', '--yt-cluster', default='hahn', help='YT cluster')
    parser.add_argument('-f', '--formula-path', required=True, help='Path to formula on YT')
    parser.add_argument('-l', '--log-results', required=True, help='Table with raw CAPTCHA answers')
    parser.add_argument('-i', '--input-unknown', required=True, help='Table with OCR answers')
    parser.add_argument('-o', '--output-table', required=True, help='Output table')
    parser.add_argument('-w', '--word-lists', nargs='*', default=[], type=argparse.FileType('r'), help='Word lists for factors')
    parser.add_argument('--memory-limit', default=1024**3, type=int, help='Memory limit for jobs calculating factors')
    return parser.parse_args()

def map_input_unknown(record):
    if 'Language' not in record:
        return
    yield {
        'unique_name': record['Unique_name'],
        'ocr_answer': record['Recognition'],
        'ocr_language': record['Language'],
        'ocr_confidence': record['Confidence'],
    }

@yt.reduce_aggregator
class ReduceBestAnswer(object):
    def __init__(self, word_lists):
        self.word_lists = word_lists

    def __call__(self, row_groups):
        formula = libmxnet.TMXNetInfo('formula')
        threshold = float(formula.GetProperty('selected-threshold'))
        factor_ids = json.loads(formula.GetProperty('factor-ids'))
        min_total_answers = int(formula.GetProperty('min-total-answers'))

        for key, recs in row_groups:
            ocr_data = None
            unique_answers = defaultdict(list)
            total_answers = 0
            for rec in recs:
                if rec['@table_index'] == 0:
                    assert ocr_data is None
                    ocr_data = {
                        "ocr_answer": rec["ocr_answer"].decode('utf-8'),
                        "ocr_confidence": rec["ocr_confidence"],
                        "ocr_language": rec["ocr_language"]
                    }
                else:
                    if not isinstance(rec["timestamp"], str):
                        continue #workaround for a bug where timestamp may be int for some reason
                    try:
                        unique_answers[rec["unique_answer"].decode('utf-8')].append(rec)
                    except UnicodeDecodeError:
                        print >>sys.stderr, 'Failed to decode answer for record %s'%rec
                        continue
                    total_answers += 1

            if ocr_data is None or not unique_answers:
                continue

            if total_answers < min_total_answers:
                continue

            unique_answers_ordered = []
            pool = []
            for unique_answer, answer_factors in ocr_factors.calculate_factors(unique_answers, ocr_data, self.word_lists):
                unique_answers_ordered.append(unique_answer)
                pool.append([float(answer_factors[fac_id]) for fac_id in factor_ids])
            weights = formula.Calculate(pool)
            max_weight, best_answer = max(itertools.izip(weights, unique_answers_ordered))
            if max_weight >= threshold:
                yield {
                    "unique_name": key["unique_name"],
                    "all_answers": {k: len(v) for k, v in unique_answers.iteritems()},
                    "answer": best_answer,
                    "confidence": max_weight
                }

def main():
    args = parse_args()
    word_lists = [ocr_factors.load_word_list(f) for f in args.word_lists]

    yt.config['proxy']['url'] = args.yt_cluster
    with yt.TempTable() as log_results_table:
        yt.run_sort(args.log_results, log_results_table, sort_by=['unique_name'])
        with yt.TempTable() as input_unknown_table:
            yt.run_map(map_input_unknown, args.input_unknown, input_unknown_table)
            yt.run_sort(input_unknown_table, input_unknown_table, sort_by=['unique_name'])
            yt.run_reduce(ReduceBestAnswer(word_lists), [input_unknown_table, log_results_table], args.output_table, reduce_by=['unique_name'],
                    format=yt.YsonFormat(control_attributes_mode="row_fields"),
                    yt_files=[yt.FilePath(args.formula_path, file_name='formula')],
                    memory_limit=args.memory_limit)


if __name__ == '__main__':
    main()
