"""Gevent-compatible DNS resolver"""

__author__ = 'torkve'

import gevent.coros
from Queue import Queue
import threading
import socket
import weakref

from ..sys.gettime import monoTime
from ..gevent.threadsafe import AsyncResult

import six


class Resolver(object):
    """
    Raw resolver that resolves DNS requests in separate threads.
    """

    class Worker(threading.Thread):
        def __init__(self, queue, semaphore, parent):
            self.queue = queue
            self.semaphore = semaphore
            self.parent = parent
            super(Resolver.Worker, self).__init__(name="Resolver.Worker-{}".format(id(self)))

        def run(self):
            try:
                while True:
                    self.semaphore.acquire()

                    obj = self.queue.get()

                    if obj is None:
                        break

                    task = obj.task
                    args = obj.args
                    try:
                        result = task(*args)
                        obj.set(result)
                    except Exception as e:
                        obj.set_exception(e)
                    finally:
                        del obj
                        del task
                        del args
            except Exception:
                parent = self.parent()
                if parent is not None:
                    parent.workers -= 1
                raise

    def __init__(self, workers=5, queueSize=None, ResultClass=AsyncResult):
        """
        :param int workers: amount of worker threads to spawn
        :param int queueSize: since synchronization uses pipes, it's recommended
                              to limit queue size in case of high load, or you
                              can exceed your fd limit otherwise
        """
        self.queue = Queue(queueSize)
        self.semaphore = threading.Semaphore(0)
        self.workers = 0
        self.maxWorkers = workers
        self.ResultClass = ResultClass
        self.stopped = True
        self.start()

    def start(self):
        if not self.stopped:
            return

        self.stopped = False

    def _releaseWorker(self):
        self.queue.put(None)
        self.semaphore.release()
        self.workers -= 1

    def _spawnWorker(self):
        self.workers += 1
        try:
            thr = Resolver.Worker(self.queue, self.semaphore, weakref.ref(self))
            thr.daemon = True
            thr.start()
        except:
            self.workers -= 1
            raise

    def _adjustWorkers(self):
        while self.workers < self.maxWorkers and self.queue.qsize() > (self.workers * 2):
            self._spawnWorker()
        while self.workers > (self.queue.qsize() * 2) and self.workers > 1:
            self._releaseWorker()

    def stop(self):
        self.stopped = True
        for worker in six.moves.xrange(self.workers):
            self._releaseWorker()

    def _call(self, task, args, timeout=None):
        start = monoTime()

        obj = self.ResultClass()
        obj.task = task
        obj.args = args
        self.semaphore.release()
        self.queue.put(obj)

        self._adjustWorkers()

        if timeout is not None:
            timeout -= (monoTime() - start)

        result = obj.get(timeout)
        return result

    def getaddrinfo(self, *args, **kwargs):
        timeout = kwargs.pop('timeout', None)

        return self._call(socket.getaddrinfo, args, timeout=timeout)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.stop()

    def __del__(self):
        self.stop()


class RetriedResolverMixin(object):
    """Resolver that retries each call in case of exception"""

    def __init__(self, *args, **kwargs):
        """
        :param int retries: how many times to try the task
        """
        self.retries = kwargs.pop('retries', 1)
        super(RetriedResolverMixin, self).__init__(*args, **kwargs)

    def _call(self, task, args, timeout=None):
        start = monoTime()

        for i in six.moves.xrange(self.retries):
            try:
                result = super(RetriedResolverMixin, self)._call(task, args, timeout=timeout)
                if result:
                    return result
            except:
                if i == self.retries - 1:
                    raise
                if timeout:
                    timeout -= monoTime() - start
                    if timeout <= 0:
                        raise
                    start = monoTime()


class GeventResolverMixin(object):
    """Resolver used in single gevent event loop and limited by semaphore"""

    def __init__(self, *args, **kwargs):
        super(GeventResolverMixin, self).__init__(*args, **kwargs)
        self.geventSemaphore = gevent.coros.Semaphore(self.queue.maxsize or 0)

    def _call(self, task, args, timeout=None):
        with self.geventSemaphore:
            return super(GeventResolverMixin, self)._call(task, args, timeout=timeout)


class RetriedGeventResolver(GeventResolverMixin, RetriedResolverMixin, Resolver):
    pass
