from os.path import join
from urlparse import urljoin
import datetime
import json
import logging
import os
import re
import shutil
import time

from sandbox import common
from sandbox import sdk2
from sandbox.sandboxsdk.paths import get_logs_folder
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.projects.common import network


class PortHolder:
    def __init__(self):
        self._taken_ports = set()
        self._ports_dict = {}
        self._retries_count = 50

    def get_port(self, server):
        if server not in self._ports_dict:
            port = network.get_free_port(max_tries=self._retries_count)
            self._taken_ports.add(port)
            self._ports_dict[server] = port
            logging.info('server "%s" has reserved port %d', server, port)
        return self._ports_dict[server]


class ServerManager:
    def __init__(self, sandbox_task, env):
        self._port_holder = PortHolder()
        self._remote_urls = {}
        self._remote_ports = {}
        self.sandbox_task = sandbox_task
        self.env = env

    def get_port(self, server):
        if server in self._remote_ports:
            return self._remote_ports[server]
        return self._port_holder.get_port(server)

    def get_url(self, server):
        if server in self._remote_urls:
            return self._remote_urls[server]
        return 'http://localhost:{}/'.format(self.get_port(server))

    def register_remote_port(self, server, port):
        self._remote_ports[server] = port

    def register_remote_url(self, server, url):
        self._remote_urls[server] = url


class Server(object):
    def __init__(self, name, server_mgr, working_dir):
        self._name = name
        self._server_mgr = server_mgr
        self._env = server_mgr.env.copy()
        self._port = server_mgr.get_port(self.SERVNAME)
        self._url = server_mgr.get_url(self.SERVNAME)
        self._dir = working_dir
        self._process = None

    def _launch_cmd(self):
        raise NotImplementedError()

    def _status_cmd(self):
        raise NotImplementedError()

    def _ping(self, name, check_status_cmd, expected_answer, timeout=15, raise_if_fail=True):
        if self._process is not None and self._process.poll() is not None:
            if raise_if_fail:
                # nothing to do here - the task can't work without server
                raise common.errors.TaskFailure('Server {} failed to start'.format(name))
            return False

        logging.info('Pinging %s', name)
        proc = sp.Popen([check_status_cmd], shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
        try:
            outs, errs = proc.communicate(timeout=timeout)
        except Exception:
            logging.info('Failed to ping %s', name)
            return False
        else:
            if proc.returncode is not None and proc.returncode != 0:
                logging.info('Failed to ping %s', name)
                return False
            if expected_answer not in [outs, errs]:
                logging.info('Wrong answer when pinging %s', name)
                return False
        logging.info('Pinging %s successful', name)
        return True

    def try_ping(self, timeout):
        try:
            check_status_cmd, expected_answer = self._status_cmd()
            return self._ping(self.SERVNAME, check_status_cmd, expected_answer, timeout=5, raise_if_fail=False)
        except NotImplementedError:
            # it's okay, server not needed to be pinged
            return True

    def start(self):
        logging.info('Starting server "%s"...', self._name)

        cmd = self._launch_cmd()

        with sdk2.helpers.ProcessLog(self._server_mgr.sandbox_task, logger=self.SERVNAME) as pl:
            self._process = sp.Popen(
                cmd,
                cwd=self._dir,
                env=self._env,
                stdout=pl.stdout,
                stderr=pl.stderr
            )

        logging.info('Server "%s" is starting', self._name)
        try:
            check_status_cmd, expected_answer = self._status_cmd()
            while not self._ping(self.SERVNAME, check_status_cmd, expected_answer):
                time.sleep(5)
        except NotImplementedError:
            pass
        logging.info('Server "%s" started', self._name)

    def stop(self):
        if self._process is not None:
            logging.info('Stop server "%s", pid = %d', self._name, self._process.pid)
            # TODO(sparkle): it could be terminate(), but Joker and Megamind won't stop on SIGTERM
            # is it better to call proper functions instead of SIGKILL
            self._process.kill()
            self._process.wait()

    def rtlog_path(self):
        return None


# Main differences from Server:
# No '_launch' and 'stop' implementations
# 'ping' raises Exception immediatelly in case of failure (the remote should be fully up!)
# It should pass to server_mgr its URL
class RemoteServer(Server):
    def __init__(self, name, server_mgr, url):
        self._name = name
        self._server_mgr = server_mgr
        self._url = url
        self._process = None
        self._server_mgr.register_remote_url(name, url)

    def _after_start(self):
        pass

    def start(self):
        logging.info('Checking remote server "%s"...', self._name)
        try:
            check_status_cmd, expected_answer = self._status_cmd()
            if not self._ping(self.SERVNAME, check_status_cmd, expected_answer):
                raise common.errors.TaskFailure('Remote server %s failed to ping', self.SERVNAME)
        except NotImplementedError:
            pass
        logging.info('Remote server "%s" checked', self._name)
        self._after_start()


class RedisServer(Server):
    SERVNAME = 'redis'

    def __init__(self, server_mgr, working_dir):
        super(RedisServer, self).__init__(self.SERVNAME, server_mgr, working_dir)
        self._make_conf_file()

    # copy 'redis.conf.base' -> 'redis.conf' and write the port at the end
    def _make_conf_file(self):
        base_config_path = join(self._dir, 'redis.conf.base')
        config_path = join(self._dir, 'redis.conf')
        shutil.copyfile(base_config_path, config_path)
        with open(config_path, 'a') as f:
            f.write('port {}'.format(self._port))

    def _launch_cmd(self):
        cmd = [
            join(self._dir, 'redis-server'),
            'redis.conf'
        ]
        return cmd


class BassServer(Server):
    SERVNAME = 'bass'

    def __init__(self, server_mgr, working_dir, default_headers, joker_host_port, ydb_endpoint, ydb_database):
        super(BassServer, self).__init__(self.SERVNAME, server_mgr, working_dir)
        self._default_headers = default_headers
        self._joker_host_port = joker_host_port
        self._ydb_endpoint = ydb_endpoint
        self._ydb_database = ydb_database

    def _launch_cmd(self):
        with open(join(self._dir, 'bass_configs', 'localhost_config.json'), 'r') as f:
            config = json.load(f)

        headers = []
        for key, value in self._default_headers.items():
            headers.append({'Name': key, 'Value': value})

        fetcher_proxy = {
            'HostPort': self._joker_host_port,
            'Headers': headers
        }

        config['FetcherProxy'] = fetcher_proxy
        config['YDb']['Endpoint'] = self._ydb_endpoint
        config['YDb']['DataBase'] = self._ydb_database
        config['HttpThreads'] = 20
        config['SearchThreads'] = 20
        config['SetupThreads'] = 20

        def set_json_value(obj, key, value):
            if isinstance(obj, dict):
                if key in obj:
                    obj[key] = value
                for item in obj.values():
                    set_json_value(item, key, value)
            elif any(isinstance(obj, t) for t in (list, tuple)):
                for item in obj:
                    set_json_value(item, key, value)

        set_json_value(config, 'Timeout', '30s')

        new_config_path = join(get_logs_folder(), 'new_bass.json')
        with open(new_config_path, 'w') as f:
            json.dump(config, f)

        cmd = [
            join(self._dir, 'bin', 'bass_server'),
            new_config_path,
            '--port', str(self._port),
            '-V', 'EventLogFile=' + join(get_logs_folder(), 'current-bass-rtlog'),
            '-V', 'ENV_GEOBASE_PATH={}'.format(join(self._dir, 'geodata6.bin')),
            '--logdir', get_logs_folder()
        ]
        return cmd

    def _status_cmd(self):
        check_status_cmd = 'curl {}ping'.format(self._url)
        return (check_status_cmd, b'pong')

    def rtlog_path(self):
        return join(get_logs_folder(), 'current-bass-rtlog')


class VinsServer(Server):
    SERVNAME = 'vins'

    def __init__(self, server_mgr, working_dir, default_headers):
        super(VinsServer, self).__init__(self.SERVNAME, server_mgr, working_dir)
        self._default_headers = json.dumps(default_headers)

    def _launch_cmd(self):
        self._vmtouch_resources()

        bass_url = self._server_mgr.get_url('bass')
        joker_url = self._server_mgr.get_url('joker')

        self._env['VINS_PROXY_URL'] = joker_url
        self._env['VINS_PROXY_SKIP'] = 'http://localhost'  # don't proxy all local services
        self._env['VINS_PROXY_DEFAULT_HEADERS'] = self._default_headers
        self._env['VINS_WORKERS_COUNT'] = str(10)
        # TODO(sparkle): if it will be a case, load this from queries
        self._env['VINS_NOW_TIMESTAMP'] = datetime.date(2020, 2, 20).strftime('%s')

        self._env['VINS_DEV_BASS_API_URL'] = bass_url

        redis_port = self._server_mgr.get_port('redis')
        self._env['VINS_REDIS_PORT'] = str(redis_port)

        self._env['VINS_RTLOG_FILE'] = join(get_logs_folder(), 'current-vins-rtlog')
        self._env['VINS_LOG_FILE'] = join(get_logs_folder(), 'vins.push_client.out')

        cmd = [
            join(self._dir, 'run-vins.py'),
            '-p', str(self._port),
            '--conf-dir', 'cit_configs',

            # TODO(sparkle) load this (and other additional args) from files
            '--env', 'shooting-ground',
            '--component', 'speechkit-api-pa',
            '-L'
        ]
        return cmd

    def _vmtouch_resources(self):
        resources_path = join(self._dir, 'resources')
        vmtouch_bin = join(resources_path, 'vmtouch')
        res_list = []
        for res in os.listdir(resources_path):
            res_list.append(join(resources_path, res))
        sp.Popen([vmtouch_bin, '-l', '-v', '-f'] + res_list, stderr=sp.STDOUT)

    def _status_cmd(self):
        check_status_cmd = 'curl {}ping'.format(self._url)
        return (check_status_cmd, b'Ok')

    def rtlog_path(self):
        return join(get_logs_folder(), 'current-vins-rtlog')


class MegamindServer(Server):
    SERVNAME = 'megamind'

    def __init__(self, server_mgr, working_dir):
        super(MegamindServer, self).__init__(self.SERVNAME, server_mgr, working_dir)

    def _launch_cmd(self):
        args = []

        vins_url = self._server_mgr.get_url('vins')
        bass_url = self._server_mgr.get_url('bass')
        joker_url = self._server_mgr.get_url('joker')
        sensors_port = self._server_mgr.get_port('sensors')

        args.extend(['--mon-service-port', str(sensors_port)])
        args.extend(['--service-sources-vins-url', vins_url])

        args.extend(['--scenarios-sources-bass-url', bass_url])
        args.extend(['--scenarios-sources-bass-apply-url', urljoin(bass_url, '/megamind/apply')])
        args.extend(['--scenarios-sources-bass-run-url', urljoin(bass_url, '/megamind/prepare')])

        args.extend(['--rtlog-filename', join(get_logs_folder(), 'current-megamind-rtlog')])
        args.extend(['--vins-like-log-file', join(get_logs_folder(), 'current-megamind-vins-like-log')])
        args.extend(['--geobase-path', join(self._dir, 'geodata6.bin')])

        old_configs_path = join(self._dir, 'megamind_configs', 'dev')
        changed_configs_path = join(self._dir, 'megamind_configs', 'changed_dev')
        shutil.copytree(old_configs_path, changed_configs_path)

        # read config
        main_config_path = join(changed_configs_path, 'megamind.pb.txt')
        with open(main_config_path, 'r') as f:
            config = f.read()

        # make new config with extended timeouts
        config = re.sub(r'TimeoutMs:\s*(\S+)', r'TimeoutMs: \g<1>00', config)
        config = re.sub(r'RetryPeriodMs:\s*(\S+)', r'RetryPeriodMs: \g<1>00', config)
        config = re.sub(r'localhost:86', 'localhost:{}'.format(self._server_mgr.get_port('bass')), config)
        config = re.sub(r'localhost:84', 'localhost:{}'.format(self._server_mgr.get_port('vins')), config)
        with open(main_config_path, 'w') as f:
            f.write(config)

        for root, _, files in os.walk(join(changed_configs_path, 'scenarios')):
            for f in files:
                with open(join(root, f), 'r') as fl:
                    data = fl.read()
                data = re.sub(r'localhost:86', 'localhost:{}'.format(self._server_mgr.get_port('bass')), data)
                data = re.sub(r'localhost:84', 'localhost:{}'.format(self._server_mgr.get_port('vins')), data)
                with open(join(root, f), 'w') as fl:
                    fl.write(data)

        num = 0
        while True:
            saved_changed_configs_path = join(get_logs_folder(), 'changed_dev_{}'.format(num))
            if not os.path.exists(saved_changed_configs_path):
                break
            num += 1
        shutil.copytree(changed_configs_path, saved_changed_configs_path)

        colon_pos = joker_url.rfind(':')
        cmd = [
            join(self._dir, 'bin', 'megamind_server'),
            '-c', main_config_path,
            '-p', str(self._port),
            '--via-proxy-host', joker_url[:colon_pos],  # host from "host:port/"
            '--via-proxy-port', joker_url[(colon_pos + 1):-1]  # port from "host:port/"
        ]
        cmd += args

        return cmd

    def _status_cmd(self):
        check_status_cmd = 'curl {}ping'.format(self._url)
        return (check_status_cmd, b'pong')

    def rtlog_path(self):
        return join(get_logs_folder(), 'current-megamind-rtlog')


class JokerRemoteServer(RemoteServer):
    SERVNAME = 'joker'

    def __init__(self, server_mgr, url, session_id, settings):
        super(JokerRemoteServer, self).__init__(self.SERVNAME, server_mgr, url)
        self._session_id = session_id
        self._settings = settings

    def _status_cmd(self):
        check_status_cmd = 'curl {}admin?action=ping'.format(self._url)
        return (check_status_cmd, b'pong')

    def _after_start(self):
        self._start_session()

    def _start_session(self):
        logging.info('start Joker session (it may exist already)')

        # format 'aa': 'bb' -> 'aa=bb'
        settings_strings = ['{}={}'.format(key, self._settings[key]) for key in self._settings]
        settings_strings.append('id={}'.format(self._session_id))
        settings_cgi = '&'.join(settings_strings)

        cmd = 'curl \'{}session?{}\''.format(self._url, settings_cgi)

        ret = sp.check_output([cmd], shell=True)
        logging.info('starting session "%s" response: %s', self._session_id, ret)
