# encoding: UTF-8

import time

import flask
import gevent
import gevent.pywsgi
from ws_properties.utils.logs import get_logger_for_instance

from dns_hosting.services.dnsserver.nsd.ctl import NSDControl
from dns_hosting.services.dnsserver.properties import NsdServerProperties
from dns_hosting.services.support.retry import StandardRetryPolicy, retry


class NsdStatServer(gevent.pywsgi.WSGIServer):
    def __init__(self, properties):
        # type: (NsdServerProperties) -> None

        self._properties = properties
        self._logger = get_logger_for_instance(self)

        super(NsdStatServer, self).__init__(
            listener=(properties.stat.host, properties.stat.port),
            log=self._logger,
            error_log=self._logger,
        )

        self._control = NSDControl(
            address=(properties.control.host, properties.control.port),
            keyfile=properties.control.keyfile,
            certfile=properties.control.certfile,
        )

        self._application = None
        self._running = False
        self._current_stats = []
        self._current_zonestatus = []
        self._current_tcp6_fd_stats = []

    def start(self):
        self._running = True
        gevent.spawn(self._set_current_stats)
        gevent.spawn(self._set_current_zonestatus)
        gevent.spawn(self._set_current_tcp6_fd_stats)
        super(NsdStatServer, self).start()

    def stop(self, timeout=None):
        self._running = False
        super(NsdStatServer, self).stop(timeout)

    @property
    def application(self):
        if self._application is None:
            app = flask.Flask(__name__)
            app.add_url_rule('/unistat/', view_func=self._get_stats_view)
            self._application = app

        return self._application

    def _set_current_stats(self):
        while self._running:
            next_start = time.time() + self._properties.stat.interval
            try:
                data = self._control.stats()
            except (gevent.Timeout, Exception):
                self._logger.exception('Stat aggregation failed.')
                self._current_stats = []
            else:
                stats = []
                for line in data.splitlines():
                    line = line.strip()

                    if line:
                        key, value = line.split('=')
                        key = key.replace('.', '_')

                        if 'time_' in key:
                            key = key + '_avvv'
                            value = float(value)
                        elif 'size_' in key:
                            key = key + '_ammv'
                            value = int(value)
                        elif 'zone_' in key:
                            key = key + '_annx'
                            value = int(value)
                        else:
                            key = key + '_ammm'
                            value = int(value)
                        stats.append([key, value])

                self._current_stats = stats
                self._logger.info('Stat loaded.')
            finally:
                gevent.sleep(max(0, next_start - time.time()))

    def _set_current_zonestatus(self):
        while self._running:
            next_start = time.time() + self._properties.stat.zonestatus_interval
            try:
                self._logger.info('Zonestatus loading...')
                stucked_zones = set()
                for i, zonedata in enumerate(self._control.zonestatus()):
                    if zonedata.get('transfer') == 'waiting-for-TCP-fd':
                        stucked_zones.add(zonedata['zone'])

                    if zonedata.get('commit_serial') is None \
                            and zonedata.get('served_serial') is None:
                        stucked_zones.add(zonedata['zone'])

                    if i % 50000:
                        gevent.sleep()
            except (gevent.Timeout, Exception) as e:
                self._logger.exception('Zonestatus aggregation failed.')
                self._current_zonestatus = []
                if isinstance(e, IOError):
                    next_start = time.time() + 5
            else:
                self._current_zonestatus = [['stucked_zones_ammx', len(stucked_zones)]]
                self._logger.info('Zonestatus loaded.')

                if stucked_zones:
                    gevent.spawn(self._fix_stucked_zones, stucked_zones)
            finally:
                gevent.sleep(max(0, next_start - time.time()))

    def _fix_stucked_zones(self, stucked_zones):
        stucked_zones = list(stucked_zones)
        retry_policy = StandardRetryPolicy(delay=1, max_attempts=3)
        chunk_size = 100
        try:
            self._logger.info('Fixing %d stucked zones...', len(stucked_zones))
            # for zone in stucked_zones:
            #     for attempt in retry(retry_policy):
            #         with attempt:
            #             self._control.retransfer(zone)
            for offset in xrange(0, len(stucked_zones), chunk_size):
                chunk = stucked_zones[offset:offset + chunk_size]
                for attempt in retry(retry_policy):
                    with attempt:
                        self._control.delzones(chunk)
                        self._control.addzones(chunk, 'default')
        except Exception:
            self._logger.exception('Failed to fix stucked zones.')
        else:
            self._logger.exception('Stucked zones fixed.')

    def _set_current_tcp6_fd_stats(self):
        while self._running:
            next_start = time.time() + self._properties.stat.interval
            try:
                with open('/proc/net/tcp6') as f:
                    tcp6_fd_count = len(f.readlines()) - 1
            except (gevent.Timeout, Exception):
                self._logger.exception('TCP6 FD stat aggregation failed.')
                self._current_tcp6_fd_stats = []
            else:

                self._current_tcp6_fd_stats = [
                    ['tcp6_fd_count_ammx', tcp6_fd_count],
                ]
                self._logger.info('TCP6 FD stat loaded.')
            finally:
                gevent.sleep(max(0, next_start - time.time()))


    def _get_stats_view(self):
        return flask.jsonify(
            self._current_zonestatus +
            self._current_tcp6_fd_stats +
            self._current_stats
        )
