from os.path import join
from threading import Thread
import Queue
import json
import logging
import os
import time
import traceback
import uuid

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sdk2.helpers import subprocess as sp
import sandbox.common.types.task as ctt

from sandbox.projects.vins.common.resources import (VinsPackage, MegamindHelperBinaries, MegamindRequests, MegamindResponses)
import sandbox.projects.vins.common.constants as consts
import sandbox.projects.vins.common.engine as engine


class MegamindGetResponses(sdk2.Task):
    ''' Get Megamind responses on various inputs '''

    class Requirements(sdk2.Task.Requirements):
        environments = [
            PipEnvironment('PyYAML')
        ]

        # Warning - disable this on dev machines
        client_tags = consts.CLIENT_TAGS
        cores = consts.CORES
        dns = consts.DNS
        disk_space = consts.DISK_SPACE
        ram = consts.RAM
        privileged = consts.PRIVILEGED

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('Common parameters') as common_parameters:
            publish_responses_and_logs = sdk2.parameters.Bool(
                'Publish responses and logs',
                default=True,
                required=True
            )
            thread_count = sdk2.parameters.Integer(
                'Working threads count',
                default=10,
                required=True
            )
        with sdk2.parameters.Group('Packages and resources') as packages_parameters:
            vins_package = sdk2.parameters.Resource(
                'VINS package',
                resource_type=VinsPackage,
                required=True
            )
            helper_binaries = sdk2.parameters.Resource(
                'Helper binaries for functionality outside of Sandbox functions',
                resource_type=MegamindHelperBinaries,
                required=True
            )
            megamind_requests = sdk2.parameters.Resource(
                'Megamind requests',
                resource_type=MegamindRequests,
                required=True
            )
        with sdk2.parameters.Group('Joker server parameters') as joker_parameters:
            joker_cluster_name = sdk2.parameters.String(
                'Joker server cluster name',
                default='vla',
                required=True
            )
            joker_endpoint_set_id = sdk2.parameters.String(
                'Joker server endpoint set id',
                default='alice-joker-mocker-vla',
                required=True
            )
            joker_semaphore_name = sdk2.parameters.String('Joker semaphore name', default=None)
        with sdk2.parameters.Group('Tokens and databases') as tokens_and_databases_parameters:
            vault_bass_owner = sdk2.parameters.String(
                'BASS Vault token owner',
                default='BASS',
                required=True
            )
            vault_bass_name = sdk2.parameters.String(
                'BASS Vault token name',
                default='robot-bassist_vault_token',
                required=True
            )
            vault_oauth_owner = sdk2.parameters.String(
                'OAuth Vault token owner',
                default='BASS',
                required=True
            )
            vault_oauth_name = sdk2.parameters.String(
                'OAuth Vault token name',
                default='alice-diff-test-token',
                required=True
            )
            ydb_endpoint = sdk2.parameters.String(
                'YDb endpoint',
                default='ydb-ru.yandex.net:2135',
                required=True
            )
            ydb_database = sdk2.parameters.String(
                'YDb database',
                default='/ru/alice/prod/mocker',
                required=True
            )

        kill_timeout = 5 * 60 * 60  # 5 hours

    def on_enqueue(self):
        if self.Parameters.joker_semaphore_name:
            self.Requirements.semaphores = ctt.Semaphores(
                acquires=[
                    ctt.Semaphores.Acquire(name=self.Parameters.joker_semaphore_name)
                ],
                release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
            )

    def publish_responses_and_logs(self, requests, responses, histories):
        folder_name = 'responses_folder'

        binary_path = join(self.binaries_package_dir, 'split_by_reqid')
        logs_splitter_cmd = [binary_path, '--save', folder_name]
        for server in self.engine._servers_list:
            rtlog_path = server.rtlog_path()
            if rtlog_path is not None:
                logs_splitter_cmd.append(rtlog_path)

        sp.Popen(
            logs_splitter_cmd,
            cwd=os.getcwd()
        ).wait()

        folder_dir = join(os.getcwd(), folder_name)

        for req_id in responses:
            with open(join(folder_dir, req_id, 'response.json'), 'w') as f:
                f.write(json.dumps(responses[req_id], indent=4, ensure_ascii=False, sort_keys=True))

            with open(join(folder_dir, req_id, 'request.json'), 'w') as f:
                f.write(json.dumps(requests[req_id], indent=4, ensure_ascii=False, sort_keys=True))

            with open(join(folder_dir, req_id, 'histories.json'), 'w') as f:
                f.write(json.dumps(histories[req_id], indent=4, ensure_ascii=False, sort_keys=True))

        # save sensors data
        with open(join(folder_dir, 'sensors.json'), 'w') as f:
            sensors_data = self.engine.get_sensors_data()
            f.write(json.dumps(sensors_data, indent=4, ensure_ascii=False, sort_keys=True))

        # create resources
        resource = sdk2.ResourceData(MegamindResponses(
            task=self,
            description='Responses from MegamindGetResponses task #{}'.format(self.id),
            path=folder_dir
        ))
        resource.ready()

    def get_session_id(self, megamind_requests_dir):
        with open(os.path.join(megamind_requests_dir, 'session_id.txt'), 'r') as f:
            session_id = f.read().rstrip('\n')
        return session_id + '_diff'

    def work_with_request_filename(self, request_filename):
        with open(request_filename, 'rb') as f:
            request_text = f.read()
        request = json.loads(request_text)

        req_id = request['header']['request_id']
        logging.debug('Checking request on file "%s"', request_filename)

        # insert oauth and rebuild query
        add_opts = request['request'].get('additional_options', None)
        if add_opts is None:
            add_opts = {}
            request['request']['additional_options'] = add_opts
        add_opts['oauth_token'] = self.oauth_token
        request_text = json.dumps(request)
        request_length = len(request_text)

        group_id = uuid.uuid4()
        headers = self.engine.build_query_headers(request, req_id, request_length, group_id)
        try:
            response = json.loads(self.engine.send_request(headers, request_text))
            history = json.loads(self.engine.get_group_id_history(group_id))
        except:
            logging.error('A thread has unexpected error: %s', traceback.format_exc())

        # erase oauth and return answer
        request['request']['additional_options'].pop('oauth_token', None)
        return (req_id, request, response, history)

    def work(self, requests_filenames_queue, results_queue):
        logging.info('A thread started his work')
        while True:
            exit_thread = False
            try:
                request_filename = requests_filenames_queue.get_nowait()
            except Queue.Empty:
                exit_thread = True

            # exit thread or work with request
            if exit_thread:
                logging.info('A thread ended his work')
                break

            try:
                results_queue.put(self.work_with_request_filename(request_filename))
            except:
                logging.error('A thread has unexpected error: %s', traceback.format_exc())

            # take break before next request
            time.sleep(1)

    def on_execute(self):
        self.requests_filenames_queue = Queue.Queue()
        self.oauth_token = str(sdk2.Vault.data(self.Parameters.vault_oauth_owner, self.Parameters.vault_oauth_name))

        # load resources
        self.binaries_package_dir = engine.prepare_resource(self.Parameters.helper_binaries, 'helper_binaries_local')
        self.vins_package_dir = engine.prepare_resource(self.Parameters.vins_package, 'vins_package_local')
        megamind_requests_dir = engine.prepare_resource(self.Parameters.megamind_requests, 'megamind_requests_local')

        # init engine
        kwargs = {
            # uncomment to use local Joker
            # 'joker_host': 'localhost',
            # 'joker_port': '13000',
            'vins_package_dir': self.vins_package_dir,
            'binaries_package_dir': self.binaries_package_dir,
            'secrets_token': sdk2.Vault.data(self.Parameters.vault_bass_owner, self.Parameters.vault_bass_name),
            'joker_cluster_name': self.Parameters.joker_cluster_name,
            'joker_endpoint_set_id': self.Parameters.joker_endpoint_set_id,
            'joker_session_id': self.get_session_id(megamind_requests_dir),
            'joker_settings': {'fetch_if_not_exists': 1},
            'ydb_endpoint': self.Parameters.ydb_endpoint,
            'ydb_database': self.Parameters.ydb_database
        }

        self.engine = engine.Engine(self, **kwargs)
        self.engine.start()

        # send requests in multiple threads and save responses
        requests_files = [
            join(megamind_requests_dir, fname) for fname in os.listdir(megamind_requests_dir)
            if os.path.isfile(join(megamind_requests_dir, fname)) and fname.endswith('.json')
        ]

        results_queue = Queue.Queue()
        for fname in requests_files:
            self.requests_filenames_queue.put(fname)

        threads = []
        for _ in range(self.Parameters.thread_count):
            threads.append(Thread(target=self.work, args=(self.requests_filenames_queue, results_queue)))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        # publish responses if needed
        if self.Parameters.publish_responses_and_logs:
            requests = {}
            responses = {}
            histories = {}
            while not results_queue.empty():
                req_id, request, response, history = results_queue.get()
                requests[req_id] = request
                responses[req_id] = response
                histories[req_id] = history
            self.publish_responses_and_logs(requests, responses, histories)

        # graceful shutdown
        self.engine.stop()

    def on_terminate(self):
        logging.info("Terminating task")
        self.requests_filenames_queue.queue.clear()
