import yt.wrapper as yt

import logging
import socket
import time
import os

from copy import deepcopy
from threading import Thread, RLock

BACKOFF_SLEEP_SECONDS = 1.0
BACKOFF_SLEEP_SECONDS_MAX = 10.0
BACKOFF_SLEEP_SECONDS_MULTIPLIER = 1.2

logger = logging.getLogger(__name__)


class _LockPathThread(Thread):
    def __init__(self, yt_client, host_path):
        super(_LockPathThread, self).__init__()
        self.daemon = True

        self._yt_client = yt.YtClient(config=deepcopy(yt_client.config))
        self._host_path = host_path
        self._lock = RLock()
        self._is_path_locked = False

    def is_path_locked(self):
        with self._lock:
            return self._is_path_locked

    def run(self):
        with self._yt_client.Transaction():
            self._yt_client.lock(self._host_path, "shared")
            with self._lock:
                self._is_path_locked = True
            while True:
                time.sleep(BACKOFF_SLEEP_SECONDS)


class Synchronizer:
    def __init__(self, yt_client, path, sshd_port, participants_count):
        self.yt_client = yt_client
        self.path = path
        self.sshd_port = sshd_port
        self.participants_count = participants_count

        self.hostname = socket.gethostname()

        self.host_path = yt.ypath_join(self.path, self.hostname + "-" + os.environ.get('YT_JOB_INDEX') + "-" + str(self.sshd_port))
        self.completed_path = yt.ypath_join(self.path, "completed")

        logger.info("Synchronizer for %s created", self.hostname)

    def get_nodes_attr_val_dict(self, attr):
        with self.yt_client.Transaction(transaction_id="0-0-0-0"):
            nodes = self.yt_client.list(self.path, attributes=["key", attr])
            nodes_attr_dict = {}
            for node in nodes:
                val = node.attributes.get(attr)
                if val:
                    nodes_attr_dict[node.attributes["key"]] = val
            return nodes_attr_dict

    def _get_active_hosts(self):
        with self.yt_client.Transaction(transaction_id="0-0-0-0"):
            nodes = self.yt_client.list(self.path, attributes=["locks", "sshd_hostname", "sshd_user", "sshd_port", "role"])
            return [node for node in nodes if node.attributes.get("locks") and
                                              node.attributes.get("sshd_hostname") and
                                              node.attributes.get("sshd_user") and
                                              node.attributes.get("sshd_port") and
                                              node.attributes.get("role")]

    def _wait_full_house(self):
        logger.info("Waiting all participants to register in cypress")

        sleep_time = BACKOFF_SLEEP_SECONDS
        while True:
            active_hosts = self._get_active_hosts()
            if len(active_hosts) == self.participants_count:
                break
            time.sleep(sleep_time)
            sleep_time = min(sleep_time * BACKOFF_SLEEP_SECONDS_MULTIPLIER, BACKOFF_SLEEP_SECONDS_MAX)

        ordered_hosts = []
        for host in active_hosts:
            host_kv = {'hostname': str(host)}
            for k in ['sshd_hostname', 'sshd_user', 'sshd_port', 'role']:
                host_kv[k] = host.attributes[k]
            if host_kv['role'] == 'master':
                ordered_hosts.insert(0, host_kv)
            else:
                ordered_hosts.append(host_kv)

        self.active_hosts = ordered_hosts
        logger.info("All participants have registered in cypress")

    def wait_attribute_set(self, attr):
        logger.info("Waiting attribute '%s' at all nodes", attr)

        sleep_time = BACKOFF_SLEEP_SECONDS
        while True:
            nodes_with_attr = []
            with self.yt_client.Transaction(transaction_id="0-0-0-0"):
                nodes = self.yt_client.list(self.path, attributes=[attr])
                nodes_with_attr = [node for node in nodes if node.attributes.get(attr)]
            if len(nodes_with_attr) == self.participants_count:
                break

            time.sleep(sleep_time)
            sleep_time = min(sleep_time * BACKOFF_SLEEP_SECONDS_MULTIPLIER, BACKOFF_SLEEP_SECONDS_MAX)

        logger.info("Attribute '%s' set at all nodes", attr)

    def write_attribute(self, attr, val):
        path = self.host_path + "/@" + attr
        logger.info("Write attribute (path: %s, val: %s)", path, val)
        with self.yt_client.Transaction(transaction_id="0-0-0-0"):
            self.yt_client.set(path, val)

    def set_attribute(self, attr):
        self.write_attribute(attr, True)

    def set_completed(self):
        logger.info("Set completed")
        self.yt_client.create("map_node", self.completed_path)

    def register_master(self, sshd_user):
        """ Register myself, start master transaction and return control """
        self.yt_client.create("map_node", self.host_path, recursive=True)
        self.write_attribute("sshd_hostname", self.hostname)
        self.write_attribute("sshd_user", sshd_user)
        self.write_attribute("sshd_port", self.sshd_port)
        self.write_attribute("role", "master")

        self.lock_thread = _LockPathThread(self.yt_client, self.host_path)
        self.lock_thread.start()
        while True:
            if self.lock_thread.is_path_locked():
                break
            if not self.lock_thread.is_alive():
                raise Exception("Lock thread failed")
            time.sleep(BACKOFF_SLEEP_SECONDS)

        self._wait_full_house()

        logger.info("Master process registered")

    def register_slave(self, sshd_user, post_action):
        """ Register myself and wait master completion or abort """
        self.yt_client.create("map_node", self.host_path, recursive=True)
        self.write_attribute("sshd_hostname", self.hostname)
        self.write_attribute("sshd_user", sshd_user)
        self.write_attribute("sshd_port", self.sshd_port)
        self.write_attribute("role", "slave")

        with self.yt_client.Transaction():
            self.yt_client.lock(self.host_path, "shared")
            self._wait_full_house()
            logger.info("Slave process registered")

            post_action(self.active_hosts)

            sleep_time = BACKOFF_SLEEP_SECONDS
            while True:
                logger.debug("Check completion path")

                active_hosts = self._get_active_hosts()
                if self.yt_client.exists(self.completed_path):
                    logger.info("Completion path has appeared")
                    return
                elif len(active_hosts) != self.participants_count:
                    raise Exception("Some participants have left the house (active: {}, expected: {})"
                                    .format(len(active_hosts), self.participants_count))

                time.sleep(sleep_time)
                sleep_time = min(sleep_time * BACKOFF_SLEEP_SECONDS_MULTIPLIER, BACKOFF_SLEEP_SECONDS_MAX)
