#! /usr/bin/env python
# don't forget to run deploy.sh after PR was merged

import logging as log

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types import task as ctt
from sandbox.projects.common.binary_task import deprecated as binary_task
from sandbox.projects.tank.ShootViaTankapi import ShootViaTankapi
from sandbox.projects.tank.lunapark_client.client import LunaparkApiClient

SHOOT_CONFIG_ARC_PATH = 'sandbox/projects/market/contentApi/MarketAntifraudLoadTest/LoadTestSetup.yaml'

SERVICES = {
    'antifraud-orders': {
        'name': 'antifraud-orders',

        'nanny_service_prefix': 'testing_market_mstat_antifraud_orders_',
        # 'solomon_project_id': 'market-mstat',
        # 'solomon_service_name': 'market-mstat_antifraud_orders',
    }
}


class MarketAntifraudLoadTest(binary_task.LastBinaryTaskRelease, sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)
        config_resources = sdk2.parameters.List(label='Tank config resource id', default=[])
        labels = sdk2.parameters.List('List of solomon labels', default=[],
                                      description='Solomon metrics labels')
        repeats = sdk2.parameters.Integer('Amount of shoots', default_value=3)

    def on_execute(self):
        log.info('Starting on_execute')
        self.lunaparkClient = LunaparkApiClient()
        self.target_service = 'antifraud-orders'
        labels = self.Parameters.labels
        labels_size = len(labels)
        with self.memoize_stage.init_iters:
            self.Context.label_i = 0
        configs = self.Parameters.config_resources
        while self.Context.label_i < labels_size:
            label = labels[self.Context.label_i]
            self.set_info('Starting shooting {}'.format(label))
            if self.Context.label_i >= len(configs):
                log.warning('Amount of labels not match amount of configs')
                self.Context.label_i += 1
                continue
            config = configs[self.Context.label_i]
            with self.memoize_stage['Init rps list: ' + label]:
                self.Context.imbalance_rps_list = []
            k = 0
            while k < self.Parameters.repeats:
                self.shoot(config, k)
                k += 1
            self.set_info('Resuming service actions...')
            all_rps = self.Context.imbalance_rps_list
            self.set_info('Shootings finished. Imbalance RPS: {}'.format(all_rps))
            avg_imbalance_rps = sum(all_rps) / len(all_rps)
            self.set_info('Average imbalance RPS: {}'.format(avg_imbalance_rps))
            prod_max_rps = avg_imbalance_rps
            self.set_info("Prod max RPS: {}".format(prod_max_rps))
            prod_instance_max_rps = int(avg_imbalance_rps)
            self.set_info("Prod instance max RPS: {}".format(prod_instance_max_rps))
            push_task_ids = [self.push_to_solomon(label, prod_max_rps)]
            self.Context.label_i += 1
            raise sdk2.WaitTask(push_task_ids, ctt.Status.Group.SUCCEED, wait_all=True, timeout=3600)

    def shoot(self, config, i):
        with self.memoize_stage['Shoot {} {}'.format(config, i)]:
            self.set_info('Shooting ' + str(config) + ' ' + str(i))
            self.shoot_via_tank_api(config)
        with self.memoize_stage['Push results {} {}'.format(config, i)]:
            shoot_task = self.get_task(self.Context.last_shoot_task_id)
            log.info('shoot lunapark link: {}'.format(shoot_task.Parameters.lunapark_link))
            log.info('shoot lunapark job id: {}'.format(shoot_task.Parameters.lunapark_job_id))
            self.Context.imbalance_rps_list.append(self.get_imbalance_rps(shoot_task))

    def get_imbalance_rps(self, shoot_task):
        lunapark_job_id = shoot_task.Parameters.lunapark_job_id
        if not lunapark_job_id:
            raise TaskFailure('No lunapark job id')
        summary = self.lunaparkClient.get_summary(lunapark_job_id)
        log.info('shoot summary: {}'.format(summary))
        imbalance_rps = summary['imbalance_rps']
        if imbalance_rps == 0 and self.Parameters.start_rps > 0:
            self.set_info('!!WARNING!!: imbalance RPS == 0. Seems that shooting ends before '
                          'the service is unstable.')
            imbalance_rps = int(self.Parameters.target_rps)
        return imbalance_rps

    def shoot_via_tank_api(self, config):
        log.info('starting shoot_via_tank_api')
        shoot_task = ShootViaTankapi(
            self,
            use_public_tanks=False,
            config_source='resource',
            tanks=['sas2-7111-25e-all-rcloud-tanks-30169.gencfg-c.yandex.net:30169',
                   'sas1-8786-a4e-all-rcloud-tanks-30169.gencfg-c.yandex.net:30169',
                   'sas1-0147-all-rcloud-tanks-30169.gencfg-c.yandex.net:30169',
                   'sas1-0021-all-rcloud-tanks-30169.gencfg-c.yandex.net:30169'],
            config_resource=config,
            ammo_source='in_config'
        )
        shoot_task.enqueue()
        log.debug('saving last_shoot_task_id={}'.format(shoot_task.id))
        self.Context.last_shoot_task_id = str(shoot_task.id)
        log.debug('saved last_shoot_task_id={}'.format(self.Context.last_shoot_task_id))
        raise sdk2.WaitTask([shoot_task.id], ctt.Status.Group.SUCCEED, wait_all=True,
                            timeout=3600)

    def push_to_solomon(self, sensor_label, value):
        from sandbox.projects.market.checkout.PushToSolomon import PushToSolomon
        log.info('starting push_to_solomon')
        push_task = PushToSolomon(
            self,
            project_id='market-content-api',
            service_name='antifraud_orders',
            cluster_name='production',
            sensor_label=sensor_label,
            oauth_token_vault_key='CONTENT-API-SOLOMON-TOKEN',
            value=float(value)
        )
        push_task.enqueue()
        return push_task.id

    def get_task_resource(self, task_id, resource_type):
        resource = sdk2.Resource.find(resource_type=resource_type, task_id=task_id).first()
        if not resource:
            raise TaskFailure('resource {} of task {} not found'.format(resource_type, task_id))
        return resource

    def get_task(self, task_id):
        task = sdk2.Task.find(id=task_id, status=ctt.Status.SUCCESS).first()
        if not task:
            raise TaskFailure('Task id={} with status SUCCESS not found'.format(task_id))
        return task
