# coding: utf-8
from __future__ import print_function

import time
import collections
import platform
import logging

import cyordereddict
import tornado.gen
import tornado.ioloop

from infra.netmon.agent.idl import tasks_pb2

from . import application
from . import ticker
from . import rpc
from . import encoding
from . import const
from . import utils
from .settings import Settings
from .diagnostic.flow import diagnostic_task

TELEMETRY_TTL = 24 * 60 * 60 * utils.US
TELEMETRY_RESEND_INTERVAL = 4 * 60 * 60


class ExecutingTask(object):

    def __init__(self, app, enqueued_task, on_finished):
        self.app = app
        self._enqueued_task = enqueued_task
        self._on_finished = on_finished
        self._result_task = tasks_pb2.TTaskResult(
            ParentKey=self._enqueued_task.Key,
            Deadline=self._enqueued_task.Deadline,
            Generated=utils.timestamp()
        )

    @property
    def finished(self):
        return self._result_task.Finished

    def report(self):
        self.app[TaskDispatcher].enqueue([self._result_task])

    @tornado.gen.coroutine
    def start(self):
        logging.info("Start task %r", self._enqueued_task)

        self._result_task.Finished = False
        try:
            yield self._execute()
        except Exception as exc:
            self._result_task.Error = encoding.safe_str(exc)
        finally:
            self._result_task.Finished = True

        if self._result_task.Error:
            logging.error("Task %r failed with %r", self._enqueued_task, self._result_task.Error)
        else:
            logging.info("Task %r finished with %r", self._enqueued_task, self._result_task)

        self.report()

        self._on_finished()

    @tornado.gen.coroutine
    def _execute(self):
        task_type = self._enqueued_task.WhichOneof(b"Arguments")
        if task_type == "Diagnostic":
            yield diagnostic_task(self, self._enqueued_task.Diagnostic, self._result_task.Diagnostic)
        else:
            self._result_task.Error = b"unknown task type given"


class TaskExecutor(application.AppMixin, application.Service):

    def __init__(self, maxlen=100, max_parallel_tasks=1):
        self._max_parallel_tasks = max_parallel_tasks
        self._last_seen_tasks = cyordereddict.OrderedDict()
        self._queue = collections.deque(maxlen=maxlen)
        self._loop = ticker.LoopingCall("task_executor", self._execute, 5)
        self._executing_tasks = []

    def enqueue(self, enqueued_tasks):
        # because tasks exists in server queue until they will be reported at first time - remember them
        enqueued_tasks = (task for task in enqueued_tasks if task.Key not in self._last_seen_tasks)
        enqueued_counter = 0
        for task in enqueued_tasks:
            self._last_seen_tasks[task.Key] = None
            self._queue.append(task)
            enqueued_counter += 1
        while len(self._last_seen_tasks) >= self._queue.maxlen:
            self._last_seen_tasks.popitem(last=False)
        if enqueued_counter:
            self._loop.spin()
        return enqueued_counter

    @tornado.gen.coroutine
    def _execute(self):
        if self._executing_tasks:
            # filter out finished tasks first
            self._executing_tasks = [task for task in self._executing_tasks if not task.finished]

        tasks_to_starts = []
        while self._queue and len(self._executing_tasks) + len(tasks_to_starts) < self._max_parallel_tasks:
            tasks_to_starts.append(self._queue.popleft())

        for enqueued_task in tasks_to_starts:
            task = ExecutingTask(self._app, enqueued_task, on_finished=self._loop.spin)
            tornado.ioloop.IOLoop.current().add_callback(task.start)
            self._executing_tasks.append(task)

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()


class TaskDispatcher(application.AppMixin, application.Service):

    _loop_interval = 30

    def __init__(self, maxlen=100):
        self._hostname = encoding.safe_str(Settings.current().hostname)
        self._queue = collections.deque(maxlen=maxlen)
        self._loop = ticker.LoopingCall("task_dispatcher", self._dispatch, self._loop_interval, round_by_interval=True)
        self._last_telemetry_timestamp = None

    def enqueue(self, finished_tasks):
        self._queue.extend(finished_tasks)
        if self._queue:
            self._loop.spin()

    def _should_telemetry_be_updated(self):
        return (
            self._last_telemetry_timestamp is None
            or utils.should_be_updated(
                TELEMETRY_RESEND_INTERVAL,
                self._last_telemetry_timestamp,
                self._hostname
            )
        )

    @tornado.gen.coroutine
    def _create_telemetry_tasks(self):
        generated = utils.timestamp()
        telemetry_tasks = []
        if self._should_telemetry_be_updated():
            local_interfaces = yield self._app[application.IfaceService].get_interfaces()
            telemetry_tasks.append(tasks_pb2.TTaskResult(
                Generated=generated,
                Deadline=generated + TELEMETRY_TTL,
                Version=tasks_pb2.TVersionResult(
                    Host=self._hostname,
                    Version=const.FULL_VERSION,
                    Platform=encoding.safe_str(platform.platform()),
                    LocalInterfaces=[tasks_pb2.TLocalInterface(
                        FQDN=encoding.safe_str(iface.fqdn),
                        Address=encoding.safe_str(iface.address),
                        Mask=encoding.safe_str(iface.mask),
                        MAC=encoding.safe_str(iface.mac)
                    ) for iface in local_interfaces]
                ),
                Finished=True
            ))
        raise tornado.gen.Return(telemetry_tasks)

    @tornado.gen.coroutine
    def _dispatch(self):
        # send only last state into server
        seen_tasks = {}
        finished_tasks = []
        while self._queue:
            task = self._queue.popleft()
            if task.ParentKey in seen_tasks:
                finished_tasks[seen_tasks[task.ParentKey]] = task
            else:
                seen_tasks[task.ParentKey] = len(finished_tasks)
                finished_tasks.append(task)

        telemetry_tasks = yield self._create_telemetry_tasks()
        if telemetry_tasks or finished_tasks:
            logging.info("Sending %d tasks into server", len(telemetry_tasks) + len(finished_tasks))
            try:
                yield self._app[rpc.RpcClient].finish_tasks(finished_tasks + telemetry_tasks)
            except:
                slots = self._queue.maxlen - len(self._queue)
                if slots:
                    self._queue.extendleft(reversed(finished_tasks[-slots:]))
                raise
            else:
                if telemetry_tasks:
                    self._last_telemetry_timestamp = time.time()

        enqueued_tasks = yield self._app[rpc.RpcClient].enqueued_tasks(self._hostname)
        if enqueued_tasks:
            enqueued = self._app[TaskExecutor].enqueue(enqueued_tasks)
            logging.info("%d tasks fetched from server, %d of them taken to queue",
                         len(enqueued_tasks), enqueued)

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()
