# -*- coding: utf-8 -*-

import os
import csv
import requests

from collections import OrderedDict
import jinja2
from sandbox import sdk2

from sandbox.projects.common import utils2
from sandbox.projects.websearch.performance_report.resources import LatencyDiffHtml

DATE_PATTERN = r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}'


def isnumeric(s):  # i wonder why python2 default strings do not have this method
    try:
        float(s)
        return True
    except ValueError:
        return False


class CalculateAppHostLatencyDiff(sdk2.Task):
    """
        Задача строит топ источников по увеличению latency
    """

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.CheckGroup('Data centers') as dc:
            dc.values.ALL = dc.Value('ALL', checked=True)
            dc.values.man = 'man'
            dc.values.sas = 'sas'
            dc.values.vla = 'vla'
        graph = sdk2.parameters.List('List of graphs')
        with sdk2.parameters.CheckGroup('Fields') as fields:
            fields.values.p50 = fields.Value('p50', checked=True)
            fields.values.p95 = 'p95'
            fields.values.p99 = 'p99'
        with sdk2.parameters.Group('Dates'):
            first_date_min = sdk2.parameters.StrictString('First range start', regexp=DATE_PATTERN)
            first_date_max = sdk2.parameters.StrictString('First range end', regexp=DATE_PATTERN)
            second_date_min = sdk2.parameters.StrictString('Second range start', regexp=DATE_PATTERN)
            second_date_max = sdk2.parameters.StrictString('Second range end', regexp=DATE_PATTERN)
        diff_threshold = sdk2.parameters.Float('Max allowed diff, %', default=5)
        statface_report_path = sdk2.parameters.String(
            'Statface report path',
            default='Yandex/Others/AppHostPerformanceReportFast',
            required=True,
        )

    def on_execute(self):
        statface_token = sdk2.Vault.data(self.Parameters.owner, name='STATFACE_TOKEN')
        headers = {'Authorization': 'OAuth {}'.format(statface_token)}
        url = 'https://stat.yandex-team.ru/_api/statreport/csv/{}'.format(self.Parameters.statface_report_path)
        params = {
            '_incl_fields': self.Parameters.fields,
            'scale': 'h',
            '_allow_transpose': 0,
        }
        css_classes = OrderedDict()
        css_classes['good'] = {'background-color': '#afa'}
        css_classes['bad'] = {'background-color': '#faa'}
        tabs = OrderedDict()
        tabs['field'] = self.Parameters.fields
        tabs['dc'] = self.Parameters.dc
        tabs['graph'] = self.Parameters.graph
        table_labels = ['Source', 'First latency', 'Second latency', 'Relative difference, %']
        tables = []

        def classifier(row):
            diff = float(row[3])
            if diff > self.Parameters.diff_threshold:
                return 'bad'
            if diff < -self.Parameters.diff_threshold:
                return 'good'

        for dc in self.Parameters.dc:
            for graph in self.Parameters.graph:
                self.set_info('dc = {}, graph = {}'.format(dc, graph))
                params['date_min'] = self.Parameters.first_date_min
                params['date_max'] = self.Parameters.first_date_max
                params['dc'] = dc
                params['graph'] = graph
                self.set_info('Download first range data')
                first_request = requests.get(url, data=params, headers=headers)
                first_request.raise_for_status()
                params['date_min'] = self.Parameters.second_date_min
                params['date_max'] = self.Parameters.second_date_max
                self.set_info('Download second range data')
                second_request = requests.get(url, data=params, headers=headers)
                second_request.raise_for_status()
                self.set_info('Calculate average latencies')
                first_tables = self._calculate_average_latencies(first_request.content)
                second_tables = self._calculate_average_latencies(second_request.content)
                self.set_info('Merge tables')
                merged_tables = self._merge_and_sort(first_tables, second_tables)
                for field, table in merged_tables.iteritems():
                    tables.append({
                        'keys': {'field': field, 'dc': dc, 'graph': graph},
                        'table': table,
                        'labels': table_labels,
                        'classifier': classifier,
                    })
                self.set_info('Success!')
        self.set_info('Build report')
        report = self._build_report(tabs, tables, css_classes)
        self._log_results(report)
        self.set_info('All done!')

    def _log_results(self, results):
        results_path = os.path.abspath('results.html')
        with open(results_path, 'w') as results_file:
            results_file.write(results)
        results_resource_id = LatencyDiffHtml(
            self,
            self.Parameters.description,
            results_path,
            ttl=90, arch='any',
        ).id
        self.Parameters.description += '\nResults: {}'.format(
            utils2.resource_redirect_link(results_resource_id, 'link')
        )

    def _calculate_average_latencies(self, csv_table):
        stream = csv_table.splitlines()
        reader = csv.reader(stream, delimiter=';')
        labels = reader.next()
        date_label = '\xc4\xe0\xf2\xe0'
        assert labels.index(date_label) == 0, 'Date is not in the first column'
        source_label = '\xc8\xf1\xf2\xee\xf7\xed\xe8\xea'
        assert labels.index(source_label) == 1, 'Source is not in the second column'
        labels[labels.index(date_label)] = 'date'
        labels[labels.index(source_label)] = 'source'
        average_latencies = {label: {} for label in labels[2:]}
        sum_of_latencies = {label: {} for label in labels[2:]}
        total_entries = {label: {} for label in labels[2:]}
        for row in reader:
            source = row[1]  # skip date
            for i in range(2, len(row)):
                label = labels[i]
                value = row[i]
                if value:
                    sum_of_latencies[label][source] = sum_of_latencies[label].get(source, 0) + float(value)
                    total_entries[label][source] = total_entries[label].get(source, 0) + 1
        for label, source_entries in total_entries.items():
            for source, entries in source_entries.items():
                average_latencies[label][source] = sum_of_latencies[label][source] / entries
        return average_latencies

    def _merge_and_sort(self, first_tables, second_tables):
        merged_tables = {}
        for field, first_table in first_tables.iteritems():
            merged_table = []
            second_table = second_tables[field]
            for source, first_latency in first_table.iteritems():
                if source not in second_table:
                    continue
                second_latency = second_table[source]
                diff = second_latency - first_latency
                relative_diff = diff / first_latency
                merged_table.append([
                    source,
                    '{:.1f}'.format(first_latency),
                    '{:.1f}'.format(second_latency),
                    '{:.1f}'.format(relative_diff * 100),
                ])
            merged_table.sort(key=lambda row: float(row[3]), reverse=True)
            merged_tables[field] = merged_table
        return merged_tables

    def _build_report(self, tabs_data, tables, custom_css_classes):
        jinja_template = jinja2.Environment(
            loader=jinja2.FileSystemLoader(os.path.dirname(__file__)),
        ).get_template('template.html')
        html_report = jinja_template.render(
            tabs_data=tabs_data,
            tables=tables,
            custom_css_classes=custom_css_classes,
            isnumeric=isnumeric,  # bad! should somehow get rid of isnumeric
        )
        return html_report
