import time
import cPickle
import logging
import datetime
import operator as op
import itertools as it

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.user as ctu
import sandbox.common.types.client as ctc
import sandbox.common.types.statistics as ctst

from sandbox.yasandbox.database import mapping as mp


class Client(object):
    Model = mp.Client
    logger = None

    # Cache for speedup of the method `match_tags`
    PLATFORM_TAGS_CACHE = {}

    class TagsOp(common.utils.Enum):
        SET = None
        ADD = None
        REMOVE = None

    @staticmethod
    def status_after_host_reset(current):
        return {
            ctt.Status.ASSIGNED: ctt.Status.ENQUEUED,
            ctt.Status.STOPPING: ctt.Status.EXCEPTION,
            ctt.Status.RELEASING: ctt.Status.NOT_RELEASED,
        }.get(current, ctt.Status.TEMPORARY)

    @classmethod
    def initialize(cls):
        cls.Model.ensure_indexes()
        cls.logger = logging.getLogger(__name__)

    @classmethod
    def create(
        cls, hostname="", disk_free=0, updated=0, platform="", cpu_model="", cpu_cores=0, ram=0,
        info=None, pending_commands=None
    ):
        client = cls.Model()
        client.hostname = hostname.lower() if hostname else ""
        client.updated = updated

        if not client.hardware:
            client.hardware = cls.Model.Hardware()
            client.hardware.cpu = cls.Model.Hardware.CPU()

        client.hardware.disk_free = max(int(disk_free), 0)
        client.hardware.ram = ram
        client.hardware.cpu.model = cpu_model.lower()
        client.hardware.cpu.cores = cpu_cores
        client.pending_commands = pending_commands or []
        client.platform = platform.lower()
        info = info if info is not None else {"system": {}}
        client.context = cPickle.dumps(info)
        return client

    @staticmethod
    def id_from_fqdn(fqdn):
        parts = fqdn.split(".")
        if len(parts) > 2 and parts[0].startswith("bootstrap-sandbox"):
            # for clients deployed to YP, example: bootstrap-sandbox2-1.sas.yp-c.yandex.net => bootstrap-sandbox2-sas-1
            id_parts = parts[0].split("-")
            id_parts.insert(-1, parts[1])
            return "-".join(id_parts)
        else:
            return parts[0]

    @classmethod
    def get(cls, id_, create=False):
        """
        Return client model

        :param id_: client id
        :return: mapping.Client object with id_ or None if there is no client with id_
        """
        client = cls.Model.objects.with_id(id_)
        if client is None and create:
            client = cls.create(hostname=id_)
        return client

    @classmethod
    def load_list(cls, client_ids):
        """
        Return list of clients with ids from client_ids

        :param client_ids: list of ids of clients to load
        :return: list of mapping.Client models
        """
        result = cls.Model.objects.in_bulk(client_ids).values()
        return result

    @classmethod
    def list_query(
        cls,
        hostname=None, freespace=None, update_ts=None, arch=None, model=None, ncpu=None, ram=None,
        tags=None, search_query=None
    ):
        """
        Build database query with specified parameters

        :param hostname: one or more host names
        :param freespace: if this parameter lower than zero filter hosts with free space lower than this absolute value
                          if this parameter greater or equal than zero filter hosts with
                          free space higher or equal value
        :param update_ts: if this parameter lower than zero filter hosts with updated time lower than this
                          absolute value
                          if this parameter greater or equal than zero filter hosts with
                          updated time higher or equal value
        :param arch: host platform
        :param model: one or more cpu models
        :param ncpu: if this parameter lower than zero filter hosts with number of cpu cores lower than this
                     absolute value
                     if this parameter greater or equal than zero filter hosts with
                     number of cpu cores higher or equal value
        :param ram: if this parameter lower than zero filter hosts with ram lower than this absolute value
                    if this parameter greater or equal than zero filter hosts with ram higher or equal value
        :param tags: tags query
        :param search_query: return clients matching the text query
        :return: database queryset
        """

        query_args = []
        query_kws = {}
        if hostname:
            if isinstance(hostname, basestring):
                query_kws["hostname"] = hostname
            else:
                query_kws["hostname__in"] = hostname
        if arch:
            query_kws["platform"] = arch
        if model:
            if isinstance(model, basestring):
                query_kws["hardware__cpu__model"] = model
            else:
                query_kws["hardware__cpu__model__in"] = model

        if freespace:
            query_kws["hardware__disk_free__{0}".format("lt" if freespace < 0 else "gte")] = abs(freespace)

        if update_ts:
            query_kws[
                "updated__{0}".format("lt" if update_ts < 0 else "gte")
            ] = datetime.datetime.utcfromtimestamp(abs(update_ts))

        if ncpu:
            query_kws["hardware__cpu__cores__{0}".format("lt" if ncpu < 0 else "gte")] = abs(ncpu)

        if ram:
            query_kws["hardware__ram__{0}".format("lt" if ram < 0 else "gte")] = abs(ram)

        if tags:
            query_args.append(cls.Model.tags_query(tags))

        if search_query:
            query_kws["hostname__icontains"] = search_query

        return cls.Model.objects(*query_args, **query_kws)

    @classmethod
    def list(
        cls,
        hostname=None, freespace=None, update_ts=None, arch=None, model=None, ncpu=None, ram=None, count=False,
        tags=None, search_query=None
    ):
        """
        Return list of clients with specified parameters or count of this clients if count not equal None

        :param hostname: one or more host names
        :param freespace: if this parameter lower than zero filter hosts with free space lower than this absolute value
                          if this parameter greater or equal than zero filter hosts with
                          free space higher or equal value
        :param update_ts: if this parameter lower than zero filter hosts with updated time lower than this
                          absolute value
                          if this parameter greater or equal than zero filter hosts with
                          updated time higher or equal value
        :param arch: host platform
        :param model: one or more cpu models
        :param ncpu: if this parameter lower than zero filter hosts with number of cpu cores lower than this
                     absolute value
                     if this parameter greater or equal than zero filter hosts with
                     number of cpu cores higher or equal value
        :param ram: if this parameter lower than zero filter hosts with ram lower than this absolute value
                    if this parameter greater or equal than zero filter hosts with ram higher or equal value
        :param count: if set not equal to None function will return count of clients with specified parameters
        :param tags: tags query
        :param search_query: return clients matching the text query
        :return: database queryset
        """

        query = cls.list_query(hostname, freespace, update_ts, arch, model, ncpu, ram, tags, search_query)

        if count:
            result = query.count()
        else:
            result = list(query.order_by("hostname"))

        return result

    @classmethod
    def pending_commands(cls, pending_commands, request_commands=None):
        """
        Return set of pending service commands

        :param pending_commands: list of mapping.Client.Command of client commands
        :param request_commands: list of available commands
        :return: set of pending service commands
        """
        commands = set(command.command for command in pending_commands) if pending_commands else set()
        if request_commands is not None:
            commands &= set(request_commands)
        return commands

    @classmethod
    def pending_service_commands(cls, client, request_commands=None):
        """
        Return set of pending service commands

        :param client: mapping.Client model
        :param request_commands: list of available commands
        :return: set of pending service commands
        """

        return cls.pending_commands(client.pending_commands, request_commands)

    @classmethod
    def next_service_command(cls, client, reset=False, request_commands=None):
        """
        Checks whether client must be reloaded

        :param client: mapping.Client model
        :param reset: update client service commands
        :param request_commands: list of available commands
        :rtype: yasandbox.database.mapping.client.Client.Command
        """
        commands = cls.pending_service_commands(client, request_commands)
        cmd = next((_ for _ in ctc.ReloadCommand if _ in commands), None)
        if cmd is not None:
            index = next((_ for _, value in enumerate(client.pending_commands) if value.command == cmd), None)

            if index is not None:
                if reset:
                    ret = client.pending_commands.pop(index)
                else:
                    ret = client.pending_commands[index]
            else:
                ret = cls.Model.Command(command=cmd)

            if cmd == cls.Model.Reloading.SHUTDOWN:
                client.updated = datetime.datetime.utcfromtimestamp(
                    time.time() - common.config.Registry().server.web.mark_client_as_dead_after
                )
            if reset:
                cls.update(client)
            return ret

    @classmethod
    def update(cls, model, data=None, merge=False):
        """
        Set some field of model from data, update and dump context

        :param model: mapping.Client model with unpickled context
        :param data: data for update
        :param merge: if True make recursive merge of info and data dicts
        :return: updated mapping.Client model
        """

        if data is not None:
            sys_info = data["system"]
            if not model.hardware:
                model.hardware = cls.Model.Hardware()
                model.hardware.cpu = cls.Model.Hardware.CPU()

            if "free_space" in sys_info:
                model.hardware.disk_free = max(int(sys_info["free_space"]), 0)
            if "arch" in sys_info:
                model.platform = sys_info["arch"]
            if "cpu_model" in sys_info:
                model.hardware.cpu.model = sys_info["cpu_model"]
            if "ncpu" in sys_info:
                model.hardware.cpu.cores = int(sys_info["ncpu"])
            if "physmem" in sys_info:
                model.hardware.ram = int(sys_info["physmem"]) // (1024 ** 2)

            if merge:
                model.info.update(common.utils.merge_dicts(model.info, data))
            else:
                model.info.update(data)
            model.info["msg"] = ""
        model.context = cPickle.dumps(model.info)
        model.updated = datetime.datetime.utcnow()
        return model.save()

    @classmethod
    def perform_tags_by_platform(cls, tags, task_platform=None):
        if not task_platform:
            return tags
        task_platforms = task_platform if isinstance(task_platform, tuple) else (task_platform,)

        cache_key = (task_platforms, tags)
        cache_value = cls.PLATFORM_TAGS_CACHE.get(cache_key)
        if cache_value is None:
            task_platform_tags = filter(
                None,
                [
                    common.platform.PLATFORM_TO_TAG.get(common.platform.get_platform_alias(pl))
                    for pl in task_platforms
                ]
            )
            if task_platform_tags:
                if isinstance(tags, basestring):
                    tags = ctc.Tag.Query(tags)
                tags &= reduce(op.or_, task_platform_tags)
            cls.PLATFORM_TAGS_CACHE[cache_key] = (task_platform_tags, tags)
        else:
            task_platform_tags, tags = cache_value
        return tags

    @classmethod
    @common.utils.ttl_cache(300)
    def cached_match_tags(cls, client_tags, client_lxc, tags, only_detect_platform):
        ignore_defaults = common.config.Registry().common.installation == ctm.Installation.LOCAL
        predicates = ctc.Tag.Query.predicates(tags, ignore_defaults=ignore_defaults) if tags else []
        if only_detect_platform:
            os_tags = ctc.Tag.filter(ctc.Tag.Group.OS)
            predicates = [(p & os_tags, n & os_tags) for p, n in predicates]
        client_tags = set(client_tags)
        if str(ctc.Tag.Group.CUSTOM) in client_tags:
            client_tags = client_tags - ctc.Tag.Group.PURPOSE.expand()
        tags_without_platform = client_tags - common.platform.PLATFORM_TAGS
        own_platform_tags = client_tags & common.platform.PLATFORM_TAGS
        client_platforms = (
            common.platform.LXC_PLATFORMS
            if client_lxc else
            map(common.platform.TAG_TO_PLATFORM.get, own_platform_tags)
        )
        matching_platform = None
        for platform in client_platforms:
            platform_tag = common.platform.PLATFORM_TO_TAG.get(platform)
            if not platform_tag or platform_tag in ctc.Tag.Group:
                continue
            client_tags = tags_without_platform | set(it.imap(str, ctc.Tag.filter(platform_tag)))
            if not tags or any(client_tags >= p and not client_tags & n for p, n in predicates):
                matching_platform = platform
                break
        return matching_platform

    @classmethod
    def match_tags(cls, client, tags, task_platform=None, only_detect_platform=False):
        tags = cls.perform_tags_by_platform(tags, task_platform=task_platform)
        use_platform_containers = client.lxc or client.porto and client.multislot
        return cls.cached_match_tags(tuple(sorted(client.tags)), use_platform_containers, tags, only_detect_platform)

    @classmethod
    def container(cls, client, platform):
        """
        Returns best client's container resource form the given platform selector (`__host_chooser_os` context key).
        """
        from . import resource as resource_controller

        def func(resources):
            if not platform or platform in (ctm.OSFamily.ANY, ctm.OSFamily.LINUX) or platform not in resources:
                return next(resources.iteritems(), (None, None))
            return platform, resources.get(platform)

        if client.lxc or client.porto and client.multislot:
            return func(resource_controller.Resource.lxc_resources)
        if client.porto:
            return func(resource_controller.Resource.porto_layers_resources)

        raise Exception(
            "Can't get container for client {}, "
            "because it doesn't support any of: [lxc, porto]".format(client.hostname)
        )

    @classmethod
    def update_tags_impl(cls, hostname, tags, op):
        """
        Implementation of updating client tags
        :param hostname: client name
        :param tags: new tags
        :param op: operation type, one if cls.TagsOp
        """
        for _ in xrange(3):
            prev_tags = cls.Model.objects.with_id(hostname).tags
            manual_tags = filter(
                lambda _: (
                    _ in ctc.Tag.Group.SERVICE or
                    _ in ctc.Tag.Group.CUSTOM or
                    _ in ctc.Tag.Group.USER
                ),
                prev_tags
            )
            if op == cls.TagsOp.SET:
                new_tags = ctc.Tag.filter(it.chain(manual_tags, tags))
            elif op == cls.TagsOp.ADD:
                new_tags = ctc.Tag.filter(it.chain(prev_tags, tags))
            elif op == cls.TagsOp.REMOVE:
                new_tags = ctc.Tag.filter(set(prev_tags) - set(tags))
            else:
                raise ValueError("Unknown operation updating tags of the client {!r}: {!r}", hostname, op)
            new_tags = map(str, new_tags)
            pure_tags = ctc.Tag.filter(new_tags, include_groups=False)
            if (
                cls.Model.objects(
                    mp.Q(tags=None) | mp.Q(tags=prev_tags), hostname=hostname
                ).update(
                    set__tags=new_tags,
                    set__pure_tags=pure_tags
                )
            ):
                return list(new_tags)
            cls.logger.warning("Conflict occurred while updating tags for client {!r}", hostname)
            time.sleep(0.1)
        else:
            cls.logger.error("Tags for client {!r} are not updated", hostname)
            return None

    @classmethod
    def update_tags(cls, client, tags, op):
        """
        Update client tags
        :param client: mapping.Client object
        :param tags: new tags
        :param op: operation type, one if cls.TagsOp
        """
        client_tags = cls.update_tags_impl(client.hostname, tags, op)
        if client_tags is not None:
            client.tags = client_tags
        return client_tags

    @classmethod
    def reload(cls, client, cmd=Model.Reloading.RESTART, author=ctu.ANONYMOUS_LOGIN, comment=""):
        """
        Store service command for client

        :param client: mapping.Client object
        :param cmd: string with appropriate command. See Model.Reloading enum for available values
        :param author: user started reload
        :param comment: comment for command
        """
        if ctc.Tag.SERVER in client.tags and cmd in (cls.Model.Reloading.REBOOT, cls.Model.Reloading.POWEROFF):
            cls.logger.warning("Attempt to reboot or power off a server %r", client.hostname)
            return
        if cmd not in [command.command for command in client.pending_commands]:
            client.pending_commands.append(cls.Model.Command(command=cmd, author=author, comment=comment))
        if cmd == cls.Model.Reloading.SHUTDOWN:
            client.info.update({"msg": "<b style='color: red'>Turned off by {author} at {datetime}</b>".format(
                author=author,
                datetime=datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
            )})
        client.context = cPickle.dumps(client.info)
        client.save()

    @classmethod
    def reset_host_tasks(cls, logger, hosts):
        """
        Reset tasks from hosts (use tasks code)
        :param logger: logger
        :param hosts: hosts to reset
        :return: list of tasks to reset
        """
        from sandbox.yasandbox.controller import TaskQueue, TaskWrapper, OAuthCache, Task
        if not hosts:
            return []
        hosts = sorted(hosts)

        _TG_EXECUTE = list(ctt.Status.Group.EXECUTE)
        _TG_EXECUTE.remove(ctt.Status.TEMPORARY)
        _TG_EXECUTE.append(ctt.Status.RELEASING)

        restarted = []
        logger.info("Reset tasks assigned to hosts %r", hosts)

        def need_restart(task):
            if task.status in _TG_EXECUTE:
                return True

            # TODO: Use `session.state` instead of audit after SANDBOX-4747
            if task.status == ctt.Status.STOPPING:
                query = mp.Audit.objects(task_id=task.id).order_by("-date").limit(1)
                for status, author in query.scalar("status", "author"):
                    if status == ctt.Status.STOPPING and not author:
                        return True

            return False

        for client in cls.load_list(hosts):
            cls.reload(client, mp.Client.Reloading.RESET)

        sources = ["{}:{}".format(ctu.TokenSource.CLIENT, host) for host in hosts]
        sessions = mp.OAuthCache.objects(source__in=sources)
        for session in sessions:
            logger.info(
                "Drop session %r for task #%s and client %r",
                session.token, session.task_id, session.client_id
            )
            completed = TaskQueue.qclient.execution_completed(session.token)
            OAuthCache.expire(session)

            # Load task object after dropping session to minimize the possibility of getting stale state
            model = mp.Task.objects.with_id(session.task_id)
            if model is not None:
                task = TaskWrapper(model)
                if completed:
                    logger.info("Completed task #%s at %s", completed[0].id, completed[0].finished)
                    Task.close_all_intervals(
                        task.model, update=True, consumption=completed[0].consumption
                    )

                    execution_host = task.model.execution.host
                    if execution_host:
                        execution_host_tags = mp.Client.objects.fast_scalar("tags").with_id(execution_host) or []
                    else:
                        execution_host_tags = []

                    ram = task.model.requirements.ram or 0
                    if task.model.requirements.ramdrive:
                        ram += (task.model.requirements.ramdrive.size or 0)

                    if task.model.execution.intervals.execute:
                        execution_interval = task.model.execution.intervals.execute[-1]
                        common.statistics.Signaler().push(dict(
                            type=ctst.SignalType.TASK_SESSION_COMPLETION,
                            start=execution_interval.start,
                            finish=execution_interval.finish,
                            task_id=task.id,
                            host=task.host
                        ))
                        now = datetime.datetime.utcnow()
                        common.statistics.Signaler().push(dict(
                            type=ctst.SignalType.TASK_INTERVALS,
                            date=now,
                            timestamp=now,
                            task_id=task.id,
                            task_owner=task.model.owner,
                            task_type=task.model.type,
                            consumption=execution_interval.consumption,
                            duration=execution_interval.duration,
                            start=execution_interval.start,
                            pool=str(completed[0].pool),
                            client_id=execution_host,
                            client_tags=execution_host_tags,
                            privileged=int(task.model.requirements.privileged or 0),
                            cores=task.model.requirements.cores,
                            ram=ram,
                            caches=int(task.model.requirements.caches != []),
                            disk=int(task.model.requirements.disk_space or 0)
                        ))

                if need_restart(task):
                    logger.info("Restart task #%s because it is executing right now", task.id)

                    while True:
                        new_status = cls.status_after_host_reset(task.status)
                        try:
                            task.set_status(
                                new_status,
                                event="Clear tasks on dead host {}".format(task.host),
                                force=True, reset_resources_on_failure=True,
                            )
                        except common.errors.UpdateConflict as e:
                            logger.info("Error during task %s reset: %s. Try again.", task.id, str(e))
                            # Reload task object to get up-to-date state
                            task.model.reload()
                        else:
                            break

                        logger.info("Pushing task #%s back to the queue.", task.id)
                        if new_status == ctt.Status.ENQUEUED:
                            TaskQueue.enqueue_task(task)

                    restarted.append(task)

            session.delete()

        return restarted
