#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import os
import json
import sys
import yt.wrapper as yt
from collections import Counter


class CheckFunction:
    def __init__(self, required_true_rate=1.0):
        self.required_true_rate = required_true_rate

    def get_label(self):
        raise NotImplementedError

    def check(self, begemot_answer):
        raise NotImplementedError


class SampleCheckFunction(CheckFunction):
    def get_label(self):
        return 'sample_check'

    def check(self, begemot_answer):
        for item in begemot_answer:
            if item.get('type') == 'wizard':
                return True

        return False


class AllCheckers:
    def __init__(self):
        self.checkers = [SampleCheckFunction()]


class BegemotResponsesChecker(AllCheckers):
    def __call__(self, row):
        answer_row = dict()
        begemot_answer = json.loads(row['begemot_answer'])
        for checker in self.checkers:
            try:
                answer_row[checker.get_label()] = str(checker.check(begemot_answer))
            except Exception as e:
                answer_row[checker.get_label()] = str(e)
        answer_row['reduce_key'] = 1
        yield answer_row


class CheckerReducer(AllCheckers):
    def __call__(self, key, records):
        passed = Counter()
        failed = Counter()
        exceptions = Counter()
        exceptions_examples = dict()
        for record in records:
            for checker in self.checkers:
                label = checker.get_label()
                if record[label] == str(True):
                    passed[label] += 1
                elif record[label] == str(False):
                    failed[label] += 1
                else:
                    exceptions[label] += 1
                    if exceptions[label] == 1:
                        exceptions_examples[label] = record[label]

        yield {
            'passed': dict(passed),
            'failed': dict(failed),
            'exceptions': dict(exceptions),
            'exceptions_examples': exceptions_examples,
        }


def run_checker(input_table, output_table):
    return yt.run_map_reduce(
        BegemotResponsesChecker(),
        CheckerReducer(),
        source_table=input_table,
        destination_table=output_table,
        reduce_by=['reduce_key'],
        spec=common_spec, client=client,
        sync=True,
    )

def build_output(output_table):
    checkers = AllCheckers().checkers
    output = {}

    for row in yt.read_table(output_table, client=client, format='json'):
        for checker in checkers:
            label = checker.get_label()
            output[label] = {
                'passed': row['passed'].get(label, 0),
                'failed': row['failed'].get(label, 0),
                'exceptions': row['exceptions'].get(label, 0),
                'required': float(checker.required_true_rate),
            }
            if row['exceptions'].get(label, 0):
                sys.stderr.write('Checker {} exception example: {}\n'.format(label, row['exceptions_examples'][label]))

        return output  # Expecting 1 row in output_table

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Checks begemot answers in YT with checkers defined in AllCheckers class'
    )
    parser.add_argument('--answers', dest='answers', help='Begemot answers', required=True)
    parser.add_argument('--output_path', dest='output_path', help='Intermediate tables path', required=True)
    parser.add_argument('--yt_proxy', dest='yt_proxy', required=True)
    parser.add_argument('--yt_pool', dest='yt_pool', required=True)
    args = parser.parse_args()

    token = os.environ['YT_TOKEN']
    client = yt.YtClient(
        args.yt_proxy,
        token,
        config=dict(
            pickling={
                'module_filter': lambda lib: hasattr(lib, '__file__') and not lib.__file__.endswith('.so')
                                             and 'hashlib' not in getattr(lib, '__name__', ''),
                'force_using_py_instead_of_pyc': True,
            },
            pool=args.yt_pool,
        ),
    )
    common_spec = {'max_failed_job_count': 0}

    out_table = yt.ypath_join(args.output_path, 'out_table')
    run_checker(args.answers, out_table)
    output = build_output(out_table)
    print(json.dumps(output))
