import sys
import socket
import random
import logging
import contextlib

from sandbox import common
import sandbox.common.types.misc as ctm

import sandbox.common.joint.client as joint_client
import sandbox.common.joint.errors as joint_errors
import sandbox.common.joint.socket as joint_socket

import sandbox.serviceq.types as qtypes
import sandbox.serviceq.config as qconfig
import sandbox.serviceq.errors as qerrors


class Client(object):
    """
    Service Q client
    """
    REDIRECTS_LIMIT = 1
    RPCClient = joint_client.RPCClient

    class MethodType(common.utils.Enum):
        SIMPLE = None
        GENERATOR = None
        DUPGENERATOR = None

    class Generator(object):
        """ Wrapper of duplex generators for support of redirects """

        def __init__(self, call, redirect):
            self.__call = call
            self.__generator = call.generator
            self.__redirect = redirect

        def __iter__(self):
            return self

        def next(self):
            try:
                return self.__generator.next()
            except qerrors.QRedirect:
                self.__redirect()
                raise qerrors.QRetry("Redirect to new PRIMARY")

        def send(self, value):
            try:
                return self.__generator.send(value)
            except qerrors.QRedirect:
                self.__redirect()
                raise qerrors.QRetry("Redirect to new PRIMARY")
            except StopIteration:
                pass

        def wait(self, timeout=None):
            return self.__call.wait(timeout=timeout)

    def __init__(self, config=None, log_level=None):
        """
        Constructor

        :param config: settings object
        """
        self._config = config or qconfig.Registry()
        self._server_port = self._config.serviceq.server.server.port
        self._primary_server_addr = None
        self._logger = logging.getLogger("qclient")
        if log_level is not None:
            self._logger.setLevel(log_level)
        self.__rpc_config = self._config.serviceq.joint.rpc
        self.__retry_sleep = self._config.serviceq.client.retry.sleep
        self.__timeout = self._config.serviceq.client.timeout

    def _rpc_client(self, server_addr=None):
        client = self.RPCClient(
            self.__rpc_config,
            *(server_addr if server_addr else ("localhost", self._server_port))
        )
        client.connect()
        return client

    @common.utils.singleton_property
    def primary_address(self):
        self.status(secondary=False)
        addr = self._primary_server_addr
        if addr is None:
            addr = (socket.gethostname(), self._server_port)
        return ":".join(map(str, addr))

    @common.utils.singleton_property
    def _local_rpc(self):
        return self._rpc_client()

    @common.utils.singleton_property
    def _secondary_rpc(self):
        replication_info = None
        for _ in range(self.REDIRECTS_LIMIT + 1):
            try:
                replication_info = sorted(
                    self._primary_rpc.call("replication_info").wait(timeout=self.__timeout).items(),
                    key=lambda _: _[1],
                    reverse=True
                )
                break
            except qerrors.QRedirect:
                self._redirect()
            except joint_errors.CallError:
                return self._rpc_client()
        if not replication_info:
            return self._rpc_client()
        secondary_addr = self.__parse_addr(random.choice(replication_info[:len(replication_info) // 2 + 1])[0])
        return self._rpc_client(secondary_addr)

    @common.utils.singleton_property
    def _primary_rpc(self):
        return self._rpc_client(self._primary_server_addr)

    @staticmethod
    def __parse_addr(addr):
        host, _, port = addr.partition(":")
        return host, int(port)

    def _redirect(self):
        ex = sys.exc_info()[1]
        self._logger.info("Received redirect to %s", ex.message)
        self._primary_server_addr = self.__parse_addr(ex.message)
        del self._primary_rpc

    def _rpc_call(self, method_name, method_type, secondary, *args, **kws):
        timeout = kws.pop("timeout", self.__timeout)
        ret = None
        if secondary or secondary is None:
            method_name = "secondary_{}".format(method_name)
        for _ in range(self.REDIRECTS_LIMIT + 1):
            try:
                try:
                    rpc = (
                        self._secondary_rpc
                        if secondary else (
                            self._local_rpc
                            if secondary is None else
                            self._primary_rpc
                        )
                    )
                except (socket.error, joint_errors.HandshakeTimeout) as ex:
                    raise qerrors.QRetry(str(ex))

                ret = rpc.call(method_name, *args, **kws)
                self._logger.info("{%s:%s} Calling %s method %r", ret.sid, ret.jid, method_type, method_name)
                if method_type == self.MethodType.SIMPLE:
                    return ret.wait(timeout=timeout)
                elif method_type == self.MethodType.GENERATOR:
                    return ret.iter(timeout=timeout)
                elif method_type == self.MethodType.DUPGENERATOR:
                    return self.Generator(ret.iter(timeout=timeout), self._redirect)
                else:
                    raise qerrors.QException(
                        "Unknown method type '{}' for method '{}'".format(method_type, method_name)
                    )
            except qerrors.QRedirect:
                self._redirect()
            except (
                qerrors.QRetry, joint_errors.RPCError, joint_errors.ServerError, joint_socket.EOF, socket.error
            ) as ex:
                self._logger.error(
                    "{%s.%s} Error calling %s method %s: %s",
                    ret and ret.sid, ret and ret.jid, method_type, method_name, ex
                )
                if not isinstance(ex, joint_errors.ServerError) and secondary is False:
                    self._primary_server_addr = None
                    del self._primary_rpc
                raise
        raise qerrors.QRetry("Number of redirects is exceeded ({})".format(self.REDIRECTS_LIMIT))

    def get_hosts(self, secondary=True):
        return qtypes.IndexedList.decode(self._rpc_call("get_hosts", self.MethodType.SIMPLE, secondary))

    def add_hosts(self, hosts):
        return qtypes.IndexedList.decode(self._rpc_call("add_hosts", self.MethodType.SIMPLE, False, hosts))

    def push(self, task_id, priority, hosts, task_info=None, score=None):
        if isinstance(hosts, (list, tuple)):
            hosts = sorted((h for h in hosts), key=lambda _: -_[0])
        return self._rpc_call(
            "push", self.MethodType.SIMPLE, False, task_id, priority, hosts,
            task_info=task_info, score=score
        )

    def sync(self, data, reset=False):
        return self._rpc_call(
            "sync", self.MethodType.SIMPLE, False,
            tuple(
                (
                    task_id, priority,
                    sorted((h for h in hosts), key=lambda _: -_[0]) if hosts is not None else None,
                    task_info
                )
                for task_id, priority, hosts, task_info in (
                    (list(item) + [None])[:4]
                    for item in data
                )
            ),
            reset=reset
        )

    @staticmethod
    def _decode_host_scores(hosts, host_scores, common_score):
        if common_score is None:
            task_hosts = [
                qtypes.TaskQueueHostsItem(score=score, host=hosts[host])
                for score, host in host_scores
            ]
        else:
            task_hosts = [
                qtypes.TaskQueueHostsItem(score=common_score, host=hosts[host])
                for host in (
                    host_scores
                    if isinstance(host_scores, list) else
                    hosts.indexes_from_bits(host_scores)
                )
            ]
        return task_hosts

    @classmethod
    def _unpack_queue(cls, data):
        items = []
        for item in data:
            if item is None:
                break
            items.append(item)
        else:
            return []
        hosts, owners, task_types, client_tags = data
        hosts = qtypes.IndexedList.decode(hosts)
        queue = []
        for (
            task_id, priority, host_scores, task_ref,
            (_, _, task_type, owner, enqueue_time, duration, tags),
            common_score
        ) in items:
            queue.append(
                qtypes.TaskQueueItem(
                    task_id=task_id,
                    priority=priority,
                    hosts=cls._decode_host_scores(hosts, host_scores, common_score),
                    task_ref=None,
                    task_info=qtypes.TaskInfo(
                        type=task_types[task_type],
                        owner=owners[owner],
                        enqueue_time=enqueue_time,
                        duration=duration,
                        client_tags=client_tags[tags] if tags is not None else None
                    ),
                    score=common_score,
                )
            )
        return queue

    @staticmethod
    def _unpack_queue_by_host(data):
        return [(task_id, priority, -score) for priority, score, task_id, task_ref in data if task_ref]

    @classmethod
    def _unpack_queue_by_task(cls, data):
        hosts, queue = data
        if hosts and queue:
            task_id, priority, host_scores, task_ref, _, common_score = queue
            if task_ref:
                return priority, cls._decode_host_scores(qtypes.IndexedList.decode(hosts), host_scores, common_score)
        return None, []

    def push_api_quota(self, delta):
        return self._rpc_call("push_api_quota", self.MethodType.SIMPLE, False, delta)

    def acquire_resource_lock(self, resource_id, host):
        return self._rpc_call("acquire_resource_lock", self.MethodType.SIMPLE, False, resource_id, host)

    def release_resource_lock(self, resource_id, host):
        return self._rpc_call("release_resource_lock", self.MethodType.SIMPLE, False, resource_id, host)

    def reset_api_consumption(self):
        return self._rpc_call("reset_api_consumption", self.MethodType.SIMPLE, False)

    def queue(self, raw=False, secondary=False):
        data = self._rpc_call("queue", self.MethodType.GENERATOR, secondary)
        return data if raw else self._unpack_queue(data)

    def queue_by_host(self, host, pool=None, raw=False, secondary=False):
        data = self._rpc_call("queue_by_host", self.MethodType.SIMPLE, secondary, host, pool)
        return data if raw else self._unpack_queue_by_host(data)

    def queue_by_task(self, task_id, raw=False, secondary=False):
        data = self._rpc_call("queue_by_task", self.MethodType.SIMPLE, secondary, task_id)
        return data if raw else self._unpack_queue_by_task(data)

    def task_queue(self, task_id, pool=None, secondary=False):
        return self._rpc_call("task_queue", self.MethodType.SIMPLE, secondary, task_id, pool)

    def validate(self, secondary=False):
        return self._rpc_call("validate", self.MethodType.SIMPLE, secondary)

    def ping(self, value, secondary=True):
        return self._rpc_call("ping", self.MethodType.SIMPLE, secondary, value)

    def resources(self, secondary=True):
        return self._rpc_call("resources", self.MethodType.SIMPLE, secondary)

    def create_semaphore(self, fields):
        sem_id, data = self._rpc_call("create_semaphore", self.MethodType.SIMPLE, False, fields)
        return sem_id, qtypes.Semaphore.decode(data)

    def set_api_quota(self, login, api_quota):
        return self._rpc_call("set_api_quota", self.MethodType.SIMPLE, False, login, api_quota)

    def get_api_quota(self, login):
        return self._rpc_call("get_api_quota", self.MethodType.SIMPLE, False, login)

    def get_api_quotas_table(self):
        return self._rpc_call("get_api_quotas_table", self.MethodType.SIMPLE, False)

    def get_web_api_quota(self):
        return self._rpc_call("get_web_api_quota", self.MethodType.SIMPLE, False)

    def set_web_api_quota(self, api_quota):
        return self._rpc_call("set_web_api_quota", self.MethodType.SIMPLE, False, api_quota)

    def get_api_consumption(self, login):
        return self._rpc_call("get_api_consumption", self.MethodType.SIMPLE, False, login)

    def get_web_api_consumption(self, login):
        return self._rpc_call("get_web_api_consumption", self.MethodType.SIMPLE, False, login)

    def update_semaphore(self, sem_id, fields):
        data = self._rpc_call("update_semaphore", self.MethodType.SIMPLE, False, sem_id, fields)
        return qtypes.Semaphore.decode(data)

    def delete_semaphore(self, sem_id):
        return self._rpc_call("delete_semaphore", self.MethodType.SIMPLE, False, sem_id)

    def release_semaphores(self, task_id, prev_status, status):
        return self._rpc_call("release_semaphores", self.MethodType.SIMPLE, False, task_id, prev_status, status)

    def semaphore_values(self, sem_ids):
        return self._rpc_call("semaphore_values", self.MethodType.SIMPLE, False, sem_ids)

    def semaphore_tasks(self, sem_id):
        return self._rpc_call("semaphore_tasks", self.MethodType.SIMPLE, False, sem_id)

    def semaphore_group(self, name):
        return self._rpc_call("semaphore_group", self.MethodType.SIMPLE, False, name)

    def snapshot(self, operation_id=None, operations_checksum=None):
        return self._rpc_call(
            "snapshot", self.MethodType.GENERATOR, False, operation_id, operations_checksum, timeout=float("inf")
        )

    def oplog(self, snapshot_id, node_addr):
        return self._rpc_call(
            "oplog", self.MethodType.DUPGENERATOR, False, snapshot_id, node_addr, timeout=float("inf")
        )

    def status(self, secondary=True):
        return self._rpc_call("status", self.MethodType.SIMPLE, secondary)

    def execution_completed(self, job_id):
        result = self._rpc_call("execution_completed", self.MethodType.SIMPLE, False, job_id)
        return [qtypes.FinishExecutionInfo(*info) for info in result]

    def current_consumptions(self, pool=None, secondary=True):
        return self._rpc_call("current_consumptions", self.MethodType.GENERATOR, secondary, pool)

    def recalculate_consumptions(self, pool=None, secondary=True):
        return self._rpc_call("recalculate_consumptions", self.MethodType.GENERATOR, secondary, pool)

    def dump_consumptions(self, pool=None, secondary=True):
        return self._rpc_call("dump_consumptions", self.MethodType.GENERATOR, secondary, pool)

    def reset_consumptions(self):
        return self._rpc_call("reset_consumptions", self.MethodType.SIMPLE, False)

    def calculate_consumptions(self):
        return self._rpc_call("calculate_consumptions", self.MethodType.SIMPLE, False)

    def set_quota(self, owner, quota, pool=None, use_cores=False):
        return self._rpc_call("set_quota", self.MethodType.SIMPLE, False, owner, quota, pool, use_cores)

    def quota(self, owner, pool=None, use_cores=False, return_defaults=True):
        return self._rpc_call("quota", self.MethodType.SIMPLE, False, owner, pool, use_cores, return_defaults)

    def owners_rating(self, owner=None, secondary=False, pool=None):
        data = self._rpc_call("owners_rating", self.MethodType.SIMPLE, secondary, owner, pool)
        return [
            [rating_owner, qtypes.OwnersRatingItem(*item)]
            for rating_owner, item in data
        ]

    def owners_rating_by_pools(self, owner=None, secondary=False):
        data = self._rpc_call("owners_rating_by_pools", self.MethodType.SIMPLE, secondary, owner)
        return {
            pool: [
                [rating_owner, qtypes.OwnersRatingItem(*item)]
                for rating_owner, item in rating
            ]
            for pool, rating in data.items()
        }

    def multiple_owners_quota(self, owners=None, pool=None, use_cores=False, return_defaults=True, secondary=False):
        data = self._rpc_call(
            "multiple_owners_quota", self.MethodType.SIMPLE, secondary, owners, pool, use_cores, return_defaults
        )
        return [
            [item[0], qtypes.QuotaItem(*item[1])]
            for item in data
        ]

    def multiple_owners_quota_by_pools(self, owners=None, use_cores=False, return_defaults=True, secondary=False):
        data = self._rpc_call(
            "multiple_owners_quota_by_pools", self.MethodType.SIMPLE, secondary, owners, use_cores, return_defaults
        )
        return {
            pool: [
                [item[0], qtypes.QuotaItem(*item[1])]
                for item in quotas
            ]
            for pool, quotas in data.items()
        }

    def set_parent_owner(self, owner, parent):
        return self._rpc_call("set_parent_owner", self.MethodType.SIMPLE, False, owner, parent)

    def parent_owners(self, owner=None, secondary=False):
        return self._rpc_call("parent_owners", self.MethodType.SIMPLE, secondary, owner)

    def last_quota_remnants(self):
        return self._rpc_call("last_quota_remnants", self.MethodType.SIMPLE, False)

    def start_profiler(self, secondary=False):
        return self._rpc_call("start_profiler", self.MethodType.SIMPLE, secondary)

    def stop_profiler(self, profile_format="CALLGRIND", secondary=False):
        return self._rpc_call("stop_profiler", self.MethodType.SIMPLE, secondary, profile_format)

    def semaphore_waiters(self, task_id, secondary=True):
        return self._rpc_call("semaphore_waiters", self.MethodType.SIMPLE, secondary, task_id)

    def semaphore_wanting(self, sem_ids=None, secondary=True):
        return self._rpc_call("semaphore_wanting", self.MethodType.GENERATOR, secondary, sem_ids)

    def semaphores(self, sem_ids=None, secondary=True):
        for sem_id, data in self._rpc_call("semaphores", self.MethodType.GENERATOR, secondary, sem_ids):
            yield sem_id, qtypes.Semaphore.decode(data)

    def prequeue_push(self, task_id):
        return self._rpc_call("prequeue_push", self.MethodType.SIMPLE, False, task_id)

    def prequeue_pop(self):
        return self._rpc_call("prequeue_pop", self.MethodType.SIMPLE, False)

    def contenders(self):
        return self._rpc_call("contenders", self.MethodType.SIMPLE, False)

    def get_runtime_option(self, name):
        return self._rpc_call("get_runtime_option", self.MethodType.SIMPLE, None, name)

    def set_runtime_option(self, name, value):
        return self._rpc_call("set_runtime_option", self.MethodType.SIMPLE, None, name, value)

    def task_to_execute_it(self, host, host_info, secondary=True):
        if secondary is None:
            if (
                self._config.common.installation not in ctm.Installation.Group.LOCAL and
                self.status(secondary=True) == qtypes.Status.PRIMARY
            ):
                del self._secondary_rpc
            delay = self.operation_id(secondary=False) - self.operation_id(secondary=True)
            if delay > self._config.serviceq.client.max_replication_delay:
                del self._secondary_rpc
            secondary = True
        return self._rpc_call("task_to_execute_it", self.MethodType.DUPGENERATOR, secondary, host, host_info)

    @contextlib.contextmanager
    def lock(self, lock_name):
        for _ in range(2):
            try:
                gen = self._rpc_call("lock", self.MethodType.DUPGENERATOR, False, lock_name=lock_name)
                yield gen.next()
                break
            except qerrors.QRetry:
                if _ == 1:
                    raise
        try:
            gen.next()
        except (qerrors.QTimeout, qerrors.QRetry, joint_errors.RPCError, joint_socket.EOF, socket.error) as ex:
            logging.warning("ServiceQ exception on releasing lock %s: %s", lock_name, ex)

    def task_to_execute(self, host, host_info):
        return self._rpc_call("task_to_execute", self.MethodType.DUPGENERATOR, False, host, host_info)

    def operation_id(self, only_applied=True, secondary=True):
        return self._rpc_call("operation_id", self.MethodType.SIMPLE, secondary, only_applied)

    def lock_jobs(self, jobs_ids):
        gen = self._rpc_call("lock_jobs", self.MethodType.DUPGENERATOR, False, jobs_ids)
        gen.next()
        return gen

    def replication_info(self):
        return self._rpc_call("replication_info", self.MethodType.SIMPLE, False)

    def get_unwanted_contenders(self, secondary=False):
        return self._rpc_call("get_unwanted_contenders", self.MethodType.SIMPLE, secondary)

    def set_unwanted_contenders(self, unwanted_contenders):
        return self._rpc_call("set_unwanted_contenders", self.MethodType.SIMPLE, False, unwanted_contenders)

    def add_quota_pool(self, pool, tags):
        return self._rpc_call("add_quota_pool", self.MethodType.SIMPLE, False, pool, tags)

    def update_quota_pool(self, pool, tags=None, default=None):
        return self._rpc_call("update_quota_pool", self.MethodType.SIMPLE, False, pool, tags, default)

    def quota_pools(self, secondary=False):
        return qtypes.QuotaPools.decode(self._rpc_call("quota_pools", self.MethodType.SIMPLE, secondary))
