# -*- coding: utf-8 -*-

import json
import logging
import math
import numpy as np
import os
import tempfile

from collections import defaultdict, deque

from sandbox import sdk2
from sandbox.projects.websearch.begemot.tasks.GetBegemotResponses import GetBegemotResponses
from sandbox.projects.websearch.begemot import parameters as bp
from sandbox.sdk2.helpers import subprocess


def get_template(filename):
    with open(os.path.join(os.path.dirname(__file__), filename), 'rb') as fd:
        return fd.read().decode('utf-8')


def render_table_row(row):
    return '<tr>{}</tr>'.format(''.join('<td>{}</td>'.format(cell) for cell in row))


def render_table(header, rows):
    return get_template('table.html').format(
        header=render_table_row(header),
        rows=''.join(map(render_table_row, rows)),
    )


class GetBegemotResponsesPerf(GetBegemotResponses):
    """
    Get begemot responses with perf calculation
    """
    class Context(sdk2.Context):
        critical_path_metrics = {}
        bgschema = {}
        statistics = {}
        batch_requests = 0

    class Parameters(GetBegemotResponses.Parameters):
        begemot_evlogstat = bp.BegemotEvlogstatBinaryResource(required=True)
        iterations = sdk2.parameters.Integer('Number of iterations', default=3)
        no_cache = GetBegemotResponses.Parameters.no_cache(default=True)

    def _parse_evlog(self):
        begemot_evlogstat = str(sdk2.ResourceData(self.Parameters.begemot_evlogstat).path)

        self.Context.bgschema = self._get_bgschema()
        with tempfile.NamedTemporaryFile(mode='w') as bgschema_file:
            bgschema_file.write(json.dumps(self.Context.bgschema))
            bgschema_file.delete = False

        with open('begemot.evlog', 'r') as log_stream, sdk2.helpers.ProcessLog(self, logger='evlogstat') as pl:
            proc = subprocess.Popen(
                (
                    begemot_evlogstat,
                    '--bgschema', bgschema_file.name,
                    '--start-frame', str(self.Context.batch_requests),
                ),
                stdin=log_stream,
                stdout=subprocess.PIPE,
                stderr=pl.stderr,
            )
            try:
                self.Context.statistics = json.loads(proc.stdout.read())
            except:
                self.Context.statistics = {}

    def _get_bgschema(self):
        begemot_path = str(sdk2.ResourceData(self.Parameters.begemot_binary).path)

        with sdk2.helpers.ProcessLog(self, logger='bgschema') as pl:
            proc = subprocess.Popen((begemot_path, '--print-bgschema'), stdout=subprocess.PIPE, stderr=pl.stderr)
            return json.load(proc.stdout)

    def _build_critical_path_metrics(self, target_metric):
        critical_info = self.Context.statistics['CriticalInfo']
        crit_time = [value['CriticalTime'] for value in critical_info.values()]
        crit_index = crit_time.index(max(crit_time))
        crit_key = critical_info.keys()[crit_index]
        crit_path = list()
        while crit_key is not None:
            crit_path.append(crit_key)
            crit_key = critical_info[crit_key].get('CriticalDependency', None)

        self.Context.critical_path_metrics = {
            '_longest_path': crit_path,
            'perfect_parallel_metric': crit_time[crit_index],
            'current_parallel_metric': self.Context.statistics['FrameStat'][target_metric],
            'metric': sum(metrics[target_metric] for metrics in self.Context.statistics['RuleStats'].itervalues())
        }

    @sdk2.report(title='Rules metrics', label='_rules_metrics_report')
    def _rules_metrics_report(self):
        rules_metrics = self.Context.statistics['RuleStats']
        metrics = ('mean', 'median', 'sqrt_variance', 'amount', 'quantile_95', 'quantile_90', 'quantile_75')

        header = ('Rules',) + metrics
        rows = [
            (
                rule,
                rule_metrics['Average'],
                rule_metrics['Quantility'][50],
                rule_metrics['SqrtVariance'],
                rule_metrics['Amount'],
                rule_metrics['Quantility'][95],
                rule_metrics['Quantility'][90],
                rule_metrics['Quantility'][75]
            )
            for rule, rule_metrics in rules_metrics.iteritems()
        ]
        rows.sort(key=lambda x: x[1], reverse=True)
        return render_table(header, rows)

    @sdk2.report(title='Critical path', label='_critical_path_report')
    def _critical_path_report(self):
        critical_path = self.Context.critical_path_metrics.get('_longest_path', ())

        header = ('Rule', 'Critical path mean')
        rows = [
            (
                rule,
                round(self.Context.statistics['CriticalInfo'][rule]['CriticalTime']),
            )
            for rule in critical_path
        ]
        return render_table(header, rows)

    @sdk2.header()
    def _header(self):
        perfect_parallel_metric = self.Context.critical_path_metrics.get('perfect_parallel_metric', 0.0)
        current_parallel_metric = self.Context.critical_path_metrics.get('current_parallel_metric', 0.0)
        metric = self.Context.critical_path_metrics.get('metric', 0.0)
        return get_template('header.html').format(
            perfect_parallelism=perfect_parallel_metric,
            perfect_parallelism_prc=round(perfect_parallel_metric * 100 / float(metric or 1), 2),
            current_parallelism=current_parallel_metric,
            current_parallelism_prc=round(current_parallel_metric * 100 / float(metric or 1), 2),
            rules_summary=metric,
        )

    def _build_metrics(self):
        self._parse_evlog()
        if not self.Context.statistics:
            return
        self._build_critical_path_metrics('Average')

    def _run_begemot(self, args, input_file, output_file, err_file):
        begemot_process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=output_file, stderr=err_file)
        queries = input_file.read()
        if queries and queries[-1] != '\n':
            queries = queries + '\n'
        self.Context.batch_requests = sum(1 for x in queries.split('\n') if x)

        try:
            for i in range(int(self.Parameters.iterations)):
                begemot_process.stdin.write(queries)
                begemot_process.stdin.flush()
            begemot_process.stdin.close()
        except IOError:
            begemot_process.kill()
        code = begemot_process.wait()
        if code:
            raise sdk2.helpers.ProcessLog.CalledProcessError(code, args, log_resource=self.log_resource)

    def on_execute(self):
        super(GetBegemotResponsesPerf, self).on_execute()
        self._build_metrics()
