# -*- coding: utf-8 -*-

import os
import copy
import logging
from collections import OrderedDict

import jinja2
import sandbox.sandboxsdk.task as sdk_task
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.parameters import SandboxStringParameter
from sandbox.projects.common.geosearch import task as addrs_task
from sandbox.projects.AddrsTestBasesearchPerformance import AddrsTestBasesearchPerformance


class ShardMap(SandboxStringParameter):
    name = 'shardmap'
    description = 'Shardmap to test'
    required = True


class AddrsShardedBasesearchPerformance(sdk_task.SandboxTask):

    type = 'ADDRS_SHARDED_BASESEARCH_PERFORMANCE'
    basesearch_params = addrs_task.AddrsBasesearchTask.basesearch_common_parameters
    database_param = [p for p in basesearch_params if 'shard_rbtorrent' in p.name]
    child_parameters = [p for p in AddrsTestBasesearchPerformance.input_parameters if p not in database_param]
    input_parameters = OrderedDict([(p.name, p) for p in child_parameters]).values()
    input_parameters.append(ShardMap)

    @property
    def footer(self):
        template_path = os.path.dirname(os.path.abspath(__file__))
        env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path), extensions=['jinja2.ext.do'])
        if 'result' in self.ctx:
            data_to_render = {'result': self.ctx['result']}
        else:
            data_to_render = {}
        return env.get_template("footer.html").render(data_to_render)

    def parse_shardmap(self):
        shardmap_id = self.ctx.get(ShardMap.name)
        shardmap_file = self.sync_resource(shardmap_id)
        with open(shardmap_file) as f:
            shardmap = f.readlines()
        rbtorrents = []
        for line in shardmap:
            rbtorrent = line.split()[1].split('(')[0]
            rbtorrents.append(rbtorrent)
        return rbtorrents

    def create_performance_task(self, rbtorrent):
        input_parameters = copy.deepcopy(self.ctx)
        input_parameters['shard_rbtorrent'] = rbtorrent
        return self.create_subtask(task_type=AddrsTestBasesearchPerformance.type,
                                   description='Testing performance on shard %s' % rbtorrent,
                                   input_parameters=input_parameters)

    def spawn_performance_task(self, rbtorrents):
        subtasks = [self.create_performance_task(rbtorrent) for rbtorrent in rbtorrents]
        self.ctx['_PERFORMANCE_TASKS'] = [str(s.id) for s in subtasks]
        return subtasks

    def find_worst_params(self, performance_task):
        params = {}
        # performance_task = channel.sandbox.get_task(performance_task_id)
        performance_params = performance_task.ctx.get('results')
        params.update({'dumper.rps': max([session.get('dumper.rps') for session in performance_params])})
        params.update({'latency_0.95': min([session.get('latency_0.95') for session in performance_params])})
        params.update({'latency_0.99': min([session.get('latency_0.99') for session in performance_params])})
        params.update({'memory_rss': max(performance_task.ctx.get('memory_rss'))})
        params.update({'memory_vsz': max(performance_task.ctx.get('memory_vsz'))})
        params.update({'task_id': performance_task.id})
        return params

    def on_execute(self):
        bad_statuses = [self.Status.FAILURE,
                        self.Status.EXCEPTION,
                        self.Status.TIMEOUT]
        rbtorrents = self.parse_shardmap()
        if '_PERFORMANCE_TASKS' not in self.ctx:
            self.spawn_performance_task(rbtorrents)
            self.wait_tasks(
                self.list_subtasks(load=False),
                self.Status.Group.SUCCEED +
                self.Status.Group.FAIL_ON_ANY_ERROR +
                self.Status.Group.SCHEDULER_FAILURE +
                self.Status.Group.WAIT,
                wait_all=True,
                state='Waiting for subtasks to complete')
        else:
            stats = []
            for performance_task_id in self.ctx['_PERFORMANCE_TASKS']:
                performance_task = channel.sandbox.get_task(performance_task_id)
                logging.info('%s status is %s' %
                             (performance_task_id,
                              performance_task.new_status))
                if performance_task.new_status in bad_statuses:
                    raise SandboxTaskFailureError('Performance test failed')
                else:
                    stats.append(self.find_worst_params(performance_task))
                logging.info('Statistics for performance task: %s' % stats)
                self.ctx['result'] = stats


__Task__ = AddrsTestBasesearchPerformance
