from __future__ import absolute_import, print_function

import time
import uuid
import gevent
import bisect
import struct
import datetime as dt
import itertools as it
import functools as ft
import collections

import six
import cython
import msgpack
import gevent.queue

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc

from sandbox.yasandbox.database import mapping

# noinspection PyUnresolvedReferences
if not cython.compiled:
    # Required to avoid errors on not existent cython decorators without compilation
    cython = type(
        "cython",
        (object,),
        dict(__call__=lambda s, o=None, **kws: o if not kws else s, __getattribute__=lambda s, n: s)
    )()

# TODO: remove after SANDBOX-8910
DEFAULT_QUOTA = 2472000  # default quota for owners without the quota


class Serializable(object):
    # XXX: against IDE warnings only
    def __init__(self, *args, **kws):
        # noinspection PyArgumentList
        super(Serializable, self).__init__()

    def encode(self):
        return self

    @classmethod
    def decode(cls, data):
        return data if isinstance(data, cls) else cls(*data)


class TaskQueueItem(
    collections.namedtuple("TaskQueueItem", "task_id priority hosts task_ref task_info score"),
    Serializable
):
    def __nonzero__(self):
        return bool(self.task_ref)

    def encode(self):
        return (
            self.task_id,
            self.priority,
            self.hosts,
            self.task_ref.encode(),
            self.task_info.encode(),
            self.score,
        )

    @classmethod
    def decode(cls, data):
        if isinstance(data, cls):
            return data
        task_id, priority, hosts, task_ref, task_info, score = data
        return cls(
            task_id,
            priority,
            hosts,
            task_ref,
            TaskInfo.decode(task_info),
            score,
        )


class TaskQueueHostsItem(
    collections.namedtuple("TaskQueueHostsItem", "score host"),
    Serializable
):
    pass


# noinspection PyUnresolvedReferences
@cython.no_gc
@cython.cclass
class HostQueueItem(Serializable):
    __slots__ = "score", "task_id", "task_ref"

    # noinspection PyMissingConstructor
    def __init__(self, score, task_id, task_ref):
        self.score = score
        self.task_id = task_id
        self.task_ref = task_ref

    def __nonzero__(self):
        return bool(self.task_ref)

    def __eq__(self, other):
        return (self.score, self.task_id) == (other.score, other.task_id)

    def __lt__(self, other):
        return (self.score, self.task_id) < (other.score, other.task_id)

    def __le__(self, other):
        return (self.score, self.task_id) <= (other.score, other.task_id)

    def __gt__(self, other):
        return (self.score, self.task_id) > (other.score, other.task_id)

    def __ge__(self, other):
        return (self.score, self.task_id) >= (other.score, other.task_id)

    def encode(self, priority=-1):
        return (
            (
                self.score,
                self.task_id,
                self.task_ref.encode()
            )
            if priority == -1 else
            (
                priority,
                self.score,
                self.task_id,
                self.task_ref.encode()
            )
        )


class ComputingResources(
    collections.namedtuple("ComputingResources", "disk_space resources cores ram slots"),
    Serializable
):
    @classmethod
    def decode(cls, data):
        if data is None:
            # noinspection PyArgumentList
            return cls()
        elif isinstance(data, cls):
            return data
        return super(ComputingResources, cls).decode(data)


ComputingResources.__new__.__defaults__ = (None, None, 0, 0, 1)


class HostInfo(
    collections.namedtuple("HostInfo", "capabilities tags free"),
    Serializable
):
    def encode(self):
        return (
            self.capabilities.encode(),
            self.tags and list(self.tags),
            self.free and self.free.encode(),
        )

    @classmethod
    def decode(cls, data):
        if data is None:
            # noinspection PyArgumentList
            return cls()
        elif isinstance(data, cls):
            return data
        capabilities, tags, free = data
        return cls(
            ComputingResources.decode(capabilities),
            tags,
            free and ComputingResources.decode(free),
        )

    @property
    def multislot(self):
        return self.tags and ctc.Tag.MULTISLOT in self.tags


HostInfo.__new__.__defaults__ = (None, (), None)


class TaskInfo(
    collections.namedtuple("TaskInfo", "requirements semaphores type owner enqueue_time duration client_tags"),
    Serializable
):
    """
    Task info tuple

    - `requirements`: task's requirements namedtuple, `sandbox.serviceq.types.ComputingResources`
    - `semaphores`: semaphores to be acquired by task, `sandbox.common.types.task.Semaphores`
    - `type`: task type, `str`
    - `owner`: task owner, `str`
    - `enqueue_time`: UNIX timestamp when the task has been added to the queue, `int`
    - `duration`: task execution time estimation, `int`
      (currently is ``kill_timeout`` divided by phi (golden ratio), see
      ``sandbox.yasandbox.controller.task.TaskQueue.add`` for details)
    - `client_tags`: tags expression describing cluster subset to execute the task, `str`
    """

    def encode(self):
        return (
            self.requirements and self.requirements.encode(),
            self.semaphores and list(self.semaphores),
            self.type,
            self.owner,
            self.enqueue_time,
            self.duration,
            self.client_tags
        )

    @classmethod
    def decode(cls, data):
        if data is None:
            # noinspection PyArgumentList
            return cls()
        elif isinstance(data, cls):
            return data
        requirements, semaphores, task_type, owner, enqueue_time, duration, client_tags = data
        return cls(
            requirements and ComputingResources.decode(requirements),
            semaphores,
            task_type,
            owner,
            enqueue_time,
            duration,
            client_tags
        )


# noinspection PyArgumentList
TaskInfo.__new__.__defaults__ = (ComputingResources(), None, None, None, None, 0, 0, None)


class Semaphore(
    common.utils.namedlist("Semaphore", "name owner capacity value auto shared public tasks"),
    Serializable
):
    def __repr__(self):
        return "<Semaphore '{}' ({}/{})>".format(self.name, self.value, self.capacity)

    @classmethod
    def decode(cls, data):
        if isinstance(data, cls):
            return data
        return cls(data)

    @property
    def free(self):
        return max(0, self.capacity - self.value)


SemaphoreGroup = collections.namedtuple("SemaphoreGroup", "sem_ids groups")


class TaskSemaphores(ctt.Semaphores):
    # field `name` contains semaphore id
    class Acquire(ctt.Semaphores.Acquire):
        semaphore_index = None

        def __new__(cls, name="", weight=1, capacity=1, public=False):
            return super(ctt.Semaphores.Acquire, cls).__new__(
                cls, cls.semaphore_index(name, capacity, public), weight, capacity, public
            )

    @classmethod
    def decode(cls, data):
        if isinstance(data, cls):
            return data
        acquires, release = data
        return ctt.Semaphores.__bases__[0](
            tuple(ctt.Semaphores.Acquire.__bases__[0](*_) for _ in sorted(acquires)),
            tuple(sorted(release))
        )


OwnersRatingItem = collections.namedtuple(
    "OwnersRatingItem",
    "remnant_ratio remnant real_consumption future_consumption quota executing_jobs queue_size is_default_quota"
)
OwnersRatingItem.__new__.__defaults__ = (None,) * len(OwnersRatingItem._fields)


QuotaItem = collections.namedtuple(
    "QuotaItem",
    "real_consumption future_consumption limit"
)
QuotaItem.__new__.__defaults__ = (None,) * len(QuotaItem._fields)


class TaskRef(Serializable):
    __slots__ = ["__task_id"]

    # noinspection PyMissingConstructor
    def __init__(self, task_id):
        self.__task_id = task_id

    def __repr__(self):
        return "<{}: {}>".format(type(self).__name__, self.__task_id)

    def __nonzero__(self):
        return self.__task_id != 0

    def __iter__(self):
        yield self.__task_id

    # noinspection PyUnresolvedReferences
    @cython.locals(other="TaskRef")
    def __eq__(self, other):
        return self.__task_id == other.__task_id

    def clear(self):
        if self.__task_id != 0:
            self.__task_id = 0

    def encode(self):
        return self.__task_id


class Status(common.utils.Enum):
    class Group(common.utils.GroupEnum):
        READY = None

    STARTING = None
    TRANSIENT = None
    RESTORING = None
    STOPPING = None

    with Group.READY:
        PRIMARY = None
        SECONDARY = None


class Statistics(object):
    Call = collections.namedtuple("Call", "name start duration")
    Dump = collections.namedtuple("Dump", "counters_history wants_task got_task calls done_tasks")

    __counters = None
    __counters_history = None
    __counters_timestamp = None
    __wants_task = None
    __got_task = None
    __calls = None
    __done_tasks = None

    # noinspection PyPep8Naming
    class __counter(property):
        known_counters = set()

        def __init__(self, method):
            self.__name = method.__name__
            self.known_counters.add(self.__name)
            property.__init__(self, method)

        def __get__(self, instance, instance_type=None):
            instance.flush_counters()
            return instance.counters[self.__name]

        def __set__(self, instance, value):
            instance.counters[self.__name] = value

    def __init__(self):
        self.dump()

    def flush_counters(self):
        now = int(time.time())
        if now > self.__counters_timestamp:
            if self.__counters:
                self.__counters_history.append((self.__counters_timestamp, self.__counters))
            self.__counters = collections.Counter()
            self.__counters_timestamp = now

    def dump(self):
        counters_history, self.__counters_history = self.__counters_history, []
        client_wants_task, self.__wants_task = self.__wants_task, collections.Counter()
        client_got_task, self.__got_task = self.__got_task, collections.Counter()
        calls, self.__calls = self.__calls, []
        done_tasks, self.__done_tasks = self.__done_tasks, []
        return self.Dump(counters_history, client_wants_task, client_got_task, calls, done_tasks)

    @common.utils.classproperty
    def known_counters(self):
        return self.__counter.known_counters

    @property
    def counters(self):
        return self.__counters

    @__counter
    def error(self):
        pass

    @__counter
    def task_conflict(self):
        pass

    @__counter
    def semaphore_conflict(self):
        pass

    @__counter
    def job_conflict(self):
        pass

    @__counter
    def release_semaphore_conflict(self):
        pass

    @__counter
    def primary_not_ready(self):
        pass

    @__counter
    def already_executing(self):
        pass

    @__counter
    def election_error(self):
        pass

    @__counter
    def db_error(self):
        pass

    @__counter
    def critical(self):
        pass

    @__counter
    def skipped_due_disk_space(self):
        pass

    @__counter
    def skipped_due_wait_semaphores(self):
        pass

    @__counter
    def skipped_due_cores(self):
        pass

    @__counter
    def skipped_due_ram(self):
        pass

    @__counter
    def blocked_by_semaphores(self):
        pass

    @__counter
    def total_wants(self):
        pass

    @__counter
    def total_got(self):
        pass

    @__counter
    def task_to_execute_iterations(self):
        pass

    def done_task(self, execution_info):
        self.__done_tasks.append(execution_info)

    def wants_task(self, tags):
        for tag in tags:
            self.__wants_task[tag] += 1

    def got_task(self, tags):
        for tag in tags:
            self.__got_task[tag] += 1

    def call(self, name, start, duration):
        self.__calls.append(self.Call(name, start, duration))


class Execution(object):
    __slots__ = "job_id", "start_time", "finish_time", "qp", "finished", "ram", "cpu", "hdd", "ssd"

    def __init__(self, job_id, start_time, finish_time, qp, finished, ram, cpu, hdd, ssd):
        """
        :param job_id: job identifier, 16 bytes
        :param start_time: start timestamp
        :param finish_time: supposed finish timestamp
        :param qp: consumed computational resources, in microQP
        :param finished: job status: 1 - finished, 0 - running, -1 - running but finish_time in the past
        :param ram: provided RAM, MiB
        :param cpu: provided CPU, cores
        :param hdd: provided HDD, MiB
        :param ssd: provided SSD, MiB
        """
        self.job_id = job_id
        self.start_time = start_time
        self.finish_time = finish_time
        self.qp = qp
        self.finished = finished
        self.ram = ram
        self.cpu = cpu
        self.hdd = hdd
        self.ssd = ssd

    def __repr__(self):
        reprs = "job_id={}, start_time={}, finish_time={}, qp={}, finished={}, ram={}, cpu={}, hdd={}, ssd={}"
        return "{}({})".format(
            type(self).__name__,
            reprs.format(
                uuid.UUID(bytes=self.job_id).hex,
                self.start_time,
                self.finish_time,
                self.qp,
                self.finished,
                common.format.size2str(self.ram),
                self.cpu,
                common.format.size2str(self.hdd),
                common.format.size2str(self.ssd),
            )
        )

    # noinspection PyUnresolvedReferences
    @cython.locals(other="Execution")
    def __eq__(self, other):
        return (
            self.job_id == other.job_id and
            self.start_time == other.start_time and
            self.finish_time == other.finish_time and
            self.qp == other.qp and
            self.finished == other.finished and
            self.ram == other.ram and
            self.cpu == other.cpu and
            self.hdd == other.hdd and
            self.ssd == other.ssd
        )

    def encode(self):
        return [
            self.job_id,
            self.start_time,
            self.finish_time,
            self.qp,
            self.finished,
            self.ram,
            self.cpu,
            self.hdd,
            self.ssd,
        ]


CONSUMPTION_WINDOW_SIZE = 12 * 60 * 60  # sliding window size in seconds (12 hours)
USAGE_METRICS_GRANULARITY = 3600  # in seconds, used for hour discounts in billing

UsageMetric = common.utils.namedlist("UsageMetric", "start_time finish_time usage")


# noinspection PyUnresolvedReferences
@cython.no_gc
@cython.cclass
class Consumption(Serializable):
    def __init__(self, window_size=CONSUMPTION_WINDOW_SIZE):
        super(Consumption, self).__init__()

        self.__window_size = window_size

        self.__consumed_qp = 0
        self.__executions = {}

        self.__started_executions = collections.deque()
        self.__finished_executions = collections.deque()

        self.__entering_qp = 0
        self.__exiting_qp = 0
        self.__future_qp = 0

        self.__future_executions = []
        self.__future_total_qp = 0

        self.__last_time = 0

        self.__ram = 0
        self.__cpu = 0
        self.__hdd = 0
        self.__ssd = 0

        self.__usage_metrics = []

    def __draw(self):
        lower_bound = self.__last_time - self.__window_size
        upper_bound = self.__last_time + self.__window_size
        output = []
        for e in sorted(self.__executions.itervalues(), key=lambda _: _.job_id):
            line = ""
            for t in xrange(lower_bound, upper_bound):
                if t == self.__last_time:
                    line += "|"
                if e.start_time > t or (e.finish_time if e.finished else max(e.finish_time, self.__last_time)) <= t:
                    line += "."
                elif e.start_time <= t < max(e.finish_time, self.__last_time):
                    line += "<" if t == e.start_time else "-"
            output.append("{} {}".format(uuid.UUID(bytes=e.job_id).hex, line))
        return "\n".join(output)

    def __repr__(self):
        return "Consumption(\n{}\n)\n{}".format(",\n".join((
            "\tlast_time={}".format(self.__last_time),
            "\tconsumed_qp={}".format(self.__consumed_qp),
            "\tfuture_total_qp={}".format(self.__future_total_qp),
            "\tentering_qp={}".format(self.__entering_qp),
            "\texiting_qp={}".format(self.__exiting_qp),
            "\tfuture_qp={}".format(self.__future_qp),
            "\texecutions={}".format(self.__executions),
            "\tstarted_executions={}".format(self.__started_executions),
            "\tfinished_executions={}".format(self.__finished_executions),
            "\tfuture_executions={}".format(self.__future_executions)
        )), self.__draw())

    def __eq__(self, other):
        return all((
            self.__consumed_qp == other.__consumed_qp,
            self.__executions == other.__executions,
            self.__started_executions == other.__started_executions,
            self.__finished_executions == other.__finished_executions,
            self.__entering_qp == other.__entering_qp,
            self.__exiting_qp == other.__exiting_qp,
            self.__future_qp == other.__future_qp,
            self.__future_executions == other.__future_executions,
            self.__future_total_qp == other.__future_total_qp,
            self.__last_time == other.__last_time
        ))

    @property
    def window_size(self):
        return self.__window_size

    @property
    def ram(self):
        return self.__ram

    @property
    def cpu(self):
        return self.__cpu

    @property
    def hdd(self):
        return self.__hdd

    @property
    def ssd(self):
        return self.__ssd

    def recalculate_internals(self, now=None):
        if not now:
            now = self.__last_time
        else:
            assert now >= self.__last_time
        lower_bound = now - self.__window_size
        entering_qp = 0
        exiting_qp = 0
        future_qp = 0
        consumed_qp = 0
        future_total_qp = 0
        for execution in self.__executions.itervalues():
            if not (execution.finished == 1 and execution.finish_time <= lower_bound):
                consumed_qp += execution.qp * (
                    (execution.finish_time if execution.finished == 1 else now) - max(
                        lower_bound, execution.start_time
                    )
                )
                if execution.finished == 0:
                    future_total_qp += execution.qp * max(execution.finish_time - now, 0)
            if execution.finish_time <= lower_bound:
                continue
            if execution is None:
                assert execution.finish_time <= now
            if execution.finished < 1 and execution.start_time >= lower_bound:
                entering_qp += execution.qp
            if execution.finished == 1 and execution.start_time < lower_bound:
                exiting_qp += execution.qp
            if execution.finished == 0 and execution.start_time != execution.finish_time:
                future_qp += execution.qp
        return consumed_qp, future_total_qp, entering_qp, exiting_qp, future_qp

    def check_invariants(self, raise_assertion=True):
        consumed_qp, future_total_qp, entering_qp, exiting_qp, future_qp = self.recalculate_internals()
        try:
            assert entering_qp == self.__entering_qp and entering_qp >= 0
            assert exiting_qp == self.__exiting_qp and exiting_qp >= 0
            assert future_qp == self.__future_qp and future_qp >= 0
            assert consumed_qp == self.__consumed_qp and consumed_qp >= 0
            assert future_total_qp == self.__future_total_qp and future_total_qp >= 0

            def check_started(job_ids, prev, curr):
                assert curr.job_id not in job_ids
                job_ids.add(curr.job_id)
                if prev is None:
                    return
                assert prev.start_time <= curr.start_time

            def check_finished(prev, curr):
                assert curr.finished == 1
                if prev is None:
                    return
                assert prev.finish_time <= curr.finish_time

            check_sequence = lambda sequence, checker: reduce(lambda a, b: (checker(a, b), b)[1], sequence, None)
            check_sequence(self.__started_executions, ft.partial(check_started, set()))
            check_sequence(self.__finished_executions, check_finished)
        except AssertionError:
            if raise_assertion:
                raise
            return False
        return True

    def fix_invariants(self):
        (
            self.__consumed_qp,
            self.__future_total_qp,
            self.__entering_qp,
            self.__exiting_qp,
            self.__future_qp
        ) = self.recalculate_internals()
        job_ids = set()
        started_executions = collections.deque()
        for execution in self.__started_executions:
            if execution.job_id in job_ids:
                continue
            started_executions.append(execution)
            job_ids.add(execution.job_id)
        self.__started_executions = started_executions
        job_ids = set()
        finished_executions = collections.deque()
        for execution in self.__finished_executions:
            if execution.job_id in job_ids:
                continue
            finished_executions.append(execution)
            job_ids.add(execution.job_id)
        self.__finished_executions = finished_executions

    def _add_instant_consumption(self, execution):
        self.__ram += execution.ram
        self.__cpu += execution.cpu
        self.__hdd += execution.hdd
        self.__ssd += execution.ssd

    def _sub_instant_consumption(self, execution):
        self.__ram -= execution.ram
        self.__cpu -= execution.cpu
        self.__hdd -= execution.hdd
        self.__ssd -= execution.ssd

    def _update_usage_metric(self, now):
        if not self.__usage_metrics:
            self.__usage_metrics.append(UsageMetric((now, now, 0)))
        else:
            last_metric = self.__usage_metrics[-1]
            start_hour = last_metric.start_time // USAGE_METRICS_GRANULARITY * USAGE_METRICS_GRANULARITY
            current_hour = now // USAGE_METRICS_GRANULARITY * USAGE_METRICS_GRANULARITY
            if start_hour < current_hour:
                prev_finish_time = last_metric.finish_time
                last_metric.finish_time = current_hour - 1
                last_metric.usage += self.__entering_qp * (last_metric.finish_time - prev_finish_time + 1)
                last_metric = UsageMetric((current_hour, current_hour, 0))
                self.__usage_metrics.append(last_metric)
            finish_time = last_metric.finish_time
            last_metric.finish_time = now
            last_metric.usage += self.__entering_qp * (last_metric.finish_time - finish_time)

    def started(self, start_time, job_id, qp, duration, ram, cpu, hdd, ssd):
        if job_id in self.__executions:
            return False

        self.calculate(start_time)
        self._update_usage_metric(start_time)

        duration = min(duration, self.__window_size)
        finish_time = start_time + duration
        execution = Execution(job_id, start_time, finish_time, qp, 0, ram, cpu, hdd, ssd)

        self.__executions[job_id] = execution
        self.__started_executions.append(execution)
        self.__entering_qp += qp

        self._add_instant_consumption(execution)

        if duration:
            bisect.insort(self.__future_executions, (finish_time, execution))
            self.__future_qp += qp
            self.__future_total_qp += qp * duration

        if not self.__last_time:
            self.__last_time = start_time
        return True

    def finished(self, finish_time, job_id):
        execution = self.__executions.get(job_id)
        if execution is None or execution.finished == 1:
            return

        self.calculate(finish_time)
        self._update_usage_metric(finish_time)

        if execution.finished == 0:
            if execution.finish_time > execution.start_time:
                self.__future_qp -= execution.qp
            if execution.finish_time > finish_time:
                self.__future_total_qp -= execution.qp * (execution.finish_time - finish_time)

        if execution.start_time >= finish_time - self.__window_size:
            self.__entering_qp -= execution.qp
        execution.finished = 1
        execution.finish_time = finish_time

        self._sub_instant_consumption(execution)

        self.__finished_executions.append(execution)
        if execution.start_time < finish_time - self.__window_size:
            self.__exiting_qp += execution.qp
        return ConsumedResources(
            execution.qp * (finish_time - execution.start_time),
            execution.ram, execution.cpu, execution.hdd, execution.ssd
        )

    def calculate(self, now):
        if not self.__last_time:
            return self.__consumed_qp, self.__future_total_qp
        assert now >= self.__last_time
        calc_interval = now - self.__last_time
        lower_bound = now - self.__window_size
        prev_lower_bound = self.__last_time - self.__window_size
        while self.__started_executions:
            execution = self.__started_executions[0]
            if execution.start_time >= lower_bound:
                break
            if execution.finished == 1:
                self.__exiting_qp += execution.qp
            else:
                self.__entering_qp -= execution.qp
            self.__consumed_qp += execution.qp * (execution.start_time - prev_lower_bound)
            self.__started_executions.popleft()

        while self.__finished_executions:
            execution = self.__finished_executions[0]
            if execution.finish_time > lower_bound:
                break
            self.__exiting_qp -= execution.qp
            self.__consumed_qp -= execution.qp * (execution.finish_time - prev_lower_bound)
            self.__finished_executions.popleft()
            self.__executions.pop(execution.job_id)

        self.__future_total_qp -= self.__future_qp * (now - self.__last_time)

        lower_index = -1
        for lower_index, (finish_time, execution) in enumerate(self.__future_executions):
            if execution.finished != 0:
                continue
            if finish_time > now:
                break
            self.__future_qp -= execution.qp
            self.__future_total_qp += execution.qp * (now - execution.finish_time)
            execution.finished = -1
        else:
            if self.__future_executions:
                self.__future_executions[:] = []
        if lower_index >= 0 and self.__future_executions:
            self.__future_executions[:] = self.__future_executions[lower_index:]

        self.__consumed_qp += calc_interval * (self.__entering_qp - self.__exiting_qp)

        self.__last_time = now

        return self.__consumed_qp, self.__future_total_qp

    def recalculate(self, now=None):
        if not now:
            now = self.__last_time
        else:
            assert now >= self.__last_time
        return self.recalculate_internals(now)[:2]

    @property
    def jobs(self):
        return self.__executions.keys()

    @property
    def executing_jobs(self):
        return [job_id for job_id, execution in self.__executions.iteritems() if execution.finished < 1]

    @property
    def qp(self):
        return self.__consumed_qp, self.__future_total_qp

    def calculate_usage_metrics(self, now=None, flush=True):
        if now is None:
            now = int(time.time())
        self._update_usage_metric(now)
        if flush:
            metrics, self.__usage_metrics = self.__usage_metrics, [UsageMetric((now, now, 0))]
        else:
            metrics = self.__usage_metrics
        return metrics

    def encode(self):
        return [
            self.__consumed_qp,
            [(job_id, execution.encode()) for job_id, execution in self.__executions.iteritems()],
            map(lambda _: _.job_id, self.__started_executions),
            map(lambda _: _.job_id, self.__finished_executions),
            self.__entering_qp,
            self.__exiting_qp,
            self.__future_qp,
            [
                (finish_time, execution.job_id)
                for finish_time, execution in self.__future_executions
            ],
            self.__future_total_qp,
            self.__last_time,
            self.__ram,
            self.__cpu,
            self.__hdd,
            self.__ssd,
            self.__usage_metrics,
        ]

    @classmethod
    def decode(cls, data):
        if isinstance(data, cls):
            return data
        obj = cls()
        (
            obj.__consumed_qp,
            executions,
            started_executions,
            finished_executions,
            obj.__entering_qp,
            obj.__exiting_qp,
            obj.__future_qp,
            future_executions,
            obj.__future_total_qp,
            obj.__last_time,
            obj.__ram,
            obj.__cpu,
            obj.__hdd,
            obj.__ssd,
            usage_metrics,
        ) = (data + [[]])[:15]  # TODO: remove slicing after SANDBOX-8931
        executions = {job_id: Execution(*data) for job_id, data in executions}
        obj.__executions = executions
        obj.__started_executions = collections.deque(
            it.ifilter(None, (executions.get(job_id) for job_id in started_executions))
        )
        obj.__finished_executions = collections.deque(
            it.ifilter(None, (executions.get(job_id) for job_id in finished_executions))
        )
        obj.__future_executions = filter(
            lambda _: _[1],
            ((finish_time, executions.get(job_id)) for finish_time, job_id in future_executions)
        )
        obj.__usage_metrics = map(UsageMetric, usage_metrics)
        return obj


class PreQueue(object):
    SettlingQueueItem = collections.namedtuple("SettlingQueueItem", "task_id pop_time")

    def __init__(self):
        self.__queue = collections.deque()
        self.__queue_tasks = set()
        self.__settling_queue = collections.deque()
        self.__settling_tasks = set()

    def __forget_outdated_tasks(self, timeout, logger):
        now = time.time()
        while self.__settling_queue and now - self.__settling_queue[0].pop_time >= timeout:
            item = self.__settling_queue.popleft()
            if logger:
                logger.info("Task #%s may be pushed to prequeue again", item.task_id)
            self.__settling_tasks.discard(item.task_id)

    def push(self, task_id, timeout, logger):
        self.__forget_outdated_tasks(timeout, logger)
        if task_id in self.__settling_tasks or task_id in self.__queue_tasks:
            return False
        self.__queue.append(task_id)
        self.__queue_tasks.add(task_id)
        return True

    def pop(self, timeout, logger):
        if not self.__queue:
            return
        task_id = self.__queue.popleft()
        self.__queue_tasks.discard(task_id)
        self.__forget_outdated_tasks(timeout, logger)
        self.__settling_queue.append(self.SettlingQueueItem(task_id, time.time()))
        self.__settling_tasks.add(task_id)
        return task_id


FinishExecutionInfo = collections.namedtuple(
    "FinishExecutionInfo",
    "id finished consumption ram cpu hdd ssd pool"
)
FinishExecutionInfo.__new__.__defaults__ = (None, None, None, 0, 0, 0, 0, None)
ConsumedResources = collections.namedtuple("ConsumedResources", "qp ram cpu hdd ssd")


def calc_bits_to_list_map():
    bits_map = [None] * 65536
    bits_map[0] = []
    for i in range(16):
        bits_map[1 << i] = []
        bits_map[1 << i].append(i)
    for i in range(1, 65536):
        if bits_map[i]:
            continue
        bits_map[i] = []
        ri = (i & (i - 1))
        li = (i ^ ri)
        bits_map[i].extend(bits_map[li])
        bits_map[i].extend(bits_map[ri])
    return bits_map


ST_UINT32 = struct.Struct(">I")
ST_UINT16 = struct.Struct(">H")
BITS_TO_LIST = calc_bits_to_list_map()


# TODO: remove after SANDBOX-9402
def make_bits_from_indexes(indexes):
    bits = []
    b = 0
    indexes = sorted(indexes, reverse=True)
    ii = 0
    for i, item_index in enumerate(
        six.moves.range(((indexes[0] if isinstance(indexes[0], int) else indexes[0][1]) + 32) // 32 * 32 - 1, -1, -1)
    ):
        if i and i % 32 == 0 and (bits or b):
            bits.append(ST_UINT32.pack(b))
            b = 0
        b <<= 1
        if ii == len(indexes):
            continue
        index = indexes[ii]
        if index == item_index:
            b += 1
            ii += 1
        elif index > item_index:
            ii += 1
    bits.append(ST_UINT32.pack(b))
    return b"".join(bits)


class IndexedList(Serializable):
    def __init__(self, list_=None):
        super(IndexedList, self).__init__()
        self.__list = list_ or []
        self.__index = {item: i for i, item in enumerate(self.__list)}

    def __reduce__(self):
        return IndexedList, (self.__list,)

    def __getitem__(self, index):
        """ Get item by index """
        return self.__list[index]

    def __eq__(self, other):
        return self.__list == (other.__list if isinstance(other, IndexedList) else other)

    def __len__(self):
        return len(self.__list)

    def __contains__(self, item):
        return item in self.__index

    @property
    def index(self):
        return self.__index

    def append(self, item):
        """ Add new item """
        index = self.__index.get(item)
        if index is None:
            index = self.__index[item] = len(self.__index)
            self.__list.append(item)
        return index

    def make_bits(self, items):
        items = set(items)
        bits = []
        b = 0
        offset = len(self.__list) % 32
        if offset:
            offset = 32 - offset
        for i, item in enumerate(reversed(self.__list), offset):
            if i and i % 32 == 0 and (bits or b):
                bits.append(ST_UINT32.pack(b))
                b = 0
            b <<= 1
            found = item in items
            if found:
                b += 1
                items.remove(item)
        bits.append(ST_UINT32.pack(b))
        return b"".join(bits), items

    @staticmethod
    def indexes_from_bits(bits):
        indexes = []
        offset = 0
        for i in six.moves.range(len(bits), 0, -2):
            map_index = ST_UINT16.unpack(bits[i - 2: i])[0]
            if map_index:
                indexes.extend(offset + j for j in BITS_TO_LIST[map_index])
            offset += 16
        return indexes

    def encode(self):
        return self.__list

    @classmethod
    def decode(cls, data):
        return cls(data)


class FixedSizeNumericDeque(object):
    def __init__(self, size):
        super(FixedSizeNumericDeque, self).__init__()
        self.size = size
        self.__queue = [0] * size
        self.__top = 0

    def normalize(self, index):
        while index < 0:
            index += self.size
        while index >= self.size:
            index -= self.size
        return index

    def __getitem__(self, index):
        return self.__queue[self.normalize(index + self.__top)]

    def __setitem__(self, index, value):
        self.__queue[self.normalize(index + self.__top)] = value

    def pop(self):
        new_top = self.normalize(self.__top + 1)
        value = self.__queue[new_top]
        self.__queue[new_top] = 0
        self.__top = new_top
        return value


class ApiConsumption(object):
    def __init__(self, window_size, default_quota):
        super(ApiConsumption, self).__init__()
        self.__table = {}  # dict of login: FixedSizeNumericDeque items
        self.consumption = {}  # dict of login: (consumption, quota) items
        self.timestamp = 0  # last table update timestamp
        self.banned_list = []
        self.window_size = window_size
        self.default_quota = default_quota

    def check_user(self, login):
        return login in self.consumption

    def add_user(self, login, quota=None):
        if quota is None:
            quota = self.default_quota
        self.consumption[login] = (0, quota)
        self.__table[login] = FixedSizeNumericDeque(self.window_size)

    def rotate_table(self, timestamp):
        self.banned_list = []
        users_to_remove = []
        if self.timestamp >= timestamp:
            return
        if timestamp - self.timestamp >= self.window_size:
            self.__table.clear()
            self.consumption.clear()
        else:
            for login, row in self.__table.iteritems():
                consumption, quota = self.consumption[login]
                consumption -= row.pop()
                if consumption == 0:
                    users_to_remove.append(login)
                else:
                    self.consumption[login] = (consumption, quota)
                if consumption >= quota:
                    self.banned_list.append(login)
        self.timestamp = timestamp
        for login in users_to_remove:
            self.__table.pop(login)
            self.consumption.pop(login)

    def add_consumption(self, login, timestamp, delta_consumption, quota=None):
        if not self.check_user(login):
            self.add_user(login, quota)

        index = timestamp - self.timestamp
        if abs(index) < self.window_size:
            self.__table[login][index] += delta_consumption
            consumption, quota = self.consumption[login]
            self.consumption[login] = (consumption + delta_consumption, quota)


class ComplexApiConsumption(object):
    EMPTY_COMPLEX_BANNED_LIST = msgpack.dumps([[], []], use_bin_type=True)  # serialized empty banned list
    WINDOW_SIZE = 900  # number of time quants for window in seconds
    DEFAULT_QUOTA = 1800000  # default quota for users without the quota in milliseconds

    WEB_WINDOW_SIZE = 180  # number of time quants for window in seconds
    WEB_DEFAULT_QUOTA = 720000  # default quota for users without the quota in milliseconds

    def __init__(self):
        self.api_consumption = ApiConsumption(self.WINDOW_SIZE, self.DEFAULT_QUOTA)
        self.web_api_consumption = ApiConsumption(self.WEB_WINDOW_SIZE, self.WEB_DEFAULT_QUOTA)
        self.serialized_banned_list = self.EMPTY_COMPLEX_BANNED_LIST
        self.timestamp = 0  # last table update timestamp

    def _rotate_table(self, timestamp):
        self.api_consumption.rotate_table(timestamp)
        self.web_api_consumption.rotate_table(timestamp)
        self.serialized_banned_list = msgpack.dumps(
            (self.api_consumption.banned_list, self.web_api_consumption.banned_list),
            use_bin_type=True
        )

    def rotate_table(self, timestamp=None):
        self.timestamp = int(time.time()) if timestamp is None else max(int(time.time()), timestamp)
        self._rotate_table(self.timestamp)


class DBOperations(Serializable):
    class Operation(Serializable):
        # noinspection PyPep8Naming
        class __metaclass__(type):
            registry = {}

            def __new__(mcs, name, bases, namespace):
                cls = type.__new__(mcs, name, bases, namespace)
                if bases != (Serializable,) and name not in mcs.registry:
                    mcs.registry[name] = cls
                return cls

        _db_collection = None

        def __init__(self, *args, **kws):
            self._args = args
            self._kws = kws
            super(DBOperations.Operation, self).__init__()

        def __repr__(self):
            return "{}({}{})".format(
                type(self).__name__,
                ", ".join(map(repr, self._args)),
                ", {}".format(", ".join("{}={!r}".format(k, v) for k, v in self._kws.viewitems())) if self._kws else ""
            )

        def __call__(self):
            pass

        @property
        def type(self):
            return self._db_collection.__name__

        def encode(self):
            return type(self).__name__, self._args, self._kws

        @classmethod
        def decode(cls, data):
            if isinstance(data, cls):
                return data
            name, args, kws = data
            # noinspection PyUnresolvedReferences
            return DBOperations.Operation.registry[name](*args, **kws)

    class SetQuota(Operation):
        _db_collection = mapping.Group

        def __init__(self, name, quota):
            super(DBOperations.SetQuota, self).__init__(name, quota)

        def __call__(self):
            name, quota = self._args
            return self._db_collection.objects(name=name).update(**(
                dict(unset__quota=True)
                if quota is None else
                dict(set__quota=quota)
            ))

    class SetParent(Operation):
        _db_collection = mapping.Group

        def __init__(self, name, parent):
            super(DBOperations.SetParent, self).__init__(name, parent)

        def __call__(self):
            name, parent = self._args
            return self._db_collection.objects(name=name).update(**(
                dict(unset__parent=True)
                if parent is None else
                dict(set__parent=parent)
            ))

    class CreateSemaphore(Operation):
        _db_collection = mapping.Semaphore

        def __init__(self, sem_id, name, owner, auto=None, capacity=None, shared=None, public=None):
            super(DBOperations.CreateSemaphore, self).__init__(
                sem_id, name, owner, auto=auto, capacity=capacity, shared=shared, public=public
            )

        def __call__(self):
            sem_id, name, owner = self._args
            kws = dict(
                self._kws,
                id=sem_id,
                name=name,
                owner=owner,
                time=self._db_collection.Time()
            )
            doc = self._db_collection(
                **{k: v for k, v in kws.viewitems() if v is not None}
            )
            for _ in range(2):
                try:
                    doc.save(force_insert=True)
                    break
                except mapping.NotUniqueError:
                    self._db_collection.objects(name=name).delete()

    class SetApiQuota(Operation):
        _db_collection = mapping.User

        def __init__(self, login, quota):
            super(DBOperations.SetApiQuota, self).__init__(login, quota)

        def __call__(self):
            login, api_quota = self._args
            self._db_collection.objects(login=login).update(set__api_quota=api_quota)

    class UpdateSemaphore(Operation):
        _db_collection = mapping.Semaphore

        def __init__(self, sem_id, owner=None, auto=None, capacity=None, shared=None, public=None):
            super(DBOperations.UpdateSemaphore, self).__init__(
                sem_id, owner=owner, auto=auto, capacity=capacity, shared=shared, public=public
            )

        def __call__(self):
            sem_id = self._args[0]
            kws = dict(
                set__time__updated=dt.datetime.utcnow(),
                **{"set__{}".format(k): v for k, v in self._kws.viewitems() if v is not None}
            )
            self._db_collection.objects(id=sem_id).update(**kws)

    class DeleteSemaphore(Operation):
        _db_collection = mapping.Semaphore

        def __init__(self, sem_id):
            super(DBOperations.DeleteSemaphore, self).__init__(sem_id)

        def __call__(self):
            self._db_collection.objects(id__in=common.utils.chain(*self._args)).delete()

    class UpdateTask(Operation):
        _db_collection = mapping.Task

        def __init__(self, task_id, acquired_semaphore):
            super(DBOperations.UpdateTask, self).__init__(task_id, acquired_semaphore)

        def __call__(self):
            task_id, acquired_semaphore = self._args
            self._db_collection.objects(id=task_id).update(set__acquired_semaphore=acquired_semaphore)

    def __init__(self):
        self.__queues = {}
        # noinspection PyProtectedMember,PyUnresolvedReferences
        self.__queues = {
            queue_name: gevent.queue.Queue()
            for queue_name in set(op._db_collection.__name__ for op in self.Operation.registry.viewvalues())
        }
        super(DBOperations, self).__init__()

    def __getitem__(self, queue_name):
        return self.__queues[queue_name]

    def __deepcopy__(self, _):
        return type(self)()

    def push(self, operation):
        queue = self.__queues[operation.type]
        queue.put(operation)

    @property
    def queues(self):
        return self.__queues.keys()

    def encode(self):
        return list(it.chain.from_iterable([op.encode() for op in queue.queue] for queue in self.__queues.viewvalues()))

    @classmethod
    def decode(cls, data):
        obj = cls()
        for item in data:
            obj.push(cls.Operation.decode(item))
        return obj


class QueueIterationResult(common.utils.Enum):
    ACCEPTED = 0
    SKIP_TASK = 1
    NO_TASKS = 2
    NEXT_TASK = 3
    SKIP_JOB = 4


class SemaphoreBlockers(object):
    class State(object):
        def __init__(self):
            self.occupied = collections.defaultdict(int)
            self.blockers = collections.defaultdict(dict)
            self.blocked_cache = {}
            self.blocked_count = 0

        def block(self, task_id, sem_id, weight):
            self.blocked_count += 1
            self.blockers[task_id][sem_id] = weight

        def occupy(self, sem_id, weight):
            self.occupied[sem_id] += weight

    def __init__(self):
        self.__blockers = collections.defaultdict(dict)
        self.__sizes_by_sem = collections.defaultdict(int)

    def __repr__(self):
        return repr(self.__blockers)

    def __eq__(self, other):
        return self.__blockers == other.__blockers

    @property
    def tasks(self):
        return six.viewkeys(self.__blockers)

    def blocked(self, task_id, task_semaphores, semaphores, state):
        if not task_semaphores:
            return False
        blocked_sem = state.blocked_cache.get(task_semaphores.acquires)
        if blocked_sem is not None:
            state.blocked_count += 1
            state.blockers[task_id] = blocked_sem
            return True
        blocked_sem = self.__blockers.get(task_id, {})
        for acquire in task_semaphores.acquires:
            sem_id = acquire.name
            sem = semaphores[sem_id]
            if sem_id in blocked_sem and sem.free - state.occupied[sem_id] < acquire.weight:
                state.blockers[task_id] = state.blocked_cache[task_semaphores.acquires] = blocked_sem
                return True
        return False

    def actual_blockers(self, state):
        for task_id in self.__blockers:
            state.blockers.pop(task_id, None)
        return state.blockers

    def size(self, sem_id):
        return self.__sizes_by_sem.get(sem_id, 0)

    def add(self, blockers):
        for task_id, weights in six.iteritems(blockers):
            blocked_by_task = self.__blockers[task_id]
            for sem_id, weight in six.iteritems(weights):
                size_diff = weight - blocked_by_task.get(sem_id, 0)
                if size_diff:
                    self.__sizes_by_sem[sem_id] += size_diff
            blocked_by_task.update(weights)

    def remove(self, task_id):
        blocked_by_task = self.__blockers.pop(task_id, None)
        if blocked_by_task is not None:
            for sem_id, weight in six.iteritems(blocked_by_task):
                self.__sizes_by_sem[sem_id] -= weight


class ResourceLocks(object):
    LOCK_TTL = 24 * 60 * 60  # ttl in seconds

    def __init__(self, locks=None):
        self.__locks = locks or {}

    def __repr__(self):
        return repr(self.__locks)

    def __eq__(self, other):
        return self.__locks == other.__locks

    def encode(self):
        return self.__locks

    def acquire(self, resource_id, host, timestamp):
        if resource_id not in self.__locks or self.__locks[resource_id][0] == host:
            self.__locks[resource_id] = (host, timestamp)
            return True
        return False

    def release(self, resource_id, host):
        if resource_id in self.__locks and self.__locks[resource_id][0] == host:
            self.__locks.pop(resource_id)
            return True
        return False

    def clean_resource_locks(self, timestamp):
        locks_to_delete = []
        for res_id, lock in self.__locks.iteritems():
            if timestamp - lock[1] >= self.LOCK_TTL:
                locks_to_delete.append((res_id, lock[0]))

        for res_id, host in locks_to_delete:
            self.release(res_id, host)


class QuotaPools(Serializable):
    def __init__(self):
        self.__pools = IndexedList()
        self.__tags = {}
        self.__quota_pool_cache = {}
        self.__defaults = {}
        super(QuotaPools, self).__init__()

    def __eq__(self, other):
        return self.__pools == other.__pools and self.__tags == other.__tags

    @property
    def pools(self):
        return self.__pools

    def default(self, pool_index):
        # TODO: fix after switch to pool quotas
        if pool_index is None:
            return DEFAULT_QUOTA
        return self.__defaults.get(pool_index, 0)

    def add(self, pool, tags):
        if pool in self.__pools:
            raise ValueError("pool {} is already added".format(pool))
        tags = set(tags)
        for index, pool_tags in self.__tags.iteritems():
            if tags <= pool_tags or tags >= pool_tags:
                raise ValueError(
                    "pool {} is a subset or superset of already existed pool {}".format(pool, self.__pools[index])
                )
        pool_index = self.__pools.append(pool)
        self.__tags[pool_index] = set(tags)
        return pool_index

    def update(self, pool, tags=None, default_quota=None):
        if pool is not None and pool not in self.__pools:
            raise ValueError("pool {} does not exist".format(pool))
        pool_index = self.__pools.index[pool]
        if tags is not None and pool is not None:
            tags = set(tags)
            for index, pool_tags in self.__tags.iteritems():
                if index == pool_index:
                    continue
                if tags <= pool_tags or tags >= pool_tags:
                    raise ValueError(
                        "pool {} is a subset or superset of already existed pool {}".format(pool, self.__pools[index])
                    )
            self.__tags[pool_index] = tags
        if default_quota is not None:
            self.__defaults[pool_index] = default_quota

    def match_pool(self, tags):
        if not tags:
            return None
        cache_key = tuple(sorted(tags))
        pool = self.__quota_pool_cache.get(cache_key, ctm.NotExists)
        if pool is not ctm.NotExists:
            return pool
        tags = set(tags)
        for pool, pool_tags in self.__tags.iteritems():
            if tags >= pool_tags:
                self.__quota_pool_cache[cache_key] = pool
                return pool

    def encode(self):
        return [
            self.__pools.encode(),
            {k: list(v) for k, v in self.__tags.iteritems()},
            self.__defaults,
        ]

    @classmethod
    def decode(cls, data):
        obj = cls()
        obj.__pools = IndexedList.decode(data[0])
        obj.__tags = {k: set(v) for k, v in data[1].iteritems()}
        if len(data) > 2:  # TODO: remove after SANDBOX-8907
            obj.__defaults = data[2]
        return obj
