from util.generic.string cimport TStringBuf

import dns.message
import dns.name
import dns.query
import dns.rdataclass
import dns.rdatatype

import json
import logging
import multiprocessing
import psutil
import requests
import time


cdef extern from "infra/yp_dns/libs/daemon/main.h" namespace "NYP::DNS":
    int RunDaemon(TStringBuf, int, const char**)


def _run_daemon(config, args):
    cdef int argc = len(args)
    cdef char* argv[1024]
    for idx in range(argc):
        argv[idx] = args[idx]

    RunDaemon(config, argc, <const char**>argv)


def parse_time(value):
    assert value.endswith('s') and value[-2].isdigit(), \
        "implement time parsing to use literals other than 's'"
    return int(value[:-1])


class YpDnsDaemon(object):
    def __init__(self, config, pdns_args):
        self._pdns_args = pdns_args

        args = [
            'run',
            '--launch=YP_DNS',
        ] + list(map(lambda kv: '--{}={}'.format(kv[0], kv[1]), self._pdns_args.items()))
        b_args = list(map(lambda arg: bytes(arg, 'utf8'), args))

        self._address = '127.0.0.1'
        self._port = self._pdns_args['local-port']
        self._service_port = self._pdns_args.get('service-port', self._port + 1)

        self._updating_frequency = 1
        self._master_request_timeout = 1
        for cluster_config in config.get('YPClusterConfigs', []):
            if 'UpdatingFrequency' in cluster_config:
                self._updating_frequency = max(self._updating_frequency, parse_time(cluster_config['UpdatingFrequency']))
            if 'Timeout' in cluster_config:
                self._master_request_timeout = max(self._master_request_timeout, parse_time(cluster_config['Timeout']))

        config_string = bytes(json.dumps(config), 'utf8')

        self._process = multiprocessing.Process(target=_run_daemon, args=(config_string, b_args,))
        self._process.start()
        self._wait_start()

    @property
    def address(self):
        return self._pdns_args['local-address']

    @property
    def port(self):
        return self._port

    @property
    def updating_frequency(self):
        return self._updating_frequency

    @property
    def master_request_timeout(self):
        return self._master_request_timeout

    @property
    def positive_cache_ttl(self):
        return self._pdns_args['query-cache-ttl']

    @property
    def negative_cache_ttl(self):
        return self._pdns_args['negquery-cache-ttl']

    def _wait_start(self):
        while True:
            try:
                r = requests.get("http://localhost:{}/ping".format(self._service_port))
                r.raise_for_status()
                break
            except Exception as e:
                time.sleep(5)

    def suspend(self):
        if psutil.pid_exists(self._process.pid):
            psutil.Process(self._process.pid).suspend()

    def resume(self):
        if psutil.pid_exists(self._process.pid):
            psutil.Process(self._process.pid).resume()

    def sensors(self):
        return requests.get("http://localhost:{}/sensors".format(self._service_port))

    def ping(self):
        return requests.get("http://localhost:{}/ping".format(self._service_port))

    def reopen_log(self):
        return requests.get("http://localhost:{}/reopen_log".format(self._service_port))

    def stop(self):
        self._process.terminate()


class DnsClient(object):
    PROTOCOL_TO_METHOD = {
        'udp': dns.query.udp,
        'tcp': dns.query.tcp,
    }

    def __init__(self, address, port=53, updating_frequency=0, updating_timeout=0):
        self._address = address
        self._port = port
        self._updating_frequency = updating_frequency
        self._updating_timeout = updating_timeout

    def wait_update(self):
        time.sleep(self._updating_frequency + self._updating_timeout + 1)

    def _do_query(self, query, protocol='udp', wait_update=False, timeout=None):
        query_method = self.PROTOCOL_TO_METHOD[protocol]

        if wait_update:
            self.wait_update()

        return query_method(query, where=self._address, port=self._port, timeout=timeout)

    def query(self, qname, rdtype, rdclass='IN', protocol='udp', use_edns=None, wait_update=False, timeout=None):
        query = dns.message.make_query(qname, rdtype, rdclass, use_edns=use_edns)
        return self._do_query(query, protocol, wait_update, timeout)

    def udp(self, qname, rdtype, rdclass='IN', use_edns=None, wait_update=False, timeout=None):
        return self.query(qname, rdtype, rdclass, protocol='udp', use_edns=use_edns, wait_update=wait_update, timeout=timeout)

    def udp_query(self, query, wait_update=False, timeout=None):
        return self._do_query(query, protocol='udp', wait_update=wait_update, timeout=timeout)

    def tcp(self, qname, rdtype, rdclass='IN', use_edns=None, wait_update=False, timeout=None):
        return self.query(qname, rdtype, rdclass, protocol='tcp', use_edns=use_edns, wait_update=wait_update, timeout=timeout)

    def tcp_query(self, query, wait_update=False, timeout=None):
        return self._do_query(query, protocol='tcp', wait_update=wait_update, timeout=timeout)

    def udp_tcp(self, qname, rdtype, rdclass='IN', use_edns=None, wait_update=False, timeout=None):
        if wait_update:
            self.wait_update()
        query = dns.message.make_query(qname, rdtype, rdclass, use_edns=use_edns)
        resp_udp = self.udp_query(query, wait_update=False, timeout=timeout)
        resp_tcp = self.tcp_query(query, wait_update=False, timeout=timeout)

        assert resp_udp == resp_tcp
        return resp_udp

    def get_rrset_in_section(self, section, resp, domain, rdtype, rdclass='IN'):
        qname = self.make_dns_name(domain)
        if isinstance(rdtype, str):
            rdtype = dns.rdatatype.from_text(rdtype)
        if isinstance(rdclass, str):
            rdclass = dns.rdataclass.from_text(rdclass)
        return resp.get_rrset(section, qname, rdclass=rdclass, rdtype=rdtype)

    def get_answer(self, resp, domain, rdtype, rdclass='IN'):
        return self.get_rrset_in_section(dns.message.ANSWER, resp, domain, rdtype, rdclass)

    def get_authority(self, resp, domain, rdtype, rdclass='IN'):
        return self.get_rrset_in_section(dns.message.AUTHORITY, resp, domain, rdtype, rdclass)

    def get_additional(self, resp, domain, rdtype, rdclass='IN'):
        return self.get_rrset_in_section(dns.message.ADDITIONAL, resp, domain, rdtype, rdclass)

    def make_dns_name(self, qname):
        domain = qname
        if not domain.endswith('.'):
            domain += '.'
        return dns.name.Name(domain.split('.'))


class YpDns(YpDnsDaemon, DnsClient):
    def __init__(self, config, pdns_args):
        YpDnsDaemon.__init__(self, config, pdns_args)
        DnsClient.__init__(self, self.address, self.port, self.updating_frequency, self.master_request_timeout)

    def wait_positive_cache_dropped(self):
        time.sleep(self.positive_cache_ttl + 1)

    def wait_negative_cache_dropped(self):
        time.sleep(self.negative_cache_ttl + 1)
