from __future__ import absolute_import

import os
import abc
import json
import time
import uuid
import random
import socket
import select
import struct
import logging
import itertools as it
import threading as th
import functools as ft
import collections

try:
    import setproctitle
except ImportError:
    setproctitle = None


class Message(object):
    # noinspection PyPep8Naming
    class __metaclass__(abc.ABCMeta):
        __message_types__ = {}

        def __new__(mcs, name, bases, namespace):
            cls = abc.ABCMeta.__new__(mcs, name, bases, namespace)
            message_type = namespace.get("__type__")
            if message_type:
                mcs.__message_types__[message_type] = cls
            return cls

    __slots__ = ["msg_id", "node_id", "term"]
    __type__ = None
    FORMAT = struct.Struct("c32sIQQQQ")

    @abc.abstractmethod
    def __init__(self, msg_id, node_id, term):
        self.msg_id = msg_id or uuid.uuid4().hex
        self.node_id = node_id
        self.term = term

    def __repr__(self):
        return "{}({})".format(
            type(self).__name__, ", ".join(
                map("=".join, zip(self.__slots__, map(str, it.chain((self.msg_id, self.node_id, self.term), self))))
            )
        )

    @abc.abstractmethod
    def __iter__(self):
        pass

    def pack(self):
        return self.FORMAT.pack(*it.chain((self.__type__, self.msg_id, self.node_id, self.term), self))

    @classmethod
    def unpack(cls, data):
        items = cls.FORMAT.unpack(data)
        # noinspection PyUnresolvedReferences
        message_type = cls.__message_types__.get(items[0])
        if not message_type:
            raise TypeError("Unknown message type: {}".format(items[0]))
        return message_type(*items[1:])


class VoteRequest(Message):
    __slots__ = Message.__slots__ + ["last_term", "last_version"]
    __type__ = "V"

    def __init__(self, msg_id, node_id, term, last_term, last_version, *_):
        super(VoteRequest, self).__init__(msg_id, node_id, term)
        self.last_term = last_term
        self.last_version = last_version

    def __iter__(self):
        return iter((self.last_term, self.last_version, 0))

    @property
    def last(self):
        return self.last_term, self.last_version


class VoteResponse(Message):
    __slots__ = Message.__slots__ + ["granted"]
    __type__ = "v"

    def __init__(self, msg_id, node_id, term, granted, *_):
        super(VoteResponse, self).__init__(msg_id, node_id, term)
        self.granted = granted

    def __iter__(self):
        return iter((int(self.granted), 0, 0))


class Request(Message):
    __slots__ = Message.__slots__ + ["prev_term", "prev_version", "version"]
    __type__ = "R"

    def __init__(self, msg_id, node_id, term, prev_term, prev_version, version):
        super(Request, self).__init__(msg_id, node_id, term)
        self.prev_term = prev_term
        self.prev_version = prev_version
        self.version = version

    def __iter__(self):
        return iter((self.prev_term, self.prev_version, self.version))

    @property
    def prev(self):
        return self.prev_term, self.prev_version


class Response(Message):
    __slots__ = Message.__slots__ + ["success"]
    __type__ = "r"

    def __init__(self, msg_id, node_id, term, success, *_):
        super(Response, self).__init__(msg_id, node_id, term)
        self.success = success

    def __iter__(self):
        return iter((int(self.success), 0, 0))


class Node(object):
    FOLLOWER = 0
    CANDIDATE = 1
    LEADER = 2

    TrackedPacket = collections.namedtuple("TrackedPacket", "msg_id deadline")

    class Exception(Exception):
        pass

    class NotLeader(Exception):
        pass

    class State(object):
        class Value(collections.namedtuple("Value", "term version")):
            Empty = None

            def __nonzero__(self):
                return self != self.Empty

        Value.Empty = Value(0, 0)

        __term = 0
        __voted_for = None
        __prev_value = Value.Empty
        __value = Value.Empty

        def __init__(self, filename, logger):
            self.__filename = filename
            self.__logger = logger
            self.__load()

        def __load(self):
            if self.__filename and os.path.exists(self.__filename):
                self.__logger.debug("Loading state from '%s'", self.__filename)
                try:
                    with open(self.__filename) as f:
                        data = json.load(f)
                    self.__term = data["term"]
                    self.__voted_for = data["voted_for"]
                    self.__prev_value = self.Value(*data["prev_value"])
                    self.__value = self.Value(*data["value"])
                except Exception as ex:
                    self.__logger.error("Error while reading from '%s': %s", self.__filename, ex)
                    raise

        def __save(self):
            if self.__filename:
                self.__logger.debug("Saving state to '%s'", self.__filename)
                try:
                    tmp_name = self.__filename + "~"
                    with open(tmp_name, "w") as f:
                        data = {
                            "term": self.__term,
                            "voted_for": self.__voted_for,
                            "prev_value": tuple(self.__prev_value),
                            "value": tuple(self.__value)
                        }
                        json.dump(data, f)
                    os.rename(tmp_name, self.__filename)
                except Exception as ex:
                    self.__logger.error("Error while reading from '%s': %s", self.__filename, ex)
                    raise

        @property
        def term(self):
            return self.__term

        @term.setter
        def term(self, value):
            if self.__term == value:
                return
            assert self.__term < value, "Can only increase term, current: {}, new: {}".format(self.__term, value)
            self.__term = value
            self.__save()

        @property
        def voted_for(self):
            return self.__voted_for

        @voted_for.setter
        def voted_for(self, value):
            if self.__voted_for == value:
                return
            assert not self.__voted_for or not value, "Cannot vote for {}, already voted for {}".format(
                value, self.__voted_for
            )
            self.__voted_for = value
            self.__save()

        @property
        def prev_value(self):
            return self.__prev_value

        @property
        def value(self):
            return self.__value

        @property
        def version(self):
            return self.__value.version

        @version.setter
        def version(self, value):
            assert self.__value.version < value, "Can only increase version, current: {}, new: {}".format(
                self.__value.version, value
            )
            self.__prev_value = self.__value
            self.__value = self.Value(self.__term, value)
            self.__save()

    def __init__(
        self, node_id, nodes, state_filename=None, election_timeout=1, heartbeat=.1, packet_ttl=.5, process_prefix=None,
        logger=None
    ):
        self.__nodes = {node["id"]: (node["host"], node["port"]) for node in nodes}
        assert node_id in self.__nodes
        self.__quorum = (len(self.__nodes)) // 2 + 1
        self.__node_id = node_id
        self.__base_election_timeout = election_timeout
        self.__election_timeout = self.__base_election_timeout
        self.__heartbeat = heartbeat
        self.__packet_ttl = packet_ttl
        self.__port = self.__nodes.pop(node_id)[1]
        self.__logger = logger or logging
        self.__process_prefix = process_prefix

        self.__ro_socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        self.__wo_sockets = {}
        for node_id, addr in self.__nodes.iteritems():
            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            sock.setblocking(0)
            sock.connect(addr)
            self.__wo_sockets[node_id] = sock

        self.__running = False
        self.__thread = None
        self.__lock = th.Lock()
        self.__role = None
        self.__last_time = None
        self.__current_leader = None

        self.__state = self.State(state_filename, self.__logger)

        # volatile state
        self.__votes = None
        # leader
        self.__tracked_packets = None
        self.__acked = None
        self.__commit_event = th.Event()

    def __repr__(self):
        return "<Node {}: {} {} {} {}>".format(
            self.__node_id,
            {self.FOLLOWER: "F", self.CANDIDATE: "C", self.LEADER: "L"}.get(self.__role, "N"),
            self.__state.term, self.__state.value.term, self.__state.version
        )

    def start(self):
        self.__running = True
        self.__ro_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.__ro_socket.bind(("", self.__port))
        self._become(self.FOLLOWER)
        self.__thread = th.Thread(target=self._loop)
        self.__thread.daemon = True
        self.__thread.start()

    def stop(self):
        self.__running = False

    @property
    def role(self):
        return self.__role

    @property
    def version(self):
        return self.__state.version

    @version.setter
    def version(self, value):
        with self.__lock:
            if self.__role != self.LEADER:
                raise self.NotLeader
            self.__state.version = value
            self.__commit_event.clear()
            self.__acked = {self.__node_id}
            self._replicate()
        self.__commit_event.wait()
        if self.__role != self.LEADER:
            raise self.NotLeader

    def _sendto(self, message, node_id):
        try:
            sock = self.__wo_sockets[node_id]
            if self.__tracked_packets and isinstance(message, Request) and message.version:
                now = time.time()
                tp = self.__tracked_packets[node_id]
                if tp and tp.deadline > now:
                    self.__logger.debug("%s Skip sending %s to <Node %s>", self, message, node_id)
                    return
                self.__tracked_packets[node_id] = self.TrackedPacket(
                    message.msg_id, now + self.__packet_ttl
                )
            self.__logger.debug("%s Sending %s to <Node %s>", self, message, node_id)
            sock.send(message.pack())
        except socket.error as ex:
            self.__logger.error("%s Error while sending message %s to <Node %s>: %s", self, message, node_id, ex)

    def _sendtoall(self, message):
        for node_id in self.__nodes:
            self._sendto(message, node_id)

    def __check_message(self, message):
        if not self.__tracked_packets:
            return True
        tracked_message = self.__tracked_packets[message.node_id]
        accepted = tracked_message and tracked_message.msg_id == message.msg_id
        if accepted:
            self.__tracked_packets[message.node_id] = None
        return accepted

    def _loop(self):
        self.__logger.info("%s Starting distributed node loop", self)
        last_replicate = time.time()
        while self.__running:
            readable = select.select([self.__ro_socket], [], [], self.__heartbeat / 10)[0]
            with self.__lock:
                for sock in readable:
                    data = sock.recv(Message.FORMAT.size)
                    message = Message.unpack(data)
                    self.__logger.debug("%s Received %s", self, message)
                    self._process_common(message)
                    if not self.__check_message(message):
                        self.__logger.debug("%s Skip processing %s", self, message)
                        break
                    if self.__role == self.FOLLOWER:
                        self._process_follower(message)
                    elif self.__role == self.CANDIDATE:
                        self._process_candidate(message)
                    elif self.__role == self.LEADER:
                        self._process_leader(message)
                    break
                else:
                    if self.__role == self.FOLLOWER:
                        if self._timeout_elapsed:
                            self._become(self.CANDIDATE)
                    elif self.__role == self.CANDIDATE:
                        if self._timeout_elapsed:
                            self._become(self.FOLLOWER)
                    elif self.__role == self.LEADER:
                        if time.time() - last_replicate >= self.__heartbeat:
                            last_replicate = time.time()
                            self._replicate()
        self.__logger.info("%s Distributed node loop stopped", self)

    def _reset_timeout(self):
        self.__last_time = time.time()
        self.__election_timeout = self.__base_election_timeout + random.uniform(0, self.__heartbeat)

    @property
    def _timeout_elapsed(self):
        return time.time() - self.__last_time >= self.__election_timeout

    @ft.partial(property, None)
    def _process_title(self, title):
        self.__logger.info("%s Become %s", self, title)
        if setproctitle and self.__process_prefix:
            setproctitle.setproctitle("{} {}".format(self.__process_prefix, title))

    def _become(self, role):
        self._reset_timeout()
        self.__role = role
        if role == self.FOLLOWER:
            self.__commit_event.set()
            self.__state.voted_for = None
            self.__tracked_packets = None
            self._process_title = "FOLLOWER"
        elif role == self.CANDIDATE:
            self.__commit_event.set()
            self.__votes = {self.__node_id}
            self.__state.term += 1
            self._process_title = "CANDIDATE"
            self.__tracked_packets = None
            # noinspection PyTypeChecker
            self._sendtoall(
                VoteRequest(None, self.__node_id, self.__state.term, *self.__state.value)
            )
        elif role == self.LEADER:
            self._process_title = "LEADER"
            self.__tracked_packets = {node_id: None for node_id in self.__nodes}
            self.__acked = set()
            self._replicate()
        else:
            raise ValueError("Wrong role: {}".format(role))

    def _process_common(self, message):
        if self.__role != self.FOLLOWER and message.term > self.__state.term:
            self.__state.term = message.term
            self._become(self.FOLLOWER)

    def _process_follower(self, message):
        if isinstance(message, VoteRequest):
            if message.term < self.__state.term:
                granted = False
            elif (
                self.__state.voted_for in (None, message.node_id) and
                self.__state.value <= message.last
            ):
                self.__state.term = message.term
                self._reset_timeout()
                self.__state.voted_for = message.node_id
                self.__logger.debug("%s Voting for %s", self, self.__state.voted_for)
                granted = True
            else:
                granted = False
            # noinspection PyTypeChecker
            self._sendto(VoteResponse(message.msg_id, self.__node_id, self.__state.term, granted), message.node_id)
        elif isinstance(message, Request):
            if not message.version:
                self._reset_timeout()
                return
            if message.term < self.__state.term:
                # noinspection PyTypeChecker
                self._sendto(Response(message.msg_id, self.__node_id, self.__state.term, False), message.node_id)
            else:
                self.__state.term = message.term
                self.__state.voted_for = None
                if self.__current_leader != message.node_id:
                    self.__current_leader = message.node_id
                    self.__logger.info("%s Current leader: <Node %s>", self, self.__current_leader)
                self._reset_timeout()
                if not self.__state.value or self.__state.version <= message.prev_version:
                    self.__state.version = message.version
                    # noinspection PyTypeChecker
                    self._sendto(Response(message.msg_id, self.__node_id, self.__state.term, True), message.node_id)

    def _process_candidate(self, message):
        if isinstance(message, VoteResponse):
            if message.granted:
                self.__votes.add(message.node_id)
                if len(self.__votes) >= self.__quorum:
                    self._become(self.LEADER)
        elif isinstance(message, Request):
            self._become(self.FOLLOWER)

    def _process_leader(self, message):
        if isinstance(message, Response):
            if message.success:
                self.__acked.add(message.node_id)
                if not self.__commit_event.is_set():
                    if len(self.__acked) >= self.__quorum:
                        self.__commit_event.set()
        self._replicate()

    def _replicate(self):
        for node_id in self.__nodes:
            acked = node_id in self.__acked
            message = Request(
                None,
                self.__node_id,
                self.__state.term,
                self.__state.prev_value.term,
                self.__state.prev_value.version,
                0 if acked else self.__state.version
            )
            self._sendto(message, node_id)
