import os
import time
import logging
import threading
from collections import defaultdict

from . import NotifyException, EmailNotifier, TelegramNotifier  # noqa

import msgpack
from concurrent import futures
from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import ConsumerMessageType
from kikimr.public.sdk.python.persqueue.errors import ActorTerminatedException

TIMEOUT = 60
COMMIT_TIMEOUT = 1800  # high timeout because commit is important


# TODO save uncommitted message to autocommit if reconnect occurred
class Worker(threading.Thread):
    def __init__(self, config, pq_client, log=None, unistats=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.log = log or logging.getLogger('worker')
        self.pq_client = pq_client
        self.should_stop = futures.Future()
        self.should_stop.set_running_or_notify_cancel()
        self.stopped = futures.Future()
        self.stopped.set_running_or_notify_cancel()
        self.config = config
        self.notifiers = {}
        self._set_notifiers(config['notifiers'])
        self.reader = None
        self.writer = None
        self.unistats = unistats if unistats is not None else defaultdict(int)
        self.seq = int(time.time())

    def _set_notifiers(self, config):
        for item in config:
            klass = item['classname']
            params = item['options']
            source_type = globals().get(klass)
            if source_type is None:
                raise TypeError("notifiers config has invalid class: %r" % (klass,))

            try:
                source = source_type.create(params)
            except Exception as e:
                raise ValueError("notifier source %r cannot be created: %s" % (klass, e))
            else:
                self.log.info("initialized: %s", source)
                self.notifiers[source.key()] = source

    def shutdown(self):
        self.should_stop.set_result(True)

        if self.reader is not None:
            self.reader.stop()

        return self.stopped

    def run(self):
        try:
            try:
                self.writer = self.pq_client.create_writer()
                self.create_reader()

                try:
                    self.serve_reader()
                except futures.CancelledError:
                    pass
            finally:
                if self.writer is not None:
                    self.writer.stop()
        except Exception:
            self.log.exception("worker failed:")
            raise
        finally:
            self.stopped.set_result(True)
            os._exit(0)

    def wait_futures(self, future_list, timeout=TIMEOUT, return_when=futures.FIRST_COMPLETED):
        """
        wait for futures and this instance should_stop flag
        :raise futures.CancelledError: if the should_stop flag is set
        """
        futures.wait((self.should_stop,) + tuple(future_list),
                     timeout=timeout, return_when=return_when)
        if self.should_stop.done():
            raise futures.CancelledError()  # FIXME use some special exception type

    def create_reader(self):
        self.unistats["pq_reconnect_dmmm"] += 1
        self.reader = self.pq_client.create_reader()
        while not self.should_stop.done():
            self.wait_futures((self.reader.start_future, self.reader.stop_future))
            if self.reader.stop_future.done():
                result = self.reader.stop_future.result(timeout=0)
                self.log.warning("reader connect failed: %s", result)
                self.reader = self.pq_client.create_reader()
            elif self.reader.start_future.done():
                result = self.reader.start_future.result(timeout=0)
                if not result.HasField('init'):
                    self.log.warning("reader connect failed: %s", result)
                    self.reader.stop()
                    self.reader = self.pq_client.create_reader()
                    continue
                self.log.info("reader successfully created")
                return

    def serve_reader(self):
        while not self.should_stop.done():
            event = self.get_message()
            self.unistats["events_read_dmmm"] += 1

            if event.type == ConsumerMessageType.MSG_LOCK:
                self.log.debug("partition offered to read: %s:%s",
                               event.message.lock.topic,
                               event.message.lock.partition)
                event.ready_to_read()
                continue
            elif event.type != ConsumerMessageType.MSG_DATA:
                self.log.debug("got message type: %s", event.type)
                continue

            for batch in event.message.data.message_batch:
                for message in batch.message:
                    self.unistats["messages_read_dmmm"] += 1
                    try:
                        self.log.debug("processing message: %s", message)
                        self.process_message(message.data)
                        self.unistats["messages_processed_dmmm"] += 1
                    except Exception:
                        self.log.exception("failed to process message")
                        self.unistats["messages_failed_dmmm"] += 1
            self.reader.commit([event.message.data.cookie])

    def get_message(self):
        try:
            event_future = self.reader.next_event()
        except ActorTerminatedException as e:
            self.log.warning("reader connection is dead: %s", e)
            self.create_reader()
            event_future = self.reader.next_event()

        while not self.should_stop.done():
            self.wait_futures((event_future, self.reader.stop_future))

            if self.reader.stop_future.done():
                result = self.reader.stop_future.result(timeout=0)
                self.log.warning("reader connection terminated: %s", result.reason)
                event_future.cancel()
                self.create_reader()
                continue

            if event_future.done():
                return event_future.result(timeout=0)

    def process_message(self, msg):
        try:
            msg = msgpack.loads(msg, raw=False)
        except Exception:
            msg = None

        if (
            not isinstance(msg, dict)
            or not all(key in msg for key in (
                'type', 'attempt', 'address', 'message', 'options')
            )
        ):
            self.log.info("got corrupted message, skipping")
            return

        self.log.debug("got message: %s", msg)
        msg['attempt'] += 1

        notifier = self.notifiers.get(msg['type'])
        if notifier is None:
            self.log.info("unknown notifier type: %r (attempt #%s)", msg['type'], msg['attempt'])
            self.push_back(msg)
            return

        address = notifier.read_user_address(msg['address'])
        if address is None:
            self.log.info("cannot parse address: %r (attempt #%s)", msg['address'], msg['attempt'])
            self.push_back(msg)
            return

        try:
            while msg is not None:
                msg = notifier.notify_user(address, msg)
                self.unistats["notifications_sent_dmmm"] += 1
        except NotifyException:
            self.unistats["notifications_failed_dmmm"] += 1
            self.log.exception("failed to notify (type %r, attempt #%s)", msg['type'], msg['attempt'])
            self.push_back(msg)

    def check_writer(self):
        if self.writer.stop_future.done():
            result = self.writer.stop_future.result(timeout=0)
            self.log.warning("writer connection terminated: %s", result.reason)
            self.writer = self.pq_client.create_writer()

    def push_back(self, msg):
        if msg['attempt'] > 100:  # FIXME configurable value
            self.log.warning("failed to deliver message (type %r, attempt #%s), dropping",
                             msg['type'], msg['attempt'])
            return

        message = msgpack.dumps(msg)
        self.check_writer()

        for attempt in range(10):
            seq, self.seq = self.seq, self.seq + 1
            write_future = self.writer.write(seq, message)
            futures.wait((write_future, self.writer.stop_future),
                         timeout=TIMEOUT, return_when=futures.FIRST_COMPLETED)
            if write_future.done():
                result = write_future.result(timeout=0)
                if result.HasField('ack'):
                    return

            self.log.warning("failed to push back message (type %r, attempt #%s)", msg['type'], msg['attempt'])

            self.check_writer()

        self.log.error("completely failed to push back message (type %r, attempt #%s)", msg['type'], msg['attempt'])
