# -*- coding: utf-8 -*-

import os
import time
import copy
import logging
import requests
import sandbox.sandboxsdk.paths as paths
import sandbox.sandboxsdk.channel as sdk_channel
from threading import Thread, Condition
from sandbox.projects import resource_types as rt
from sandbox.projects.common.fusion.runner import FusionRunner
from sandbox.projects.common.fusion.task import FusionTestTask
from sandbox.projects.common.fusion.task import FusionParamsDescription
from sandbox.projects.common.fusion.distributor import DefaultDistributorParams as DistributorParams
from sandbox.projects.common.search.components import get_fusion_search, FusionSearch, DefaultFusionParams as FusionParams
from sandbox.projects.common.search.components import FUSION_DB_SOURCE_TITLE, DEFAULT_BASESEARCH_PORT
from sandbox.sandboxsdk.parameters import SandboxUrlParameter, SandboxIntegerParameter, SandboxBoolParameter
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.common.types.task import LogName


class RunRefreshParams(object):
    class DistributorsConfig(SandboxUrlParameter):
        name = 'runrefresh_distributors_config_url'
        description = 'distributors.conf'
        group = 'Refresh params'
        required = False

    class ShardId(SandboxIntegerParameter):
        name = 'runrefresh_shardid'
        description = 'Shard id (leave empty to run all shards)'
        group = 'Refresh params'
        required = False

    class ShardsNum(SandboxIntegerParameter):
        name = 'runrefresh_shardnum'
        description = 'Number of shards'
        group = 'Refresh params'
        required = False

    class SaveDb(SandboxBoolParameter):
        name = 'runrefresh_savedb'
        default_value = True
        description = 'Save generated index as a resource'
        group = 'Refresh params'
        required = False

    class ShardSave(SandboxIntegerParameter):
        name = 'runrefresh_shardsave'
        description = 'Shard to save'
        group = 'Refresh params'
        required = False

    class SingleHost(SandboxBoolParameter):
        name = 'runrefresh_singlehost'
        default_value = False
        description = 'Run all instances on the same host'
        group = 'Refresh params'
        required = False

    class IdleTime(SandboxIntegerParameter):
        name = 'runrefresh_idletime'
        default_value = 0
        description = 'Extra sleep'
        group = 'Refresh params'
        required = False

    class MaxDocs(SandboxIntegerParameter):
        name = "db_max_docs"
        description = 'Required number of documents'  # description is different from 'db_max_docs' in TestFusion
        default_value = 20000
        group = FUSION_DB_SOURCE_TITLE

    params = (DistributorsConfig, ShardId, ShardsNum, IdleTime, SaveDb, SingleHost, MaxDocs, )
    _inherited_params = FusionParams.params + DistributorParams.params

    @classmethod
    def all_params(cls):
        _all_params = list(cls.params) + list(cls._inherited_params)
        _all_params.sort(key=lambda x: x.group, reverse=True)

        for i, o in enumerate(_all_params):
            if o.name == "db_max_docs" and o is not RunRefreshParams.MaxDocs:
                _all_params.remove(o)
                break
        return tuple(_all_params)


def fetch_http_file(url, path):
    try:
        req = requests.get(url, verify=False, stream=True, timeout=10)
        req.raise_for_status()  # Throws on http error
        with open(path, 'wb') as fd:
            for chunk in req.iter_content(1024):
                fd.write(chunk)

    except Exception as e:
        raise SandboxTaskFailureError(e)


class RunRefreshInstance(FusionTestTask):
    """
        Runs a RTYServer instance, which is controlled through remote commands (shutdown etc)
    """
    type = 'RUN_REFRESH_INSTANCE_00'

    input_parameters = RunRefreshParams.all_params()

    execution_space = 30 * 1024

    def __init__(self, *args, **kwargs):
        FusionTestTask.__init__(self, *args, **kwargs)
        self._last_error = ""

    #
    # Working with parameters
    #
    def use_external_distributors(self):
        return bool(self.ctx.get(RunRefreshParams.DistributorsConfig.name))

    def get_shard_id(self, default=0):
        vl = self.ctx.get(RunRefreshParams.ShardId.name, default)
        return int(vl if vl is not None else default)

    def get_shards_num(self):
        vl = self.ctx.get(RunRefreshParams.ShardsNum.name)
        if vl is None:
            return self.get_shard_id() + 1
        return int(vl)

    def get_shard_save(self):
        vl = self.ctx.get(RunRefreshParams.ShardSave.name)
        return vl if vl is not None else 0

    def get_save_db(self):
        vl = self.ctx.get(RunRefreshParams.SaveDb.name, False)
        return bool(vl)

    def get_single_host_opt(self):
        vl = self.ctx.get(RunRefreshParams.SingleHost.name, False)
        return bool(vl)

    def is_cluster_mode(self):
        shards_num = self.get_shards_num()
        shard_id = self.get_shard_id(-1)
        return shards_num > 1 and shard_id < 0

    def get_duration(self):
        duration = self.ctx.get(RunRefreshParams.IdleTime.name, 0)
        return int(duration)

    _shard_port_shift = 4

    def get_port_for_shard(self, shard_id):
        return DEFAULT_BASESEARCH_PORT + self._shard_port_shift * (shard_id + 2)  # "+2" is a trick to avoid firewall blocking(?) of port 17175

    @staticmethod
    def get_full_host(host):
        if '.' in host:
            return host
        elif not host:
            return "<unknown>"
        else:
            return host + ".search.yandex.net"

    #
    # Helper routines
    #
    def create_distributors_conf(self, path):
        if self.use_external_distributors():
            if self.use_distributor():
                raise SandboxTaskFailureError("You can't use both '{}' and '{}' parameters at the same time"
                                              .format(DistributorParams.EnableDistributor.description,
                                                      RunRefreshParams.DistributorsConfig.description))
            url = self.ctx[RunRefreshParams.DistributorsConfig.name]
            fetch_http_file(url, path)

        else:
            local_dist_conf = """
                <Replica>
                    Servers: localhost:${DISTRIBUTOR_PORT and DISTRIBUTOR_PORT or 20000}
                    Id: 0
                    Priority: 0
                </Replica>
                DistributorServers : localhost:${DISTRIBUTOR_PORT and DISTRIBUTOR_PORT or 20000}\n"""
            with open(path, 'wb') as fd:
                if self.use_distributor():
                    fd.write(local_dist_conf)

    def create_db_resource(self, index_dir):
        resource = self.create_resource(
            self.descr,
            index_dir,
            rt.RTYSERVER_INDEX_DIR,
            attributes={"test_fusion": "yes"}
        )
        return resource

    def _run_mem_profiler(self):
        pass  # Workaround for multiple profilers in a task

    class StartCond(object):
        def __init__(self, state=False):
            self._cv = Condition()
            self.state = state

        def __enter__(self):
            self._cv.acquire()

        def __exit__(self, type, value, traceback):
            self._cv.release()

        def wait(self, timeout):
            while not self.state:
                self._cv.wait(timeout)

        def set(self, state):
            self._cv.acquire()
            if self.state != state:
                self.state = state
                self._cv.notifyAll()
            self._cv.release()

    def get_fusion(self, shard_id, port, get_db, max_mem_documents=None, default_wait=None, run=True, event_log=False):
        FusionSearch.run_mem_profiler = self._run_mem_profiler

        fusion = get_fusion_search(get_db=get_db,
                                   shard=shard_id,
                                   shards_count=self.get_shards_num(),
                                   max_documents=max_mem_documents,
                                   default_wait=default_wait,
                                   event_log=event_log,
                                   port=port)

        return fusion

    #
    # Main (single instance) routine
    #

    def run_instance(self, shard_id, save_db, signal=None):
        if self.use_external_distributors():
            max_mem_documents = min(max(self.get_max_docs(), 100), 30000)
            max_disk_documents = min(max(self.get_max_docs()*2, 1000), 50000)
        else:
            max_mem_documents = self.get_max_docs()
            max_disk_documents = 100000

        os.environ["DAYS"] = str(3)
        os.environ["LOCATION"] = "MSK"
        os.environ["TimeToLiveSec"] = str(300)
        os.environ["MaxDiskDocuments"] = str(max_disk_documents)

        get_db = self.is_db_resource_required()
        fusion = self.get_fusion(shard_id=shard_id,
                                 port=self.get_port_for_shard(shard_id),
                                 get_db=get_db,
                                 max_mem_documents=max_mem_documents)
        if signal:
            signal.set(True)

        if get_db:
            self.init_fusion_with_db(fusion)  # Waits until the db is loaded
        else:
            self.init_fusion_empty(fusion)  # Waits until server is started

        runner = FusionRunner(fusion,
                              self.use_distributor(),
                              self.use_external_distributors(),
                              shard_id,
                              self.get_max_docs(),
                              FusionRunner.ShardBuilderParams())

        with runner:
            runner.run(self.get_duration())

            if save_db:
                db_path = runner.detach_db()
                self.create_db_resource(db_path)

    def run_instance_routine(self, shard_id, save_db, signal):
        attempts = 3
        while attempts:
            try:
                logging.info("[%r] Starting thread (retries left: %r)", shard_id, attempts)
                self.run_instance(shard_id, save_db, signal)
                return
            except Exception as ex:
                signal.set(True)
                signal = None  # Release the calling thread anyways
                attempts = attempts - 1
                logging.exception("instance routine {} failed with an exception".format(shard_id))
                if not attempts:
                    self.set_info("Unhandled exception: {0}".format(str(ex)))
                    self._last_error = str(ex)
                else:
                    logging.info("instance routine {} will be restarted")
                    time.sleep(30)

    def on_execute_instance(self):
        distributor = None
        if self.use_distributor():
            distributor = self.get_distributor()

        extra_cfg_dir = self.abs_path("config")
        paths.make_folder(extra_cfg_dir)
        os.environ["CONFIG_PATH"] = extra_cfg_dir

        distributors_file = os.path.join(extra_cfg_dir, "distributors.conf")
        self.create_distributors_conf(distributors_file)
        os.environ["DistributorsConfig"] = distributors_file

        if not self.is_cluster_mode():
            self.run_instance(self.get_shard_id(), self.get_save_db())
        else:
            threads = []
            for shard_id in reversed(xrange(self.get_shards_num())):
                cv = self.StartCond()
                t = Thread(target=self.run_instance_routine, args=(shard_id, (self.get_save_db() and shard_id == self.get_shard_save()), cv))
                t.refresh_shard = shard_id
                t.refresh_cv = cv
                threads.append(t)

            for t in threads:
                t.start()
                with t.refresh_cv:
                    t.refresh_cv.wait(62)

            for t in threads:
                t.join()

            if self._last_error:
                raise SandboxTaskFailureError(self._last_error)

        if distributor is not None:
            distributor.stop()

    #
    # Cluster routines
    #
    @staticmethod
    def get_task_log(task_id):
        """
            Gets a log for an incomplete task (differs from the base class method)
        """
        try:
            current_log = sdk_channel.channel.rest.server.resource.read(
                task_id=task_id,
                type=str(rt.TASK_LOGS),
                limit=1
            )["items"][0]
            return '/'.join((current_log["http"]["proxy"], LogName.COMMON))
        except:
            return None

    @staticmethod
    def grep_log(url, substring):
        try:
            log = requests.get(url, verify=False, timeout=5)
        except:
            return False
        return substring in log.text

    def log_task_started(self, task):
        srv_task = sdk_channel.channel.sandbox.server.get_task(task.id)
        host = srv_task.get('host')
        logging.info("Started {} ({}) # {}:{} ".format(task.inst_shard_id, task.id, self.get_full_host(host), task.inst_port))

    def wait_instances_ready(self, tasks, timeout, log_message):
        start_time = time.time()
        waiting_list = []
        for t in tasks:
            waiting_list.append({'task': t, 'log': self.get_task_log(t.id)})
        update = []
        while len(waiting_list) and (time.time() - start_time) <= timeout:
            update[:] = []
            for w in waiting_list:
                ready = False
                if not w['log']:
                    url = self.get_task_log(w['task'].id)
                    if url:
                        self.log_task_started(w['task'])
                        logging.debug("subtask {} log is {}".format(w['task'].id, url))
                    w['log'] = url

                if w['log']:
                    ready = self.grep_log(w['log'], log_message)
                if not ready:
                    update.append(w)
                else:
                    logging.info("Subtask {} is ready".format(w['id']))
            waiting_list[:] = update
            if len(waiting_list):
                time.sleep(30)

        return len(waiting_list) == 0

    def on_execute_cluster(self):
        """
             Executes multiple instances as the subtasks
        """
        instances = []
        for shard_id in xrange(self.get_shards_num()):
            sub_ctx = copy.deepcopy(self.ctx)
            sub_ctx["notify_via"] = ""
            sub_ctx[RunRefreshParams.ShardId.name] = shard_id
            sub_ctx[RunRefreshParams.SaveDb.name] = False if shard_id != self.get_shard_save() else self.ctx[RunRefreshParams.SaveDb.name]

            descr = "Shard {} for task {}({}) '{}'".format(shard_id, self.id, self.type, self.descr)

            target_host = None

            subtask = self.create_subtask(
                task_type=RunRefreshInstance.type,
                description=descr,
                input_parameters=sub_ctx,
                host=target_host,
                arch=self.arch,
                important=self.important
            )

            subtask.inst_shard_id = shard_id
            subtask.inst_port = self.get_port_for_shard(shard_id)
            instances.append(subtask)

        logging.info("Waiting for the instances to start")
        started = self.wait_instances_ready(instances, 12 * 3600, FusionRunner.REFRESH_IS_READY_MSG)
        if started:
            logging.info("The child tasks have started successfully")
        else:
            logging.error("Some child tasks are not ready, timed out")

        for subtask in instances:
            srv_task = sdk_channel.channel.sandbox.server.get_task(subtask.id)
            host = srv_task.get('host')
            if host:
                host = self.get_full_host(host)
            else:
                host = "<unknown>"

            logging.info("Instance {} is running at {}:{}", subtask.inst_shard_id, host, subtask.inst_port)

        logging.info("Waiting for the instances to finish")
        self.wait_tasks(instances, tuple(self.Status.Group.FINISH) + tuple(self.Status.Group.BREAK), True)

        logging.info("All child instances are gone. Finishing.")

    def on_execute(self):
        if self.is_cluster_mode() and not self.get_single_host_opt():
            self.on_execute_cluster()
        else:
            self.on_execute_instance()


RunRefreshInstance.__doc__ += FusionParamsDescription

__Task__ = RunRefreshInstance
