from os.path import join
from threading import (Thread, Lock)
import json
import logging
import os
import requests
import shutil
import sys
import time

from sandbox import common
from sandbox import sdk2
from sandbox.projects.common.decorators import retries
from sandbox.sdk2.helpers import subprocess as sp

import sandbox.projects.vins.common.servers as servers


# we should copy the package resource folder to a temporary folder (which has 'current work dir' as a parent)
# because Sandbox won't allow write to files outside 'current work dir'
def prepare_resource(resource, dir_name):
    logging.debug('preparing temporary directory "%s" for a package', dir_name)
    working_dir = join(os.getcwd(), dir_name)

    package_dir = str(sdk2.ResourceData(resource).path)
    shutil.rmtree(working_dir, ignore_errors=True)
    shutil.copytree(package_dir, working_dir, symlinks=True)

    # hack read-only mode
    for root, _, files in os.walk(working_dir):
        os.chmod(root, 0o777)
        for f in files:
            os.chmod(join(root, f), 0o777)

    logging.debug('preparing directory "%s" for a package finished', dir_name)
    return working_dir


class HelperBinary:
    def __init__(self, cmd, cwd=None):
        self._cmd = cmd
        self._cwd = cwd

    def call(self):
        kwargs = {}
        if self._cwd:
            kwargs['cwd'] = self._cwd
        proc = sp.Popen(
            self._cmd,
            stdout=sp.PIPE,
            stderr=sys.stderr,
            stdin=None,
            **kwargs
        )
        response = proc.stdout.readline()
        proc.wait()
        return response


class Engine(object):
    def __init__(self, task, **kwargs):
        self._task = task

        # packages
        self._vins_package_dir = kwargs['vins_package_dir']
        self._binaries_package_dir = kwargs['binaries_package_dir']

        # tokens and links
        self._secrets_token = kwargs['secrets_token']

        if 'joker_host' in kwargs and 'joker_port' in kwargs:
            self._joker_host = kwargs['joker_host']
            self._joker_port = kwargs['joker_port']
        else:
            cluster_name = kwargs['joker_cluster_name']
            endpoint_set_id = kwargs['joker_endpoint_set_id']
            self._joker_host, self._joker_port = self._get_joker_endpoint(cluster_name, endpoint_set_id)

        self._joker_session_id = kwargs['joker_session_id']
        self._joker_settings = kwargs['joker_settings']
        self._ydb_endpoint = kwargs['ydb_endpoint']
        self._ydb_database = kwargs['ydb_database']

        self._joker_host_port = '{}:{}'.format(self._joker_host, self._joker_port)
        self._joker_url = 'http://{}/'.format(self._joker_host_port)

        # accept requests only if servers are alive
        self._alive = False
        self._alive_lock = Lock()

    def _prepare_env(self):
        env = os.environ.copy()
        env['BASS_AUTH_TOKEN'] = self._secrets_token

        # add secrets for services
        cmd = [
            join(self._binaries_package_dir, 'secrets_discoverer'),
            '--token', env['BASS_AUTH_TOKEN'],
            '--key', 'sec-01cnbk6vvm6mfrhdyzamjhm4cm'
        ]
        secrets_json = HelperBinary(cmd).call()

        env.update(json.loads(secrets_json))
        env['MONGO_PASSWORD'] = env['VINS_MONGO_PASSWORD']

        return env

    def build_proxy_headers(self, test_id, timestamp=None, group_id=None):
        yandex_joker = 'prj={}&sess={}&test={}'.format('megamind', self._joker_session_id, test_id)
        if group_id:
            yandex_joker = '{}&group_id={}'.format(yandex_joker, group_id)

        proxy_values = {
            'x-yandex-via-proxy': self._joker_host_port,
            'x-yandex-joker': yandex_joker
        }
        if timestamp:
            proxy_values['x-yandex-fake-time'] = timestamp

        headers = {}
        for key, value in proxy_values.items():
            headers[key] = value
            for i in range(3):
                key = 'x-yandex-proxy-header-' + key
                headers[key] = value

        headers['x-yandex-via-proxy-skip'] = 'Vins Bass BassRun BassApply'

        return headers

    # headers that will be passed with query
    def build_query_headers(self, request, req_id, content_length, group_id=None):
        headers = {}

        # Common headers
        headers['Content-type'] = 'application/json'
        headers['x-alice-client-reqid'] = req_id
        headers['X-RTLog-Token'] = '{}${}${}'.format(str(int(time.time() * 10 ** 6)), req_id, req_id)
        headers['Host'] = self._joker_host_port
        headers['Connection'] = 'Keep-Alive'
        headers['Content-Length'] = str(content_length)

        # Joker headers
        timestamp = None
        if 'application' in request:
            timestamp = request['application'].get('timestamp', None)
        headers.update(self.build_proxy_headers(req_id, timestamp, group_id))

        return headers

    def _prepare_servers(self):
        env = self._prepare_env()
        self._server_mgr = servers.ServerManager(self._task, env)

        default_proxy_headers = self.build_proxy_headers('default')

        joker_server = servers.JokerRemoteServer(
            self._server_mgr, self._joker_url, self._joker_session_id, self._joker_settings
        )
        redis_server = servers.RedisServer(self._server_mgr, self._vins_package_dir)

        bass_server = servers.BassServer(
            self._server_mgr, self._vins_package_dir, default_proxy_headers,
            self._joker_host_port, self._ydb_endpoint, self._ydb_database
        )
        vins_server = servers.VinsServer(self._server_mgr, self._vins_package_dir, default_proxy_headers)
        megamind_server = servers.MegamindServer(self._server_mgr, self._vins_package_dir)

        # set the correct order to launch
        return [joker_server, redis_server, bass_server, vins_server, megamind_server]

    def _get_joker_endpoint(self, cluster_name, endpoint_set_id):
        cmd = [
            join(self._binaries_package_dir, 'endpoint_discoverer'),
            '--cluster-name', cluster_name,
            '--endpoint-set-id', endpoint_set_id
        ]
        response = HelperBinary(cmd).call()

        resp_json = json.loads(response)
        if resp_json['status'] != 'OK':
            raise common.errors.TaskFailure('Can\'t get Joker server endpoint - check the balancer')
        return (resp_json['host'], resp_json['port'])

    def _restart_servers(self):
        # stop old servers
        for server in self._servers_list:
            server.stop()

        # init and run new servers
        new_servers_list = self._prepare_servers()
        for server in new_servers_list:
            server.start()

        self._servers_list = new_servers_list

    def _any_server_silent(self):
        for server in self._servers_list:
            if not server.try_ping(timeout=10):
                logging.warn('Server %s is silent!', server.SERVNAME)
                return True
        logging.info('All servers are alive!')
        return False

    def _check_heartbeat(self):
        while self._alive:
            need_restart = self._any_server_silent()
            if need_restart:
                logging.warn('Servers need to be restarted! But waiting 1 minute...')

                with self._alive_lock:
                    time.sleep(60)

                    logging.info('Checking server pings again...')
                    need_restart = self._any_server_silent()
                    if need_restart:
                        logging.warn('Still some server is silent, restart servers...')
                        self._restart_servers()
                    else:
                        logging.info('Servers are alive, don\'t restart servers...')

            # next hearbeat in 30 seconds
            time.sleep(30)

    def start(self):
        self._servers_list = self._prepare_servers()
        for server in self._servers_list:
            server.start()

        self._alive = True
        self._heartbeat_thread = Thread(target=self._check_heartbeat)
        self._heartbeat_thread.start()

    def stop(self):
        with self._alive_lock:
            self._alive = False
            for server in self._servers_list:
                server.stop()
            self._heartbeat_thread.join()

    @retries(max_tries=2, delay=1, default_instead_of_raise=True, default_value='{"error": "No 2xx HTTP answer"}')
    def send_request(self, headers, data):
        # prevent sending when restarting servers
        with self._alive_lock:
            url = '{}speechkit/app/pa/'.format(self._server_mgr.get_url('megamind'))

        r = requests.post(url, data=data, headers=headers, timeout=15)
        # code is not 2xx
        if r.status_code / 100 != 2:
            raise Exception('Status code is not 2xx, it\'s {}'.format(r.status_code))
        r.encoding = 'utf-8'
        return r.text

    def get_group_id_history(self, group_id):
        # prevent sending when restarting servers
        with self._alive_lock:
            url = '{}history?group_id={}'.format(self._server_mgr.get_url('joker'), group_id)

        try:
            r = requests.get(url, timeout=15)
            # code is not 2xx
            if r.status_code / 100 != 2:
                raise Exception('Status code is not 2xx, it\'s {}'.format(r.status_code))
            r.encoding = 'utf-8'
            return r.text
        except Exception:
            return ""

    @retries(max_tries=1, delay=1, default_instead_of_raise=True, default_value=float('nan'))
    def estimate_request_time(self, headers, data):
        # prevent sending when restarting servers
        with self._alive_lock:
            url = '{}speechkit/app/pa/'.format(self._server_mgr.get_url('megamind'))

        passed_time = time.time()
        r = requests.post(url, data=data, headers=headers, timeout=5)
        passed_time = time.time() - passed_time

        # code is not 2xx
        if r.status_code / 100 != 2:
            raise Exception('Status code is not 2xx, it\'s {}'.format(r.status_code))

        return passed_time

    def get_sensors_data(self):
        self._server_mgr
        sensors_url = '{}counters/json'.format(self._server_mgr.get_url('sensors'))
        r = requests.get(sensors_url, timeout=5)
        if r.status_code == 200:
            return json.loads(r.text)
        return {}
