#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import json
import requests
import hashlib
import logging
import subprocess
import socket
import six

logger = logging.getLogger(__name__)
HAPROXY_SOCKET = '/run/haproxy-admin.sock'


class ServiceDiscoveryError(Exception):
    pass


class ServiceDiscovery:
    BASE_URL = 'http://sd.yandex.net:8080'
    CLIENT_NAME = 'telemost-sd'

    def __init__(self, endpoint, dc, only_ready=False):
        self.endpoint = endpoint
        self.dc = dc
        self.only_ready = only_ready

    def get_data(self):
        data = {
            'cluster_name': self.dc,
            'endpoint_set_id': self.endpoint,
            'client_name': self.CLIENT_NAME,
        }
        return data

    def request(self):
        url = self.BASE_URL + '/resolve_endpoints/json'
        data = self.get_data()
        resp = requests.get(url, data=json.dumps(data), timeout=2)
        resp.raise_for_status()
        return resp.json()

    def resolve(self):
        sd = self.request()
        services = []
        endpoint_set = sd.get('endpoint_set')
        if not endpoint_set:
            raise ServiceDiscoveryError(sd)
        if endpoint_set['endpoint_set_id'] != self.endpoint:
            raise ServiceDiscoveryError(sd)
        for endpoint in endpoint_set['endpoints']:
            if self.only_ready and not endpoint['ready']:
                continue
            services.append(endpoint)
        return services


class HAProxy:
    def __init__(self, socket_path=HAPROXY_SOCKET):
        self.socket_path = socket_path

    def _execute(self, command):
        try:
            unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            unix_socket.settimeout(0.1)
            unix_socket.connect(self.socket_path)
            unix_socket.send(six.b(command + '\n'))
            file_handle = unix_socket.makefile()
        except (socket.timeout, OSError):
            raise
        else:
            try:
                data = file_handle.read().splitlines()
            except (socket.timeout, OSError):
                raise
            else:
                return data
        finally:
            unix_socket.close()

    def show_backend(self):
        command = 'show backend'
        try:
            data = self._execute(command)
        except socket.error:
            return []
        backends = []
        for item in data:
            if not item:
                continue
            if item.startswith('#'):
               continue
            backends.append(item)
        logging.info("backends from haproxy: %s" % ",".join(backends))
        return backends


class App:
    def __init__(self, dcs, enpoints, config_path):
        self.dcs = dcs
        self.enpoints = enpoints
        self.services = self.get_services()
        self.config_path = config_path
        self.haproxy = HAProxy()

    def get_services(self):
        services = []
        for dc in self.dcs:
            for endpoint in self.enpoints:
                sd = ServiceDiscovery(endpoint, dc)
                try:
                    services.extend(sd.resolve())
                except ServiceDiscoveryError:
                    pass
        logger.debug('got services: %s' % services)
        return services

    @property
    def backend_changed(self):
        backends = filter(lambda x: not x.startswith('generate_'), self.haproxy.show_backend())
        if set(backends) != set([service['fqdn'] for service in self.services]):
            return True
        return False

    def write(self):
        logger.info('update config')
        with open(self.config_path, 'w') as fp:
            fp.write(self.generate_config())

    def restart(self):
        cmd = [
            "supervisorctl", "restart", "backend"
        ]
        r = subprocess.call(cmd)
        logger.info('restart retval: %d', r)

    def need_run(self):
        if not self.services:
            logger.error('no services from endpoint')
            return False
        # haproxy not run?
        if not os.path.exists(self.haproxy.socket_path):
            return True
        if self.backend_changed:
            logger.warning('backends changed')
            return True
        return False

    def run(self):
        if self.need_run():
            self.write()
            self.restart()
        else:
            logger.info('up to date')

    @staticmethod
    def check_dns(hostname):
        try:
            socket.setdefaulttimeout(1)
            addr = socket.getaddrinfo(hostname, 0)
        except:
            return False
        return True

    def generate_config(self):
        config = ''
        ctx = {}
        template = """
backend %(server)s
        mode http
        balance roundrobin
        option httpclose
        option forwardfor
        server pod %(server)s:80 resolvers dns
"""
        for service in self.services:
            if not self.check_dns(service['fqdn']):
                continue
            ctx['server'] = service['fqdn']
            ctx['ip'] = service['ip6_address']
            config = config + template % ctx
        return config


if __name__ == '__main__':
    logging.basicConfig(
        stream=sys.stdout, level=logging.INFO,
        format='%(asctime)s [%(levelname)s] [%(module)s:%(funcName)s] %(message)s'
    )
    logging.getLogger(__name__).setLevel(logging.DEBUG)
    DCS = [s.strip() for s in
           os.getenv('XMPP_DCS', 'sas,vla,man').split(',')]
    ENDPOINSTS = [s.strip() for s in
                  os.getenv('XMPP_ENDPOINTS').split(',')]

    app = App(DCS, ENDPOINSTS, os.environ['HAPROXY_CONFIG'].strip())
    app.run()
