from os.path import join
from threading import Thread
import Queue
import json
import logging
import os
import random
import time
import traceback

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sandboxsdk.paths import get_logs_folder
import sandbox.common.types.task as ctt

from sandbox.projects.vins.common.resources import (VinsPackage, MegamindHelperBinaries, MegamindRequests, MegamindPerformanceDiff)
import sandbox.projects.vins.common.constants as consts
import sandbox.projects.vins.common.engine as engine


class MegamindComparePerformance(sdk2.Task):
    ''' Compare Megamind performance on various inputs on different revisions '''

    class Requirements(sdk2.Task.Requirements):
        environments = [
            PipEnvironment('PyYAML')
        ]

        # Warning - disable this on dev machines
        client_tags = consts.CLIENT_TAGS
        cores = consts.CORES
        dns = consts.DNS
        disk_space = consts.DISK_SPACE
        ram = consts.RAM * 2  # two engines
        privileged = consts.PRIVILEGED

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('Common parameters') as common_parameters:
            thread_count = sdk2.parameters.Integer(
                'Working threads count',
                default=10,
                required=True
            )
        with sdk2.parameters.Group('Packages and resources') as packages_parameters:
            vins_package_old = sdk2.parameters.Resource(
                'Old VINS package',
                resource_type=VinsPackage,
                required=True
            )
            vins_package_new = sdk2.parameters.Resource(
                'New VINS package',
                resource_type=VinsPackage,
                required=True
            )
            helper_binaries = sdk2.parameters.Resource(
                'Helper binaries for functionality outside of Sandbox functions',
                resource_type=MegamindHelperBinaries,
                required=True
            )
            megamind_requests = sdk2.parameters.Resource(
                'Megamind requests',
                resource_type=MegamindRequests,
                required=True
            )
        with sdk2.parameters.Group('Joker server parameters') as joker_parameters:
            joker_cluster_name = sdk2.parameters.String(
                'Joker server cluster name',
                default='vla',
                required=True
            )
            joker_endpoint_set_id = sdk2.parameters.String(
                'Joker server endpoint set id',
                default='alice-joker-mocker-vla',
                required=True
            )
            joker_semaphore_name = sdk2.parameters.String('Joker semaphore name', default=None)
        with sdk2.parameters.Group('Tokens and databases') as tokens_and_databases_parameters:
            vault_bass_owner = sdk2.parameters.String(
                'BASS Vault token owner',
                default='BASS',
                required=True
            )
            vault_bass_name = sdk2.parameters.String(
                'BASS Vault token name',
                default='robot-bassist_vault_token',
                required=True
            )
            vault_oauth_owner = sdk2.parameters.String(
                'OAuth Vault token owner',
                default='BASS',
                required=True
            )
            vault_oauth_name = sdk2.parameters.String(
                'OAuth Vault token name',
                default='alice-diff-test-token',
                required=True
            )
            ydb_endpoint = sdk2.parameters.String(
                'YDb endpoint',
                default='ydb-ru.yandex.net:2135',
                required=True
            )
            ydb_database = sdk2.parameters.String(
                'YDb database',
                default='/ru/alice/prod/mocker',
                required=True
            )
            exps1 = sdk2.parameters.JSON('New experiments for old package')
            exps2 = sdk2.parameters.JSON('New experiments for new package')

        kill_timeout = 10 * 60 * 60  # 10 hours

    def on_enqueue(self):
        if self.Parameters.joker_semaphore_name:
            self.Requirements.semaphores = ctt.Semaphores(
                acquires=[
                    # acquire with weight 2, because we have two engines
                    ctt.Semaphores.Acquire(name=self.Parameters.joker_semaphore_name, weight=2)
                ],
                release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
            )

    def analyze_performance(self, responses):
        # flush to the file
        perf_file = join(get_logs_folder(), 'perf.txt')
        with open(perf_file, 'w') as f:
            while not responses.empty():
                old_perf, new_perf = responses.get()
                f.write('{} {}\n'.format(old_perf, new_perf))

        # make resource
        megamind_performance_diff = MegamindPerformanceDiff(
            task=self,
            description='Difference in megamind performance',
            path='diff_dir'
        )
        data = sdk2.ResourceData(megamind_performance_diff)
        os.mkdir(str(data.path))

        # run analyzer
        cmd = [
            join(self.binaries_package_dir, 'perf_analyzer'),
            '--perf-file', perf_file,
            '--analyze-file', 'analyze.txt'
        ]
        engine.HelperBinary(cmd, str(data.path)).call()

        data.ready()

    def get_session_id(self, megamind_requests_dir):
        with open(os.path.join(megamind_requests_dir, 'session_id.txt'), 'r') as f:
            session_id = f.read().rstrip('\n')
        return session_id + '_perf'

    def work_with_request_filename(self, request_filename):
        with open(request_filename, 'rb') as f:
            request_text = f.read()
        request = json.loads(request_text)

        req_id = request['header']['request_id']
        logging.debug('Checking request on file "%s"', request_filename)

        # insert oauth and rebuild query
        add_opts = request['request'].get('additional_options', None)
        if add_opts is None:
            add_opts = {}
            request['request']['additional_options'] = add_opts
        add_opts['oauth_token'] = self.oauth_token

        prev_exps = request['request'].get('experiments', {})
        old_exps = dict(prev_exps)
        new_exps = dict(prev_exps)

        if self.Parameters.exps1:
            for exp in self.Parameters.exps1:
                old_exps[exp] = self.Parameters.exps1[exp]

        if self.Parameters.exps2:
            for exp in self.Parameters.exps2:
                new_exps[exp] = self.Parameters.exps2[exp]

        request['request']['experiments'] = old_exps
        old_request_text = json.dumps(request)
        old_request_length = len(old_request_text)
        old_headers = self.engine_old.build_query_headers(request, req_id, old_request_length)

        request['request']['experiments'] = new_exps
        new_request_text = json.dumps(request)
        new_request_length = len(new_request_text)
        new_headers = self.engine_new.build_query_headers(request, req_id, new_request_length)

        if bool(random.getrandbits(1)):
            old_response_time = self.engine_old.estimate_request_time(old_headers, old_request_text)
            new_response_time = self.engine_new.estimate_request_time(new_headers, new_request_text)
        else:
            new_response_time = self.engine_new.estimate_request_time(new_headers, new_request_text)
            old_response_time = self.engine_old.estimate_request_time(old_headers, old_request_text)

        logging.debug('Old time: %s, new time: %s', str(old_response_time), str(new_response_time))

        return (old_response_time, new_response_time)

    def work(self, requests_filenames_queue, results_queue):
        logging.info('A thread started his work')
        while True:
            exit_thread = False
            try:
                request_filename = requests_filenames_queue.get_nowait()
            except Queue.Empty:
                exit_thread = True

            # exit thread or work with request
            if exit_thread:
                logging.info('A thread ended his work')
                break

            try:
                results_queue.put(self.work_with_request_filename(request_filename))
            except:
                logging.error('A thread has unexpected error: %s', traceback.format_exc())

            # take break before next request
            time.sleep(1)

    def on_execute(self):
        self.requests_filenames_queue = Queue.Queue()
        self.oauth_token = str(sdk2.Vault.data(self.Parameters.vault_oauth_owner, self.Parameters.vault_oauth_name))

        # load resources
        self.binaries_package_dir = engine.prepare_resource(self.Parameters.helper_binaries, 'helper_binaries_local')
        self.vins_package_dir_old = engine.prepare_resource(self.Parameters.vins_package_old, 'vins_package_old_local')
        self.vins_package_dir_new = engine.prepare_resource(self.Parameters.vins_package_new, 'vins_package_new_local')
        megamind_requests_dir = engine.prepare_resource(self.Parameters.megamind_requests, 'megamind_requests_local')

        # init engine
        def build_kwargs(vins_package_dir):
            return {
                # uncomment to use local Joker
                # 'joker_host': 'localhost',
                # 'joker_port': '13000',
                'vins_package_dir': vins_package_dir,
                'binaries_package_dir': self.binaries_package_dir,
                'secrets_token': sdk2.Vault.data(self.Parameters.vault_bass_owner, self.Parameters.vault_bass_name),
                'joker_cluster_name': self.Parameters.joker_cluster_name,
                'joker_endpoint_set_id': self.Parameters.joker_endpoint_set_id,
                'joker_session_id': self.get_session_id(megamind_requests_dir),
                'joker_settings': {'fetch_if_not_exists': 1, 'imitate_delay': 1},
                'ydb_endpoint': self.Parameters.ydb_endpoint,
                'ydb_database': self.Parameters.ydb_database
            }

        self.engine_old = engine.Engine(self, **build_kwargs(self.vins_package_dir_old))
        self.engine_old.start()

        self.engine_new = engine.Engine(self, **build_kwargs(self.vins_package_dir_new))
        self.engine_new.start()

        # send requests in multiple threads and save responses
        requests_files = [
            join(megamind_requests_dir, fname) for fname in os.listdir(megamind_requests_dir)
            if os.path.isfile(join(megamind_requests_dir, fname)) and fname.endswith('.json')
        ]

        results_queue = Queue.Queue()
        for fname in requests_files:
            self.requests_filenames_queue.put(fname)

        threads = []
        for _ in range(self.Parameters.thread_count):
            threads.append(Thread(target=self.work, args=(self.requests_filenames_queue, results_queue)))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        # analyze requests answers
        self.analyze_performance(results_queue)

        # graceful shutdown
        self.engine_old.stop()
        self.engine_new.stop()

    def on_terminate(self):
        logging.info("Terminating task")
        self.requests_filenames_queue.queue.clear()
