# encoding: UTF-8

import contextlib
import re

import gevent.baseserver
import gevent.socket
import gevent.ssl
from ws_properties.utils.logs import get_logger_for_instance


NUM = re.compile(r'^\d+\s+')


def split_pair(s):
    k, v = s.split(':', 1)
    return k.replace('-', '_').strip(), v.strip()


class NSDControl(object):
    PROTOCOL_VERSION = 'NSDCT1'

    def __init__(self, address, **kwargs):
        self._family, self._address = gevent.baseserver.parse_address(address)
        self._kwargs = kwargs
        self._logger = get_logger_for_instance(self)

    @contextlib.contextmanager
    def _make_connection(self):
        socket = gevent.socket.socket(
            self._family,
            gevent.socket.SOCK_STREAM,
        )
        if self._kwargs:
            socket = gevent.ssl.wrap_socket(
                socket,
                **self._kwargs
            )  # type: gevent.socket.socket
        try:
            socket.connect(self._address)
            stream = socket.makefile()
            try:
                yield stream
            finally:
                stream.close()
        finally:
            socket.close()

    def _command(self, cmd, *args, **kwargs):
        timeout = kwargs.get('timeout')
        with self._make_connection() as stream:
            with gevent.Timeout(timeout):
                cmd = cmd % args
                stream.write(self.PROTOCOL_VERSION + ' ' + cmd + '\n')
                stream.flush()
                return stream.read()

    def stats(self, **kwargs):
        return self._command('stats', **kwargs)

    def addzones(self, zones, pattern='default', **kwargs):
        timeout = kwargs.get('timeout')
        with self._make_connection() as stream:
            with gevent.Timeout(timeout):
                stream.write(self.PROTOCOL_VERSION + ' addzones\n')
                stream.flush()
                for i, zone in enumerate(zones):
                    stream.write(zone + ' ' + pattern + '\n')
                    if i > 0 and i % 10000 == 0:
                        stream.flush()
                        line = stream.readline()
                        while ('added: ' + zone) not in line:
                            line = stream.readline()
                stream.write('\x04\n')
                stream.flush()
                stream.read()

    def retransfer(self, zone, pattern='default', **kwargs):
        timeout = kwargs.get('timeout')
        with self._make_connection() as stream:
            with gevent.Timeout(timeout):
                stream.write(self.PROTOCOL_VERSION + ' transfer ' + zone + '\n')
                stream.flush()
                stream.write('\x04\n')
                stream.flush()
                stream.read()

    def delzones(self, zones, **kwargs):
        timeout = kwargs.get('timeout')
        with self._make_connection() as stream:
            with gevent.Timeout(timeout):
                stream.write(self.PROTOCOL_VERSION + ' delzones\n')
                stream.flush()
                for i, zone in enumerate(zones):
                    stream.write(zone + '\n')
                    if i > 0 and i % 10000 == 0:
                        stream.flush()
                        line = stream.readline()
                        while ('removed: ' + zone) not in line:
                            line = stream.readline()
                stream.write('\x04\n')
                stream.flush()
                stream.read()

    def zonestatus(self, **kwargs):
        timeout = kwargs.get('timeout')
        with self._make_connection() as stream:
            with gevent.Timeout(timeout):
                stream.write(self.PROTOCOL_VERSION + ' zonestatus\n')
                stream.flush()

                zonedata = None
                for line in stream:
                    line = line.strip()
                    if not line:
                        continue

                    key, value = split_pair(line)
                    if key == 'zone':
                        if zonedata:
                            yield zonedata
                        zonedata = dict(zone=value)
                    elif key == 'pattern':
                        continue
                    elif key in ('served_serial', 'commit_serial', 'notified_serial'):
                        value = value.replace('"', '')
                        if value == 'none':
                            zonedata[key] = None
                            zonedata[key + '_at'] = None
                        else:
                            serial, dt = value.split(' since ', 1)
                            zonedata[key] = int(serial)
                            zonedata[key + '_at'] = dt
                    elif key in ('wait', 'transfer'):
                        value = value.replace('"', '')
                        value = NUM.sub('', value)
                        if value == 'none':
                            zonedata[key] = None
                        else:
                            zonedata[key] = value
                    else:
                        zonedata[key] = value

                if zonedata:
                    yield zonedata
