from __future__ import absolute_import

import sys
import random
import threading

import six

from ya.skynet.util.pickle import dumps, loads
from ya.skynet.util.errors import saveTraceback
from ya.skynet.util.logging import MessageAdapter
from ya.skynet.services.cqudp import cfg

from .processhandle import ProcessHandle, Timeout, Signalled, CommunicationError
from ..utils import short, genuuid, run_daemon, sleep, monotime
from ..utils import reconfigure_log, log as root
from ..window import IndexedWindow, IncomingWindow, WindowTimeout
from ..transport.protocol import Protocol


WINSIZE = 16
MIN_ORPHAN_TIMEOUT = 15.
MIN_WAIT_MIN = 1.
MIN_WAIT_MAX = 2.


class TaskHandle(object):
    def __init__(self, uuid, hostid, task, path, notify_path=None, log=None):
        self.uuid = uuid
        self.hostid = hostid
        self.task = task
        self.paths = {tuple(path)}
        self.final_addr = (path[0],)
        self.notify_path = notify_path
        self.touch_time = monotime()
        self._stopped = False
        self._msg_uuids = set()
        self._params = task.get('params', {})
        if self._params.get('loglevel') is not None:
            reconfigure_log('child', level=self._params.get('loglevel'), logger=log)

        self._window = IndexedWindow(WINSIZE,
                                     break_fun=lambda: not self._running())  # outgoing messages task -> client
        self._incoming_window = IncomingWindow()  # incoming messages client -> task
        self.task_ready = threading.Event()  # set when task is ready to accept rpc messages
        self.protocol = None
        self.log = log or MessageAdapter(
            root().getChild('task'),
            fmt="[%(uuid)s] %(message)s",
            data={'uuid': short(self.uuid)},
        )
        self.log.info('task starting params: %s', self._params)

    def _init_protocol(self):
        if self.protocol is not None:
            return

        options = self.task.get('options', {})
        port_range = options.get('port_range', None)
        if port_range and isinstance(port_range, (tuple, list)):
            for port in six.moves.xrange(port_range[0], port_range[1] + 1):
                try:
                    self.protocol = Protocol(impl=self._transport,
                                             port=port,
                                             send_own_ip=options.get('send_own_ip', False),
                                             reuse=False)
                except Exception as e:
                    self.log.debug("port %d is probably in use, cannot create protocol: %s", port, e)
                else:
                    break
            if self.protocol is None:
                raise RuntimeError("No free ports in desired range: %s", port_range)
        else:
            self.protocol = Protocol(impl=self._transport, send_own_ip=options.get('send_own_ip', False))

        self.log.info('started bus on %s', self.protocol.listenaddr())

        if self.notify_path is not None:
            self.log.debug("notifying local server about start")
            self.protocol.route_direct(notify_started_msg(self.uuid), self.notify_path, self.hostid)

    def run(self):
        try:
            process = ProcessHandle(
                taskid=self.uuid,
                task=self.task,
                tempdir=cfg.server.TasksTempDir,
                loglevel=self._params.get('loglevel'),
                log=self.log,
            )
            process.fork()
            # any threads should be spawned only after fork
            # to prevent fork-in-multithread-app side-effects
            self._init_protocol()
            with process:
                run_daemon(self._receiver_thread, process.send)
                run_daemon(process.join, None)

                for res in self._yield_results(process):
                    self._enqueue_response(res[0], res[1])

                while self._running() and not self._window.empty():
                    sleep(1)

                if self._window.empty():
                    self.log.info('all messages delivered')

                if self._is_orphaned():
                    self.log.info('orphaned')

        except WindowTimeout:
            self.log.error("orphaned")
            raise
        except BaseException as e:
            self.log.exception("exception: %s", e, exc_info=sys.exc_info())
            raise
        finally:
            self.protocol.shutdown()

    def _yield_results(self, process):
        first_call = True
        while self._running():
            try:
                a = max(MIN_WAIT_MIN, self._params.get('wait_min', 0))
                b = max(MIN_WAIT_MAX, self._params.get('wait_max', 0))
                datatype, r = process.recv(random.uniform(a, b))

                if datatype == '_finish':
                    self.log.info('all results received')
                    break
                elif datatype == 'init':
                    self.log.debug('process requested task again')
                    process.send('init', self.task)
                    continue
                elif datatype == 'ready':
                    self.log.debug('process is ready to accept rpc')
                    for _, (initiator, msg) in self._incoming_window.pop():
                        process.send('rpc', (initiator, msg))
                    self.task_ready.set()
                    self._send_heartbeat()
                    continue

                yield datatype, r

            except Timeout:
                if first_call:
                    self._send_heartbeat()
                continue

            except (Signalled, CommunicationError) as e:
                # TODO: encapsulate dumps(None, e) and saveTraceback into exception
                saveTraceback(e)
                yield 'result', dumps((None, e))
                break

            finally:
                first_call = False

    def _enqueue_response(self, datatype, response):
        if self._pipeline:
            index = self._window.enqueue((datatype, response))
            self._send_response(index, datatype, response)
        else:
            if datatype == 'result':
                index = self._window.enqueue((datatype, response))
                self._send_response(index, datatype, response)
            else:
                response = loads(response)  # backward compatibility
                self.log.info('send %s', datatype)
                self._route(response)

    def _running(self):
        return not self._is_orphaned() and not self._stopped

    def _receiver_thread(self, send_func):
        while True:
            try:
                kind, envelope, iface = self.protocol.receive()
                msgs = envelope.msgs
            except Exception as e:
                self.log.exception('cannot receive a message: %s' % (e,), exc_info=sys.exc_info())
                continue

            if kind != 'route' or not envelope.delivered():
                continue

            for msg in msgs:
                uuid = msg['uuid']
                if uuid in self._msg_uuids:
                    continue
                self._msg_uuids.add(uuid)

                try:
                    if msg.get('taskid') != self.uuid:
                        self.log.debug('unexpected %r for task %s',
                                       msg.get('type', "unknown message type"),
                                       msg.get('taskid'))
                        continue

                    elif msg['type'] == 'heartbeat':
                        self._heartbeat(msg, envelope.data, envelope.path)
                    elif msg['type'] == 'addpath':
                        self._addpath(msg)
                    elif msg['type'] == 'stop':
                        self.log.info('task is requested to die')
                        self._stop()
                    else:
                        self._rpc(envelope.data, msg, envelope.initiator(), send_func)
                except Exception as e:
                    self.log.exception('exception on dispatching: %s', e, exc_info=sys.exc_info())

    def _addpath(self, msg):
        path = tuple(msg['path'])
        self.paths.add(path)

    def _stop(self):
        self._stopped = True

    def _is_orphaned(self):
        return monotime() - self.touch_time > self._orphan_timeout

    def _is_about_to_orphan(self):
        return monotime() - self.touch_time > 2. * self._orphan_timeout / 3

    def _send_heartbeat(self, last_sent_rpcid=None):
        self.log.info('send heartbeat')
        self._route(heartbeat(self.uuid, last_sent_rpcid))

    def _touch(self):
        self.touch_time = max(monotime(), self.touch_time)

    def _rpc(self, index, msg, initiator, send_func):
        if isinstance(index, tuple):
            index, data = index  # FIXME data in RPC is currently ignored, should we allow it?

        self.log.info("pushing #%s %s %s to child",
                      index,
                      msg['type'],
                      short(msg['uuid']))
        if index is not None:
            if index > self._incoming_window.index:
                self.log.info("requesting rpc #%s", self._incoming_window.index)
                self._send_heartbeat(self._incoming_window.index)
            self._incoming_window.put(index, (initiator, msg))

            if self.task_ready.is_set():
                for _, (initiator, msg) in self._incoming_window.pop():
                    send_func('rpc', (initiator, msg))
        else:
            send_func('rpc', (initiator, msg))

    def _heartbeat(self, hb, next_id, path):
        # FIXME backward compatibility for old format, after 15.6.0 it should always be tuple
        last_sent_rpcid = None
        if isinstance(next_id, tuple):
            next_id, last_sent_rpcid = next_id
            if last_sent_rpcid is not None and last_sent_rpcid < self._incoming_window.index:
                last_sent_rpcid = None
            else:
                last_sent_rpcid = self._incoming_window.index

        if hb['taskid'] != self.uuid:
            self.log.info('unexpected heartbeat for task %s', short(hb['taskid']))
            return
        self._touch()
        self._params.update(hb.get('params', {}))

        self.log.info('got heartbeat %s (#%s requested)', short(hb['uuid']), next_id)

        self.paths.add(tuple(path))
        x = self._window.seek(next_id)
        if x:
            self._send_response(x[0], x[1][0], x[1][1])
        else:
            self._send_heartbeat(last_sent_rpcid)

    def _send_response(self, idx, datatype, data):
        if self._pipeline:
            if isinstance(datatype, tuple) and datatype[0] == 'rpc':
                # FIXME compatibility hack, remove after all clients update
                datatype = 'rpc'
            self.log.debug('send response/%s #%s', datatype, idx)
            self._route(response_msg(self.uuid, idx, datatype, data))
        else:
            self.log.debug('send result #%s', idx)
            self._route(result(self.uuid, idx, data))

    def _route(self, msg):
        if self._is_about_to_orphan() and self.final_addr not in self.paths:
            self.paths.add(self.final_addr)
            path = self.final_addr
        else:
            path = random.choice(tuple(self.paths))

        if self._aggregate:
            if isinstance(msg, dict):
                msg['aggr'] = True
            else:
                setattr(msg, 'aggr', True)
            self.log.debug('aggregated route %s via %s', msg['type'], list(reversed(path)))
            self.protocol.route_back(msg, self.hostid, path)
        else:
            self.log.debug('direct route %s to %s', msg['type'], path[0])
            self.protocol.route_direct(msg, path[0], self.hostid)

    @property
    def _pipeline(self):
        return self._params.get('pipeline', False)

    @property
    def _orphan_timeout(self):
        return max(float(self._params.get('orphan_timeout', 0)), MIN_ORPHAN_TIMEOUT)

    @property
    def _aggregate(self):
        return self._params.get('aggregate', False)

    @property
    def _transport(self):
        if self._params.get('netlibus', False):
            return 'netlibus'
        return 'msgpack'


def result(taskid, index, res):
    return {
        'uuid': genuuid(),
        'taskid': taskid,
        'index': index,
        'result': res if isinstance(res, six.binary_type) else dumps(res),
        'type': 'result',
    }


def heartbeat(taskid, request_rpcid=0):
    return {
        'uuid': genuuid(),
        'taskid': taskid,
        'type': 'heartbeat',
        'request_rpc': request_rpcid,  # next incoming message id to ask client for
    }


def response_msg(taskid, index, datatype, data):
    return {
        'uuid': genuuid(),
        'taskid': taskid,
        'type': 'response',
        'index': index,
        'content': {
            'type': datatype,
            'data': data
        }
    }


def addpath_msg(taskid, path):
    return {
        'uuid': genuuid(),
        'taskid': taskid,
        'type': 'addpath',
        'path': path,
    }


def notify_started_msg(uuid):
    return {
        'uuid': genuuid(),
        'taskid': uuid,
        'type': 'task_started',
    }
