#!/usr/bin/env python3

import os
import socket
import random
import hashlib
from urllib import parse as urlparse
from kazoo.client import KazooClient, KazooState
from flask import Flask, request
import requests
import time
import logging
import sys

logging.basicConfig(
    stream=sys.stdout, level=logging.INFO,
    format='tskv\t'
           'tskv_format=telemost-zkservice\t'
           'timestamp=%(asctime)s\t'
           'unixtime_ms=%(created)s\t'
           'level=%(levelname)s\t'
           'message=%(message)s'
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class ZKService:
    QUERY_PARAM = 'room'

    def __init__(
        self,
        zk_hosts,
        buckets=20, zk_prefix='/telemost/backend',
        mediator_url='https://mediator.rtc.yandex.net:8443'
    ):
        self.zk_hosts = zk_hosts
        self.buckets = buckets
        self.zk_prefix = zk_prefix
        self.mediator_url = mediator_url
        self.zk = KazooClient(hosts=self.zk_hosts, read_only=True)
        self.zk.start()
        self.alive_backends = {}
        self.zk_backends = {}
        for x in range(self.buckets):
            self.zk_backends[x] = self.zk_get(x)

    def get_zk_path(self, bucket):
        return '%s/%s' % (self.zk_prefix, bucket)

    def zk_get(self, bucket):
        zk_path = self.get_zk_path(bucket)
        if self.zk.state == KazooState.CONNECTED:
            try:
                val, stat = self.zk.get(zk_path)
            except:
                val = b''
            if val:
                backend = val.decode('utf-8')
                logger.debug('got from zk: %s' % backend)
                self.zk_backends[bucket] = backend
        backend = self.zk_backends.get(bucket)
        return backend

    def _get_hash(self, val):
        # persistent hash from string to int
        r = int(hashlib.sha256(val.encode('utf-8')).hexdigest(), 16)
        return r

    def _get_query_param(self, query):
        params = urlparse.parse_qs(query)
        try:
            param_val = params[self.QUERY_PARAM][0]
            return param_val
        except (KeyError, IndexError):
            pass
        return None

    def get_hash_qs(self, query):
        room = self._get_query_param(query)
        if room:
            return self._get_hash(room)
        return random.randint(0, 99)

    def get_bucket(self, uri):
        hash_qs = self.get_hash_qs(uri)
        bucket = hash_qs % self.buckets
        return bucket

    def get_from_mediator(self, room):
        if not room:
            logger.warning('room: "%s"' % room)
            return None
        for x in range(3):
            try:
                url = '%s/room/sfu_endpoint?room_id=%s' % (self.mediator_url, room)
                logger.info('make request: %s' % url)
                resp = requests.get(url, timeout=(0.2, 1))
                if resp.status_code == 404:
                    logger.info('mediator %s reply: "%s" (time: %s)' % (
                        url, resp.status_code, resp.elapsed.total_seconds()
                    ))
                    return 'generate_410'
                resp.raise_for_status()
                break
            except (
                requests.exceptions.ConnectionError,
                requests.exceptions.ConnectTimeout,
                requests.exceptions.ReadTimeout,
                requests.exceptions.HTTPError,
            ) as e:
                logger.error('exception while request to mediator (room: %s) (try %d): %s' % (room, x+1, e))
                if x == 2:
                    return None
        logger.info('mediator %s reply: "%s" (time: %s)' % (
            url, resp.json(), resp.elapsed.total_seconds()
        ))
        tmp = socket.getdefaulttimeout()
        socket.setdefaulttimeout(1)
        address = resp.json().get('host')
        hostname = socket.gethostbyaddr(address)[0]
        socket.setdefaulttimeout(tmp)
        return hostname

    def get_backend_mediator(self, uri):
        room = self._get_query_param(uri)
        logger.info('call get_from_mediator: uri: "%s"; room: "%s"' % (uri, room))
        try:
            backend = self.get_from_mediator(room)
        except Exception as e:
            backend = None
            logger.error('mediator error: %s' % e)
        if backend:
            logger.info('return %s (mediator) for uri "%s"' % (
                backend, uri))
            return backend
        logger.warning('return default backend for uri "%s": generate_504' % uri)
        return 'generate_504'

    def get_backend(self, uri):
        room = self._get_query_param(uri)
        if not room:
            logger.warning('return generate_400 for uri "%s"' % uri)
            return 'generate_400'
        bucket = self.get_bucket(uri)
        backend = self.zk_get(bucket)
        mediator_backend = self.get_backend_mediator(uri)
        if backend == 'mediator':
            logger.info('return %s (bucket: %s) for uri "%s" (from mediator)' % (
                    mediator_backend, bucket, uri))
            return mediator_backend
        dead = []
        for x in range(self.buckets):
            if self.check_alive(backend):
                logger.info('return %s (bucket: %s) for uri "%s"' % (
                    backend, bucket, uri))
                return backend
            dead.append(backend)
            logger.warning('dead backend: %s' % backend)
            bucket = (bucket + 1) % self.buckets
            backend = self.zk_get(bucket)
            if backend in dead:
                continue
        logger.info(
            'return last %s (bucket: %s) for uri "%s"' % (backend, bucket, uri))
        return backend

    def check_alive(self, backend):
        if time.time() - self.alive_backends.get(backend, 0) < 20:
            logger.debug('%s: check alive cached True' % backend)
            return True
        url = 'http://%s/ping' % backend
        for x in range(3):
            try:
                r = requests.get(url, timeout=0.2)
                r.raise_for_status()
            except:
                self.alive_backends[backend] = 0
                logger.info("try %d failed for %s" % (x, backend))
                time.sleep(0.05)
                continue
            self.alive_backends[backend] = time.time()
            return True
        return False


service = ZKService(
    os.environ['ZK_HOSTS'],
    zk_prefix=os.environ.get('ZK_PREFIX', '/telemost/backend'),
    mediator_url=os.environ.get('MEDIATOR_URL',
                                'https://mediator.rtc.yandex.net:8443')
)

app = Flask(__name__)


@app.route('/')
def backend():
    return service.get_backend(request.args.get('uri'))
