from __future__ import print_function
import os
import sys
import time
import errno
import fcntl
import signal
import socket
import threading
import contextlib
from itertools import chain

from ..utils import genuuid, fqdn, getaddrinfo, monotime, log as root
from ..transport import protocol, envelope

from ya.skynet.util.pickle import dumps
from ya.skynet.util.functional import singleton
from ya.skynet.util.net.getifaddrs import getIfAddrs, Flags


class Watcher(object):
    def __init__(self, path, *servers, **kwargs):
        watch_ips = kwargs.pop('watch_ips', True)

        self.log = log()
        self.lock = None
        self.path = path
        self.watch_ips = watch_ips
        self.cached_ipaddr = None
        self.original_addrs = get_ips() if watch_ips else None
        self.cached_ipaddr = check_ips(self.log, self.original_addrs, self.cached_ipaddr) if watch_ips else None
        self.servers = []
        for server in servers:
            self.add(server)
        self.running = threading.Event()

    def add(self, server):
        self.log.debug("{} server on port {} will be watched".format(server.impl, server.protocol.listenaddr()[1]))
        self.servers.append(server)

    def set_lock(self, lock):
        self.lock = lock

    def run(self):
        while not self.servers and not self.running.is_set():
            self.running.wait(0.5)

        # immediately check if some other server is running
        with force_filelock(self.log, self.path, self.lock):
            pass

        while not self.running.is_set():
            if self.watch_ips:
                try:
                    self.cached_ipaddr = check_ips(self.log, self.original_addrs, self.cached_ipaddr)
                except socket.gaierror:
                    self.log.exception('self check failed:')

            results = {'status': []}
            for server in self.servers:
                results['status'].append(check_server(self.log, server))

            results['timestamp'] = monotime()
            results['unix_timestamp'] = time.time()

            with force_filelock(self.log, self.path, self.lock):
                self._write_stats(results)

            self.running.wait(30)

    def _write_stats(self, stats):
        with open(self.path + '.new', 'wb') as newf:
            try:
                newf.write(dumps(stats))
                newf.flush()
                os.fsync(newf.fileno())
            except EnvironmentError as e:
                if e.errno == errno.ENOSPC:
                    self.log.warning("No space left on %r, cannot write state", os.path.dirname(self.path))
                else:
                    self.log.exception("Unexpected error on %r, cannot write state",
                                       os.path.dirname(self.path),
                                       exc_info=sys.exc_info())

        os.rename(self.path + '.new', self.path)
        self.log.debug("wrote stats to %r", self.path)

    def stop(self):
        self.log.debug("watcher is being stopped")
        self.running.set()


@contextlib.contextmanager
def optional_lock(lock):
    if lock is not None:
        with lock:
            yield
    else:
        yield


@contextlib.contextmanager
def force_filelock(log, path, user_lock):
    with optional_lock(user_lock):
        if sys.platform == 'cygwin':
            if not os.path.exists(path):
                # create file first
                open(path, 'w').close()
            mode = 'rb'
        else:
            mode = 'ab'

        with open(path, mode) as f:
            try:
                fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
            except EnvironmentError:
                log.warning("Another server with this watcher is already started. Dying.")
                os.kill(os.getpid(), signal.SIGINT)
            else:
                yield
                fcntl.flock(f.fileno(), fcntl.LOCK_UN)


def check_server(log, server):
    port = server.protocol.listenaddr()[1]
    impl = server.impl

    log.debug("checking {} on port {}".format(impl, port))
    proto = None
    try:
        proto = protocol.Protocol(impl=impl, log=log.getChild('proto'))
        tree = [(
            (0, ('localhost', port)),
            [((1, proto.listenaddr()), None)]
        )]
        msg = {
            'type': 'test',
            'uuid': genuuid(),
        }
        env = envelope.Envelope(msg, tree)
        for addr, next_envelope in env.next():
            proto.route(next_envelope, addr)

        cmd, env, iface = proto.receive(timeout=5.0)
        assert cmd == 'route'
        assert env.msgs[0]['uuid'] == msg['uuid']
    except Exception as e:
        log.error("{} server on port {} is not running: {} {}".format(impl, port, type(e), e))
        return dict(impl=impl, port=port, running=False, error=str(e))
    finally:
        if proto is not None:
            proto.shutdown()

    log.debug("{} server on port {} is OK".format(impl, port))
    return dict(impl=impl, port=port, running=True)


def _is_link_local(iface):
    if iface.family == socket.AF_INET:
        return iface.addr.startswith('169.254.')
    elif iface.family == socket.AF_INET6:
        octets = iface.addr.split(':', 1)[0]
        return octets and int(octets, 16) & 0xffc0 == 0xfe80
    else:
        return True


def _filter_iface(iface):
    return (
        iface.addr
        and (
            not iface.name
            or all(
                not iface.name.startswith(prefix)
                for prefix in ('docker', 'br-')
            )
        )
        and not _is_link_local(iface)
        and iface.flags & Flags.IFF_LOOPBACK == 0
        and iface.flags & Flags.IFF_NOARP == 0
        and iface.flags & Flags.IFF_UP == Flags.IFF_UP
        and iface.flags & Flags.IFF_LOWER_UP == Flags.IFF_LOWER_UP
    )


def get_ips():
    addrs = chain.from_iterable(getIfAddrs().itervalues())
    return set(map(lambda x: x.addr, filter(_filter_iface, addrs)))


def check_ips(log, reference, cached_ipaddr):
    current_ips = get_ips()
    if reference != current_ips:
        log.debug("reference ips: %s, current ips: %s", reference, current_ips)
        log.error("network configuration has been changed, restart required, killing self right now!")
        os.kill(os.getpid(), signal.SIGINT)

    hostname = fqdn()
    try:
        ips = getaddrinfo(hostname, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)
        ips = filter(lambda x: x[0] in (socket.AF_INET, socket.AF_INET6), ips)
    except Exception:
        if cached_ipaddr is not None:
            ips = cached_ipaddr
        else:
            raise

    for addr in ips:
        ip = addr[4][0]
        if ip not in reference:
            message = \
                "this host (%r) ip %r isn't configured on this machine, " \
                "host is misconfigured, will die immediately"
            log.error(message, hostname, ip)
            print(message % (hostname, ip), file=sys.stderr)
            os.kill(os.getpid(), signal.SIGINT)

    return ips


@singleton
def log():
    return root().getChild('watcher')
