from __future__ import print_function
import sys
import random
import threading
import time
import subprocess
import socket
import cPickle as pickle
import os
import struct

from .nomoresky import ExecutionTreeNode
from .nomoresky import log
from .nomoresky import TimeoutException
from .nomoresky import HostMessage
from .nomoresky import WriteError


AGENT_THREADS_COUNT = 20


# Completely shuffle slist inplace
def shuffle(slist):
    def swap(tlist, pos1, pos2):
        stack = tlist[pos1]
        tlist[pos1] = tlist[pos2]
        tlist[pos2] = stack

    if not isinstance(slist, list):
        raise ValueError("Shuffle lists inplace")
    for i in range(len(slist) - 1):
        swap(slist, i, random.randrange(i + 1, len(slist)))


def resolveHostsBySkyList(clause):
    if isinstance(clause, list):
        skyListArgs = clause
    elif type(clause) in (str, unicode):
        skyListArgs = clause.split()
    else:
        # last resort: try to make list from clause
        skyListArgs = list(clause)

    devnull = open("/dev/null", "rw")
    sp = subprocess.Popen(["sky", "list"] + skyListArgs,
                          stdout=subprocess.PIPE, stdin=devnull,
                          close_fds=True, bufsize=4096)
    return set([line.strip() for line in sp.stdout])


# stolen from skynet
def formatHosts(hosts, useGrouping=True, addDomain=None):
    from library.format import splitHostName

    if not useGrouping:
        hosts.sort()
        return ' '.join(hosts)
    defaultDomains = ['.yandex.ru'] if addDomain is None else []

    results = []
    groups = {}
    named_hosts = []

    for host in hosts:
        if not host:
            continue
        prefix, number, domain = splitHostName(host)
        if number == -1:
            named_hosts.append((prefix, domain))
        else:
            groups.setdefault(
                (prefix, domain), []).append((int(number), number))
    for (prefix, domain), suffixes in groups.iteritems():
        suffixes.sort()
        suffixes.append((99999999999999, 0))
        start = 0
        singles = []
        for i in range(1, len(suffixes)):
            if suffixes[i - 1][0] + 1 != suffixes[i][0]:
                if i - start > 3:
                    if singles:
                        if len(singles) > 1:
                            results.append('%s{%s}%s' % (
                                    prefix, ','.join(singles),
                                    (domain if domain not in defaultDomains
                                     else '')
                            ))
                        else:
                            results.append('%s%s%s' % (
                                    prefix, singles[0],
                                    (domain if domain not in defaultDomains
                                     else ''))
                            )
                        del singles[:]
                    results.append('%s{%s..%s}%s' % (
                        prefix, suffixes[start][1],
                        suffixes[i - 1][1],
                        (domain if domain not in defaultDomains else '')
                    ))
                else:
                    singles.extend((suffixes[j][1] for j in range(start, i)))
                start = i
        if singles:
            if len(singles) > 1:
                results.append('%s{%s}%s' % (
                        prefix, ','.join(singles),
                        (domain if domain not in defaultDomains else ''))
                )
            else:
                results.append('%s%s%s' % (
                        prefix, singles[0],
                        (domain if domain not in defaultDomains else '')))
    results.sort()
    for (prefix, domain) in named_hosts:
        results.append('%s%s' % (
                prefix, (domain if domain not in defaultDomains else '')))
    return ' '.join(results)


# Results receiver for sdel.py
# Implements simple timeout logic for sdel.py
class ResultDictEater(dict):
    def __init__(self, timeout=None):
        if __debug__:
            if timeout is not None and type(timeout) not in (int, float):
                raise ValueError("timeout should be integer of float")

        super(ResultDictEater, self).__init__()
        self.__lock = threading.RLock()
        self.__timeout = timeout
        self.__nextResultTime = None
        self.__hostToResult = dict()
        self.__firstPositiveCompletion = False

    def __call__(self, hr):
        if hr is None:
            return
        if __debug__:
            log("receive stream %s for host %s" % (hr.stream, hr.host))
        if hr.stream in ("error", "ssh error"):
            hr.data = hr.data.replace(hr.host, "%host%")
        hrkey = ((hr.data, hr.stream), )
        if hr.stream == 'exitcode' and hr.data == 0:
            self.__firstPositiveCompletion = True
        with self.__lock:
            if hr.host in self.__hostToResult:
                prevhrkey = self.__hostToResult[hr.host]
                del self.__hostToResult[hr.host]
                hrkey = prevhrkey + hrkey
                self[prevhrkey].discard(hr.host)
                if len(self[prevhrkey]) == 0:
                    del self[prevhrkey]

            if hrkey not in self:
                self[hrkey] = set([hr.host])
            else:
                self[hrkey].add(hr.host)
            self.__hostToResult[hr.host] = hrkey

    def up(self):
        if self.__timeout is not None:
            self.__nextResultTime = time.time() + self.__timeout

    def wait(self, condlock):
        # 0.5 is for keyboard interrupt
        if self.__nextResultTime is None or not self.__firstPositiveCompletion:
            waittime = 0.5
        else:
            waittime = min(0.5, self.__nextResultTime - time.time())
        if waittime <= 0:
            raise TimeoutException
        if __debug__:
            log("wait for result with timeout")
        condlock.wait(waittime)


# Simple receiver of results, without buffer, one by one with locking
# of pushers and poppers.
# Implements simple timeout logic.
# It is used in cpu-load.py.
class OneByOneEater(object):
    def __init__(self, timeout=None):
        self.__lock = threading.Condition()
        self.__result = False
        self.__nextResultTime = None
        self.__timeout = timeout

    def __call__(self, hr):
        with self.__lock:
            while self.__result is not False:
                self.__lock.wait()
            self.__result = hr
            if self.__timeout is not None\
                    and hr.stream in ("exitcode"):
                self.__nextResultTime = time.time() + self.__timeout
            self.__lock.notifyAll()

    def up(self):
        if self.__timeout is not None and self.__nextResultTime is not None:
            self.__nextResultTime = time.time() + self.__timeout

    def pop(self):
        with self.__lock:
            while self.__result is False:
                # 0.5 is for keyboard interrupt
                if self.__nextResultTime is None:
                    waittime = 0.5
                else:
                    waittime = min(0.5, self.__nextResultTime - time.time())
                if waittime <= 0:
                    raise TimeoutException
                self.__lock.wait(waittime)
            self.__lock.notifyAll()
            result = self.__result
            if result is None:
                return
            self.__result = False
            return result


# Print results of ResultDictEater
# It is used in sdel.py
def printCompactOutput(eater, task):
    try:
        task.waitForAllDone(eater)
    except KeyboardInterrupt:
        print >> sys.stderr, "user asks for earlier results report"
        pass

    hostSet = task.getSilentHosts()
    if hostSet:
        print >> sys.stderr
        print >> sys.stderr, "%d hosts with unknown status: %s" % (
            len(hostSet), formatHosts(hostSet))

    hostSet = task.getUnfinishedHosts()
    if hostSet:
        print >> sys.stderr
        print >> sys.stderr, "%d unfinished hosts: " %\
            len(hostSet) + formatHosts(hostSet)
    sys.stderr.flush()

    for hrkey, hostset in sorted(eater.items(), key=lambda x: len(x[1]), reverse=True):
        hostsPrinted = False
        prettyHostList = formatHosts(hostset)
        for data, stream in hrkey:
            if stream not in ("exitcode", "stdout"):
                continue
            if not hostsPrinted:
                print()
                print(len(hostset), "hosts:", prettyHostList)
                hostsPrinted = True
            print("type of result:", stream)
            print(str(data).rstrip())
        sys.stdout.flush()
        hostsPrinted = False
        for data, stream in hrkey:
            if stream in ("exitcode", "stdout"):
                continue
            if not hostsPrinted:
                print >> sys.stderr
                print >> sys.stderr, len(hostset), "hosts:", prettyHostList
                hostsPrinted = True
            print >> sys.stderr, "type of result:", stream
            print >> sys.stderr, str(data).rstrip()
        sys.stderr.flush()


# NoskyTask push object of that class is pushed into an eater
class HostResult(object):
    __slots__ = ['host', 'stream', 'data']

    def __init__(self, host, data, stream):
        self.host = host
        self.stream = stream
        self.data = data

    def __repr__(self):
        return "%s: %s" % (self.host, self.data)


# Implements parent node of the execution tree
class NoskyTask(ExecutionTreeNode):
    def __init__(self, hosts=None, eater=None, cmd=None, flags=None):
        self.__startedHosts = set()
        self.__cantStartHosts = set()
        self.__finishedHosts = set()
        self.__startingLock = threading.Condition()
        self.__totalHosts = set()
        self.__broadcastList = []
        self.__broadcastLock = threading.RLock()
        self.__resultCallback = None
        self.__eater = eater
        self.__forwardChunkSize = 0
        self.__hostListToForward = set()

        self.__agentQueue = AgentQueue()
        for i in range(AGENT_THREADS_COUNT):
            AgentAsker(self, self.__agentQueue)

        super(NoskyTask, self).__init__(cmd=cmd, flags=flags)

        self.addNewHost(hosts)

    def _forwardToHost(self, newhost):
        if len(self.__hostListToForward) == 0:
            super(NoskyTask, self)._forwardToHost(newhost)
            return

        itemsCount = self.__forwardChunkSize
        if itemsCount > len(self.__hostListToForward):
            itemsCount = len(self.__hostListToForward)

        if itemsCount <= 0:
            super(NoskyTask, self)._forwardToHost(newhost)
            return

        if __debug__:
            log("itemsCount == %d" % itemsCount)
        nextChunk = set()
        nextChunk.add(newhost)
        for x in range(itemsCount):
            nextChunk.add(self.__hostListToForward.pop())
        try:
            super(NoskyTask, self)._forwardToHost(nextChunk)
        except WriteError:
            self.__hostListToForward |= nextChunk ^ set([newhost])
            raise

    def _addNewHostWithDoneLock(self, newhost):
        with self._doneLock:
            if not self._doneFlag:
                self._addNewHostUnderDoneLock(newhost)
            else:
                self.linkError(newhost)

    def addNewHost(self, newhost):
        if type(newhost) in (set, frozenset, tuple, list):
            self.__totalHosts |= set(newhost)
            self.__hostListToForward = set(newhost)
            self.__forwardChunkSize =\
                (len(newhost) - 1) / self.CONNECTIONS_PER_NODE + 1
            while len(self.__hostListToForward) > 0:
                nextHost = self.__hostListToForward.pop()
                if __debug__:
                    log("next host " + str(nextHost))
                self._addNewHostWithDoneLock(nextHost)
        else:
            self.__totalHosts.add(newhost)
            self._addNewHostWithDoneLock(newhost)

    def allowExit(self):
        if self._allowedToExit:
            return
        self._pushAllowExitToAllConns()
        self._allowedToExit = True
        if __debug__:
            log("allow exit")

    def __hostStarted(self, host, success=True):
        if success:
            self.__startedHosts.add(host)
        else:
            self.__cantStartHosts.add(host)

    def __treatPushedMessage(self, msg):
        if self.__eater is not None:
            self.__eater.up()

        if msg.msgtype in ("stdout", "stderr", "exitcode",
                           "error", "ssh error"):
            if msg.msgtype == "exitcode":
                self.__finishedHosts.add(msg.host)
            if self.__eater is not None:
                try:
                    hr = HostResult(msg.host, msg.data, msg.msgtype)
                    self.__eater(hr)
                except:
                    pass
            else:
                if __debug__:
                    log("eater is None")
        else:
            if __debug__:
                log("message received: (host=%s, type=%s, data=%s)" %
                    (msg.host, msg.msgtype, msg.data))

        if msg.msgtype == 'cmd':
            if msg.data == 'started':
                self.__hostStarted(msg.host)
            elif msg.data == 'not started':
                self.__hostStarted(msg.host, False)
        elif msg.msgtype == 'link failed':
            affectedHosts = msg.data
            # causedHost = msg.host
            if type(affectedHosts) not in (list, tuple, set, frozenset):
                affectedHosts = [affectedHosts]
            for affectedHost in affectedHosts:
                if affectedHost not in self.__startedHosts:
                    with self.__broadcastLock:
                        self.addNewHost(affectedHost)
                        for data, stream in self.__broadcastList:
                            downmsg = HostMessage(affectedHost, data, stream)
                            self._pushDown(downmsg)
        elif msg.msgtype == 'agent':
            self.__agentQueue.push(msg)

    def pushThroughCache(self, msg):
        self._push(msg)

    def _push(self, msg):
        if not self.isActive:
            return
        if isinstance(msg, list):
            for item in msg:
                self.__treatPushedMessage(item)
        elif isinstance(msg, HostMessage):
            self.__treatPushedMessage(msg)
        if self.__eater is not None:
            self.__eater.up()

    def _hasBeenDone(self):
        if self.__eater is not None:
            self.__eater(None)

    # Send the data to the host
    def pushDown(self, host, data):
        msg = HostMessage(host, data, 'stdin')
        if host == '*':
            with self.__broadcastLock:
                if self.__totalHosts - self.__startedHosts -\
                        self.__cantStartHosts:
                    self.__broadcastList.append((data, 'stdin'))
                else:
                    self.__broadcastList = []
                self._pushDown(msg)
        else:
            self._pushDown(msg)

    def pushDownCloseStdin(self, host):
        msg = HostMessage(host, 'close stdin', 'cmd')
        if host == '*':
            with self.__broadcastLock:
                if self.__totalHosts - self.__startedHosts -\
                        self.__cantStartHosts:
                    self.__broadcastList.append((msg.data, msg.msgtype))
                else:
                    self.__broadcastList = []
                self._pushDown(msg)
        else:
            self._pushDown(msg)

    @property
    def eater(self):
        return self.__eater

    @eater.setter
    def eater(self, value):
        self.__eater = value
        self.allowExit()

    def waitForAllDone(self, timeout=None):
        if __debug__:
            log("wait for all done")
        self.allowExit()
        with self._doneLock:
            if self._doneFlag:
                return
            while not self._doneFlag:
                try:
                    if timeout is not None:
                        try:
                            timeout.wait(self._doneLock)
                        except TimeoutException:
                            self._doneFlag = True
                            if __debug__:
                                log("done = True, by timeout")
                            self._doneLock.notifyAll()
                            self._hasBeenDone()
                    else:
                        self._doneLock.wait(2)
                except Exception:
                    self._doneFlag = True
                    if __debug__:
                        log("done = True, by exception")
                    self._doneLock.nofityAll()
                    self._hasBeenDone()
                    raise

    def getSilentHosts(self):
        return self.__totalHosts - self.__startedHosts - self.__cantStartHosts

    def getUnfinishedHosts(self):
        return self.__startedHosts - self.__cantStartHosts\
            - self.__finishedHosts


class AgentQueue(object):
    def __init__(self):
        self.__lock = threading.Condition()
        self.__queue = []

    def push(self, msg):
        with self.__lock:
            if len(self.__queue) < AGENT_THREADS_COUNT:
                self.__lock.notify()
            self.__queue.append(msg)

    def pop(self):
        with self.__lock:
            while len(self.__queue) == 0:
                self.__lock.wait()
            return self.__queue.pop(0)


class AgentAsker(threading.Thread):
    def __init__(self, receiver, agentQueue):
        super(AgentAsker, self).__init__()
        self.__queue = agentQueue
        self.__receiver = receiver
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(os.getenv("SSH_AUTH_SOCK"))
        if __debug__:
            log("connected to ssh-agent")
        self.__sock = sock.makefile()
        self.daemon = True
        self.start()

    def run(self):
        try:
            while True:
                msg = self.__queue.pop()
                self.__forwardAgentMsg(msg)
        except:
            if __debug__:
                log("exception in writeloop")

    def __forwardAgentMsg(self, msg):
        if __debug__:
            log("let's treat agent forwarded request for %s" %
                msg.host)
        (port, data) = pickle.loads(msg.data)
        self.__sock.write(data)
        self.__sock.flush()
        if __debug__:
            log("wrote ssh-agent request")
        reply = self.__sock.read(4)
        if len(reply) < 4:
            raise Exception("connection to the ssh-agent is lost")
        (dataLen, ) = struct.unpack("!I", reply)
        reply += self.__sock.read(dataLen)
        if dataLen + 4 > len(reply):
            raise Exception("connection to the ssh-agent is lost")
        if __debug__:
            log("read ssh-agent reply: %s" % msg.host)
        rdata = pickle.dumps((port, reply), pickle.HIGHEST_PROTOCOL)
        downMsg = HostMessage(msg.host, rdata, 'agent')
        self.__receiver._pushDown(downMsg)
