from kernel.util.net import dns
from kernel.util.gevent.threadsafe import AsyncResult
from kernel.util.unittest import TestCase, main
from library.sky.hostresolver import Resolver

import errno
import gevent
import os


class TestDns(TestCase):
    def setUp(self):
        self.dnsResolver = dns.RetriedGeventResolver(workers=20, queueSize=40, retries=3)
        self.cmsResolver = Resolver()
        super(TestDns, self).setUp()

    def tearDown(self):
        self.dnsResolver.stop()
        super(TestDns, self).tearDown()

    def resolve(self, host, tolerateNoAddress=False):
        try:
            return self.dnsResolver.getaddrinfo(host, 10000, timeout=60)
        except BaseException as e:
            if tolerateNoAddress and "Name or service not known" in str(e):
                return []
            else:
                raise

    def testResolve(self):
        for _ in range(iterations):
            hosts = ['yandex.ru'] * 500
            lets = [gevent.spawn(self.resolve, host, False) for host in hosts]
            gevent.joinall(lets)
            for let in lets:
                self.assertIsInstance(let.get(), list)

    def testResolverFail(self):
        resolver = dns.RetriedGeventResolver(workers=1, queueSize=1, retries=2, ResultClass=FailingResult)
        self.assertRaises(AsyncResult.Timeout, resolver.getaddrinfo, "localhost", 0, 0, 0, timeout=1.0)
        self.assertEqual(0, resolver.workers)
        resolver.ResultClass = AsyncResult
        self.assertTrue(resolver.getaddrinfo("localhost", 0, 0, 0, timeout=1.5))

    # @skip("Only in Yandex network")
    def testResolveHALL(self):
        for _ in range(iterations):
            hosts = self.cmsResolver.resolveHosts('H@ALL')
            lets = [gevent.spawn(self.resolve, host, True) for host in hosts]
            gevent.joinall(lets, raise_error=True)
            for let in lets:
                self.assertIsInstance(let.get(), list)


class FailingResult(AsyncResult):
    def set(self, value):
        raise IOError(errno.EPIPE, os.strerror(errno.EPIPE))

    setException = set
    set_exception = set


if os.name == 'posix':
    iterations = 20
else:
    iterations = 1


if __name__ == '__main__':
    main()
