import os
import six
import sys
import time
import select
import atexit
import traceback
import threading
import functools
import weakref
from itertools import chain

from ya.skynet.util.functional import singleton
from ya.skynet.util.sys.user import getUserName, UserPrivileges
try:
    from netlibus import monotime
except ImportError:
    from ya.skynet.util.sys.gettime import monoTime as monotime
from ya.skynet.util.net.pipe import Pipe
try:
    from ya.skynet.library.config import detect_hostname
except ImportError:
    import socket
    detect_hostname = socket.getfqdn


def run_daemon(target, *args, **kwargs):
    from threading import Thread
    name = kwargs.pop('__name', None)
    should_start = kwargs.pop('__start', True)
    if name is None:
        name = 'DaemonThread {}'.format(target)
    t = Thread(target=target, name=name, args=args, kwargs=kwargs)
    t.daemon = True
    if should_start:
        t.start()
    return t


_threaded_objects = weakref.WeakSet()


def _join_all():
    for obj in list(_threaded_objects):
        obj.join_all()
        _threaded_objects.discard(obj)

atexit.register(_join_all)


class Threaded(object):
    def __init__(self):
        self.__threads = set()
        _threaded_objects.add(self)

    def spawn(self, target, *args, **kwargs):
        daemon = kwargs.pop('_daemon', False)
        if daemon:
            t = run_daemon(
                self.__thread_wrapper,
                target,
                __name='DaemonThread {}.{}-{}'.format(self.__module__, self.__class__.__name__, target.__name__),
                __start=False,
                *args,
                **kwargs
            )
        else:
            t = threading.Thread(
                target=self.__thread_wrapper,
                name='Thread {}.{}-{}'.format(self.__module__, self.__class__.__name__, target.__name__),
                args=(target,) + args,
                kwargs=kwargs
            )
        self.__threads.add(t)
        t.start()
        return t

    def __thread_wrapper(self, target, *args, **kwargs):
        try:
            target(*args, **kwargs)
        finally:
            thr = threading.current_thread()
            self.__threads.discard(thr)

    def join_all(self):
        if hasattr(self, 'shutdown'):
            self.shutdown()

        for thr in list(self.__threads):
            thr.join()


class FdHolder(object):
    __slots__ = ('name',)

    def __init__(self, name):
        self.name = '__' + name

    def __get__(self, obj, objtype):
        if obj is None:
            return self

        return getattr(obj, self.name, None)

    def __set__(self, obj, val):
        oldfd = getattr(obj, self.name, None)
        setattr(obj, self.name, val)

        if oldfd is not None:
            os.close(oldfd)

    def __delete__(self, obj):
        oldfd = getattr(obj, self.name, None)
        setattr(obj, self.name, None)

        if oldfd is not None:
            os.close(oldfd)


class LRUCache(object):
    PREV, NEXT, KEY, VALUE = 0, 1, 2, 3   # names for the link fields

    class Link(object):
        __slots__ = ['prev', 'next', 'key', 'value']

        def __init__(self, prev=None, next=None, key=None, value=None):
            self.prev = prev
            self.next = next
            self.key = key
            self.value = value

        def __iter__(self):
            yield self.prev
            yield self.next
            yield self.key
            yield self.value

        def __repr__(self):
            return 'Link(prev={}, next={}, key={!r}, value="...")'.format(
                "<Key {!r}>".format(self.prev.key) if self.prev is not self else "Self",
                "<Key {!r}>".format(self.next.key) if self.next is not self else "Self",
                self.key,
            )

    def __init__(self, size=100):
        self.__cache = {}
        self.__size = size
        self.__lock = threading.RLock()
        self.__root = None
        self.__clear()

    def __clear(self):
        with self.__lock:
            self.__cache.clear()
            old = self.__root
            while old is not None:
                old.prev = None
                old.key = None
                old.value = None
                next = old.next
                old.next = None
                old = next if next is not old else None
            self.__root = LRUCache.Link()
            self.__root.prev = self.__root.next = self.__root
            self.__full = False

    def __update_cache(self, link):
        if link is not None:
            l_prev, l_next, _key, _value = link
            l_prev.next = l_next
            l_next.prev = l_prev
            last = self.__root.prev
            last.next = self.__root.prev = link
            link.prev = last
            link.next = self.__root

    def __getitem__(self, key):
        with self.__lock:
            link = self.__cache[key]
            self.__update_cache(link)
            return link.value

    def get(self, key, default=None):
        with self.__lock:
            link = self.__cache.get(key)
            self.__update_cache(link)
            return link.value if link is not None else default

    def __delitem__(self, key):
        with self.__lock:
            link = self.__cache.pop(key, None)
            # we don't raise KeyError as usual __delitem__ does,
            # because of uncertain cache contents
            if link is None:
                return

            # (temporary store old item to prevent premature GC collection)
            oldkey, link.key = link.key, None  # noqa
            oldval, link.value = link.value, None  # noqa
            l_prev = link.prev
            l_next = link.next
            l_prev.next = l_next
            l_next.prev = l_prev
            link.prev = link.next = None

            self.__full = (len(self.__cache) >= self.__size)

    def __setitem__(self, key, value):
        with self.__lock:
            self.__update_cache(self.__cache.get(key))
            if key in self.__cache:
                self.__cache[key].value = value
            elif self.__full:
                oldroot = self.__root
                oldroot.key = key
                oldroot.value = value

                self.__root = oldroot.next
                oldkey = self.__root.key
                oldvalue = self.__root.value  # noqa
                self.__root.key = self.__root.value = None

                del self.__cache[oldkey]
                self.__cache[key] = oldroot
            else:
                last = self.__root.prev
                link = LRUCache.Link(last, self.__root, key, value)
                last.next = self.__root.prev = self.__cache[key] = link
                self.__full = (len(self.__cache) >= self.__size)

    def __contains__(self, key):
        with self.__lock:
            return key in self.__cache

    def setdefault(self, key, value=None):
        with self.__lock:
            self.__update_cache(self.__cache.get(key))
            if key in self.__cache:
                return self.__cache[key].value

            self[key] = value
            return value

    def __del__(self):
        self.__clear()


def auto_restart(fn):
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        while True:
            # noinspection PyBroadException
            try:
                return fn(*args, **kwargs)

            except Exception:
                log().exception("Exception occurred in thread {}".format(fn), exc_info=sys.exc_info())

            except BaseException as e:
                log().info(str(e))

    return wrapper


def rotate_list(lst, num):
    if not num:
        return lst

    return lst[num:] + lst[:num]


def genuuid():
    import uuid
    return str(uuid.uuid4())


@singleton
def configure_log(suffix, level=None, debug=False):
    reconfigure_log(suffix, level=level, debug=debug)


def reconfigure_log(suffix, level=None, debug=False, logger=None, rename_levels=True):
    from ya.skynet.util import logging

    level_chars = 1 if rename_levels else None
    level = level if level is not None else logging.INFO

    external_debug_logger = debug_log()

    logger = logger or log()

    fh = next(iter(filter(lambda x: type(x).__name__ == 'SkynetLoggingHandler', logger.handlers)), None)
    if fh:
        fh.setLevel(level)
    else:
        filename = 'cqudp-%s.log' % (suffix,)
        try:
            from api.logger import SkynetLoggingHandler
            fh = SkynetLoggingHandler(app='cqudp', filename=filename)
            logging.initialize(logger=logger, level=level, handler=fh, formatter=None, levelChars=level_chars)
        except ImportError:
            fh = logging.FileHandler(filename)
            fh.setLevel(level)
            logging.initialize(logger=logger, level=level, handler=fh, levelChars=level_chars)

    if debug:
        ch = next(iter(filter(lambda x: isinstance(x, logging.StreamHandler), logger.handlers)), None)
        if ch:
            ch.setLevel(level)
        else:
            ch = logging.StreamHandler()
            ch.setLevel(level)
            logging.initialize(logger=logger, level=level, handler=ch, levelChars=level_chars)

        dch = next(iter(filter(lambda x: isinstance(x, logging.StreamHandler), external_debug_logger.handlers)), None)
        if dch:
            dch.setLevel(level)
        else:
            logging.initialize(logger=external_debug_logger, level=level, handler=ch, levelChars=level_chars)


@singleton
def debug_log():
    from ya.skynet.util import logging
    logger = logging.getLogger('ya.skynet.debug')
    if not logger.handlers:
        logger.addHandler(logging.NullHandler())

    return logger


@singleton
def log():
    from ya.skynet.util import logging
    logger = logging.getLogger('cqudp')
    logger.addHandler(logging.NullHandler())
    return logger


@singleton
def fqdn():
    return detect_hostname()


def getaddrinfo(*args, **kwargs):
    import socket

    attempts = kwargs.pop('retries', 10)
    for i in six.moves.xrange(attempts):
        try:
            return socket.getaddrinfo(*args, **kwargs)
        except socket.gaierror as e:
            if e.errno == socket.EAI_AGAIN and i != attempts - 1:
                continue
            raise


def short(uuid):
    return str(uuid).split('-', 1)[0]


def gencid():
    """generate client id, used to distinct different clients in log"""
    return genuuid()[:3]


def as_user(_user, _fn, *args, **kwargs):
    if _user is not None and _user != getUserName():
        with UserPrivileges(_user, modifyGreenlet=False):
            return _fn(*args, **kwargs)
    else:
        return _fn(*args, **kwargs)


def in_thread():
    # dirty hack since we can be monkey-patched and gevent will lie to us
    # so we just check the filename of the top frame
    while True:
        try:
            return 'threading.py' in traceback.extract_stack()[0][0]
        except KeyError:
            continue


def sleep(seconds):
    """
    Why it's needed? In short: time.sleep doesn't guarantee,
    that it won't be interrupted (e.g. procman api constantly wakes every thread
    and interrupts time.sleep)
    """
    if seconds == 0:
        return time.sleep(0)

    now = monotime()
    deadline = now + seconds
    while seconds > 0:
        time.sleep(seconds)
        now = monotime()
        seconds = deadline - now


try:
    import gevent.monkey
    std_select_fun = gevent.monkey._select_select
except (ImportError, AttributeError):
    std_select_fun = select.select

in_server = False


def poll_select(rfds, wfds, xfds, timeout=None):
    max_select_fd = 64 if sys.platform == 'cygwin' else 1024

    # for simple cases don't deal with poll
    # if all((x if isinstance(x, int) else x.fileno()) < max_select_fd for x in chain(rfds, wfds, xfds)):
    #     return select_fun(rfds, wfds, xfds, timeout)
    # if monkey-patched and main thread, don't deal with poll
    if not in_server and select.select.__module__.startswith('gevent') and not in_thread():
        return select.select(rfds, wfds, xfds, timeout)

    has_epoll = hasattr(select, 'epoll')
    if not hasattr(select, 'poll') and not has_epoll:
        raise RuntimeError("monkey-patched gevent, fd > %s, not main thread (gevent cannot be used). Deal with it." %
                           max_select_fd)

    # poll() returns fds, never sockets
    rfds = [(x if isinstance(x, int) else x.fileno()) for x in rfds]
    wfds = [(x if isinstance(x, int) else x.fileno()) for x in wfds]
    xfds = [(x if isinstance(x, int) else x.fileno()) for x in xfds]
    if has_epoll:
        timeout = -1 if timeout is None else timeout
    else:
        timeout *= 1000

    p = select.epoll() if has_epoll else select.poll()
    readflag = select.EPOLLIN if has_epoll else select.POLLIN
    writeflag = select.EPOLLOUT if has_epoll else select.POLLOUT
    errflag = ((select.EPOLLERR | select.EPOLLHUP) if has_epoll
               else (select.POLLHUP | select.POLLERR | select.POLLNVAL))

    for fd in rfds:
        p.register(fd, readflag)
    for fd in wfds:
        p.register(fd, writeflag)
    for fd in xfds:
        p.register(fd, errflag)

    ret_r, ret_w, ret_x = [], [], []

    try:
        for fd, event in p.poll(timeout):
            if (event & readflag) and fd in rfds:
                ret_r.append(fd)
            if (event & writeflag) and fd in wfds:
                ret_w.append(fd)
            if (event & errflag) and fd in xfds:
                ret_x.append(fd)
    except IOError as e:
        # epoll doesn't raise select.error, but IOError
        raise select.error(e.args[0], e.args[1])

    return ret_r, ret_w, ret_x


if six.PY2:
    bytesio = six.moves.cStringIO
else:
    bytesio = six.BytesIO


class AtFork(object):
    def __init__(self):
        import threading

        self.__handlers = frozenset()
        self.__non_empty_loggers = []
        self.__lock = threading.RLock()

        self.__acquired_handlers = None

        self.__pid = None

    def __update_cache(self):
        import logging

        new_handlers = frozenset(logging._handlerList)
        if self.__handlers != frozenset(logging._handlerList):
            for logger in chain(list(logging.Logger.manager.loggerDict.values()), [logging.getLogger()]):
                if isinstance(logger, logging.Logger) and logger.handlers:
                    self.__non_empty_loggers.append(logger)
            self.__handlers = new_handlers

    def __enter__(self):
        import logging

        self.__lock.acquire()

        self.__update_cache()

        sys.stdout.flush()
        sys.stderr.flush()

        self.__pid = os.getpid()

        logging._acquireLock()

        self.__acquired_handlers = []

        for handler in self.__handlers:
            _handler = handler()
            if _handler is None:
                continue

            _handler.acquire()
            self.__acquired_handlers.append(_handler)

        return self

    def _release_handlers(self):
        if self.__acquired_handlers:
            for handler in self.__acquired_handlers:
                try:
                    handler.release()
                except:
                    pass
            self.__acquired_handlers = None

    def _close_handlers(self):
        for handler in self.__handlers:
            _handler = handler()
            if _handler is None:
                continue

            try:
                _handler.close()
            except:
                pass

        # And empty all loggers
        for logger in self.__non_empty_loggers:
            del logger.handlers[:]

    def __exit__(self, exc_type, exc_val, exc_tb):
        import logging

        self._release_handlers()

        # Child
        if self.__pid != os.getpid():
            self._close_handlers()

        self.__pid = None

        try:
            logging._releaseLock()
        except:
            pass

        try:
            self.__lock.release()
        except:
            pass


class Coverage(object):
    def __init__(self):
        self.measure = os.getenv('COVERAGE_PROCESS_START', None) is not None

    def __enter__(self):
        if self.measure:
            import cov_core_init
            self.cov = cov_core_init.init()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.measure:
            self.cov.stop()
            self.cov.save()


class AsyncResult(object):
    __unknownState__ = object()
    __normalResult__ = object()
    __exceptionResult__ = object()

    waitDelay = 0.01

    class Timeout(Exception):
        pass

    def __init__(self, select_fun=None):
        super(AsyncResult, self).__init__()
        self.__eventPipe = Pipe()
        self.__selectFun = select_fun or poll_select
        self._init()

    def _init(self):
        self.__eventPipe.open()
        self.__result = self.__unknownState__
        self.__resultType = None

    def clear(self):
        self.__eventPipe.close()
        self._init()

    def set(self, value):
        self.__result = value
        self.__resultType = self.__normalResult__
        self.__eventPipe.writeAll('1')
        self.__eventPipe.closeW()

    def setException(self, value):
        self.__result = value
        self.__resultType = self.__exceptionResult__
        self.__eventPipe.writeAll('1')
        self.__eventPipe.closeW()

    set_exception = setException

    def _returnResult(self):
        if self.__result is self.__unknownState__:
            raise AsyncResult.Timeout()
        elif self.__resultType is self.__exceptionResult__:
            raise self.__result
        else:
            return self.__result

    def _wait(self, timeout=None):
        if timeout is not None and timeout < 0:
            timeout = 0
        return len(self.__selectFun([self.__eventPipe.r], [], [], timeout)[0]) > 0

    def get(self, timeout=None):
        if self.__result is self.__unknownState__:
            if self._wait(timeout):
                self.__eventPipe.closeR()
        return self._returnResult()

    wait = get

    def is_set(self):
        return self.__result is not self.__unknownState__

    isSet = is_set

    def __del__(self):
        self.__eventPipe.close()
