# -*- coding: utf-8 -*-

import base64
import logging
import random

from sandbox.sandboxsdk import task, parameters, errors, channel
from sandbox.common import rest
from sandbox.common.types.task import Status
from sandbox.projects.common import utils
from sandbox.projects.common.gencfg import api_client as gencfg_api

from sandbox.projects.LbGetSaasResponses import LbGetSaasResponses as getresp

from collections import defaultdict


def data_from_response(response_base64):
    from sandbox.projects.common.base_search_quality.tree import meta_pb2
    response = meta_pb2.TReport()
    response.ParseFromString(base64.b64decode(response_base64))
    result = []
    for grouping in response.Grouping:
        for group in grouping.Group:
            for document in group.Document:
                for attr in document.ArchiveInfo.GtaRelatedAttribute:
                    if attr.Key == "s_q" or attr.Key == "s_i":
                        raise errors.SandboxTaskFailureError('Unsupported format')
                    if attr.Key == "data":
                        result.append(attr.Value)
    return sorted(result)


class LbCheckSaasConsistency(task.SandboxTask):
    """
        Проверяет согласованность ответов машинок SaaS из разных групп (ДЦ).
        SEARCH-2302
    """

    type = 'LB_CHECK_SAAS_CONSISTENCY'

    cores = 1

    class GencfgGroupsParam(parameters.ListRepeater, parameters.SandboxStringParameter):
        name = 'gencfg_groups'
        description = 'Group names'
        required = True

    class MaxMismatchesParam(parameters.SandboxStringParameter):
        name = 'max_mismatches'
        description = 'Maximum allowed mismatched responses'
        default_value = '0'

    class MaxAttemptsInGroupParam(parameters.SandboxIntegerParameter):
        name = 'max_attempts_in_group'
        description = 'Retry count per group if host is not responding'
        default_value = 3

    input_parameters = (
        GencfgGroupsParam,
        MaxMismatchesParam,
        MaxAttemptsInGroupParam,
        getresp.InputQueries,
        getresp.NumQueriesParam,
        getresp.UseNoregRequests,
        getresp.UseRegionRequests,
        getresp.RegionDispatchConfig,
        getresp.GeoaC2P,
        getresp.SaaSServiceId,
        getresp.SaaSIndexPrefix,
        getresp.TimeoutParam,
        getresp.RetryCountParam,
        getresp.MaxErrorsParam,
        getresp.ThreadCountParam,
    )

    @property
    def footer(self):
        if 'num_requests' in self.ctx:
            return [{
                'helperName': '',
                'content':
                    'Requests total: {num_requests}<br/>'
                    'Requests with responses from all hosts: {num_responses}<br/>'
                    'Requests with isomorphic responses: {num_matches}<br/>'
                    'Requests with non-empty isomorphic responses: {num_nonempty_matches}<br/>'
                    .format(**self.ctx)
            }]
        else:
            return []

    def on_execute(self):
        if 'subtask_ids' in self.ctx:
            success_subtasks, failed_subtasks = self.get_failed_subtasks()
            if failed_subtasks:
                self.ctx['subtask_ids'] = success_subtasks + self.start_another_instances(failed_subtasks)
                self.wait_all_tasks_completed(self.ctx['subtask_ids'])
            del self.ctx['good_instances']
            self.parse_results()
        else:
            self.ctx['subtask_info'] = {}
            self.ctx['good_instances'] = {}
            self.ctx['bad_instances'] = {}
            groups = utils.get_or_default(self.ctx, self.GencfgGroupsParam)
            if not groups:
                raise errors.SandboxTaskFailureError("No groups given")
            self.fetch_gencfg_groups(groups)
            self.ctx['subtask_ids'] = self.start_subtasks(groups)
            self.wait_all_tasks_completed(self.ctx['subtask_ids'])

    def fetch_gencfg_groups(self, groups):
        gencfg = gencfg_api.GencfgApiClient('http://api.gencfg.yandex-team.ru')
        for group in groups:
            gencfg_answer = gencfg.get_group_instances(group)
            instances = [inst['hostname'] + ':' + str(inst['port']) for inst in gencfg_answer['instances']]
            logging.debug("instances for %s: %s", group, instances)
            self.ctx['good_instances'][group] = instances
            self.ctx['bad_instances'][group] = []

    def start_subtasks(self, groups):
        selected_instances = []
        for group in groups:
            if not self.ctx['good_instances'][group]:
                raise errors.SandboxTaskFailureError('No instances for group ' + group)
            selected_instances.append((random.choice(self.ctx['good_instances'][group]), group))
        logging.debug("selected_instances: %s", selected_instances)

        ids = []
        for instance, group in selected_instances:
            subtask_descr = "shoot %s, %s" % (group, self.descr)
            subtask_ctx = {
                "notify_via": "",
                getresp.InputQueries.name: utils.get_or_default(self.ctx, getresp.InputQueries),
                getresp.NumQueriesParam.name: utils.get_or_default(self.ctx, getresp.NumQueriesParam),
                getresp.UseNoregRequests.name: utils.get_or_default(self.ctx, getresp.UseNoregRequests),
                getresp.UseRegionRequests.name: utils.get_or_default(self.ctx, getresp.UseRegionRequests),
                getresp.RegionDispatchConfig.name: utils.get_or_default(self.ctx, getresp.RegionDispatchConfig),
                getresp.GeoaC2P.name: utils.get_or_default(self.ctx, getresp.GeoaC2P),
                getresp.LingboostSaasHost.name: instance,
                getresp.SaaSServiceId.name: utils.get_or_default(self.ctx, getresp.SaaSServiceId),
                getresp.SaaSIndexPrefix.name: utils.get_or_default(self.ctx, getresp.SaaSIndexPrefix),
                getresp.TimeoutParam.name: utils.get_or_default(self.ctx, getresp.TimeoutParam),
                getresp.RetryCountParam.name: utils.get_or_default(self.ctx, getresp.RetryCountParam),
                getresp.MaxErrorsParam.name: utils.get_or_default(self.ctx, getresp.MaxErrorsParam),
                getresp.ThreadCountParam.name: utils.get_or_default(self.ctx, getresp.ThreadCountParam),
            }
            ids.append(self.create_subtask(task_type=getresp.type, description=subtask_descr, input_parameters=subtask_ctx).id)
            self.ctx['subtask_info'][ids[-1]] = (group, instance)
        return ids

    def get_failed_subtasks(self):
        success_subtasks, failed_subtasks, broken_subtasks = [], [], []
        for task_id in self.ctx['subtask_ids']:
            task = channel.channel.sandbox.get_task(task_id)
            if task.is_failure():
                failed_subtasks.append(task_id)
            elif task.new_status == Status.SUCCESS:
                success_subtasks.append(task_id)
            else:
                broken_subtasks.append(task_id)
        if broken_subtasks:
            raise errors.SandboxTaskFailureError('Child tasks have not finished correctly: ' + str(broken_subtasks))
        return success_subtasks, failed_subtasks

    def start_another_instances(self, failed_subtasks):
        failed_groups = []
        for task_id in failed_subtasks:
            group, bad_instance = self.ctx['subtask_info'][task_id]
            self.set_info('Warning: subtask for group {}, instance {} has failed'.format(group, bad_instance))
            self.ctx['good_instances'][group].remove(bad_instance)
            self.ctx['bad_instances'][group].append(bad_instance)
            if len(self.ctx['bad_instances'][group]) >= utils.get_or_default(self.ctx, self.MaxAttemptsInGroupParam):
                raise errors.SandboxTaskFailureError('Too many failed hosts in ' + group)
            failed_groups.append(group)
        return self.start_subtasks(failed_groups)

    def parse_results(self):
        result_resources = []
        result_sources = []
        for subtask_id in self.ctx['subtask_ids']:
            subtask_ctx = rest.Client().task[subtask_id].context.read()
            subtask_result_filename = self.sync_resource(subtask_ctx['saasdump_id'])
            result_resources.append(open(subtask_result_filename, 'r'))
            result_sources.append(self.ctx['subtask_info'][subtask_id][0])
        num_requests = 0
        num_responses = 0
        num_matches = 0
        num_nonempty_matches = 0
        source2error = defaultdict(lambda: 0)
        for line in result_resources[0]:
            line = line.rstrip('\r\n')
            key = line[:line.rfind('\t')+1]
            all_answered = True
            bundle2source = defaultdict(lambda: [])
            for index, other in enumerate(result_resources):
                otherline = line if index == 0 else other.readline()
                if not otherline.startswith(key):
                    raise errors.SandboxTaskFailureError('Inconsistent query keys between children')
                response_base64 = otherline[otherline.rfind('\t')+1:]
                if response_base64.startswith('['):
                    all_answered = False
                else:
                    docdata = data_from_response(response_base64)
                    bundle2source[tuple(docdata)].append(result_sources[index])
            num_requests += 1
            if all_answered:
                num_responses += 1
                if len(bundle2source) == 1:
                    num_matches += 1
                    if docdata:
                        num_nonempty_matches += 1
                else:
                    logging.error('--- mismatched responses for request {}: ---'.format(key))
                    good_sources = None
                    bad_source = None
                    single_bad_source = True
                    for bundle, sources in bundle2source.items():
                        logging.error('{} say(s) {}'.format(sources, bundle))
                        if len(sources) == 1:
                            if bad_source is None:
                                bad_source = sources[0]
                            else:
                                single_bad_source = False
                        else:
                            if good_sources is None:
                                good_sources = sources
                            else:
                                single_bad_source = False
                    if single_bad_source:
                        source2error[bad_source] += 1

        error_text = ''
        for bad_source, error_count in source2error.items():
            error_text += 'Error: majority vote decided {} errors of source {}\n'.format(error_count, bad_source)
        if error_text:
            self.set_info(error_text + 'See log for details')
        self.ctx['num_requests'] = num_requests
        self.ctx['num_responses'] = num_responses
        self.ctx['num_matches'] = num_matches
        self.ctx['num_nonempty_matches'] = num_nonempty_matches
        logging.info(
            '{} requests total, {} with responses from all hosts, {} with isomorphic responses, {} with non-empty isomorphic responses'.format(
                num_requests,
                num_responses,
                num_matches,
                num_nonempty_matches
            )
        )
        max_mismatches = utils.get_or_default(self.ctx, self.MaxMismatchesParam)
        if max_mismatches.endswith('%'):
            max_mismatches = float(max_mismatches[:-1]) * num_responses / 100
        else:
            max_mismatches = float(max_mismatches)
        if num_responses - num_matches > max_mismatches:
            raise errors.SandboxTaskFailureError('Too many mismatched responses')


__Task__ = LbCheckSaasConsistency
