from collections import defaultdict, namedtuple, Counter
from datetime import datetime, timedelta
import json
import msgpack
import sys
import subprocess

# noinspection PyUnresolvedReferences
from kernel.util.console import setProcTitle
# noinspection PyUnresolvedReferences
from kernel.util import logging
# noinspection PyUnresolvedReferences
from library.sky.hosts import braceExpansion

from libraries.resolver import resolve_hosts, resolve_instances

import flask
import gevent
import gevent.monkey
import gevent.event
import gevent.pywsgi

gevent.monkey.patch_all()

import utils

app = flask.Flask(__name__)


class ResolverFutures(object):
    def __init__(self, cache):
        super(ResolverFutures, self).__init__()
        self.futures = {}
        self.cache = cache

    def wait(self, term):
        if term not in self.futures:
            self.futures[term] = gevent.event.AsyncResult()
            gevent.spawn(_resolve, term).link(self.futures[term])
        res = self.futures[term].wait()
        if self.cache.find(term) is None:
            self.cache.set(term, msgpack.dumps(res))
            del self.futures[term]


class Updater(gevent.Greenlet):
    warmup_terms = [
        # placeholder
    ]
    MAX_CACHE_SIZE = 30

    def __init__(self):
        super(Updater, self).__init__()
        self.most_frequently_asked_terms = Counter()
        self.cache = Cache()
        self.futures = ResolverFutures(self.cache)

    def update_frequent(self, term):
        if term not in self.warmup_terms and 'C@' in term:
            self.most_frequently_asked_terms.update((term,))

    def resolve(self, term):
        if self.cache.find(term) is None:
            self.futures.wait(term)
        return self.cache.find(term)

    def status(self):
        return self.cache.status()

    @staticmethod
    def async_resolve(term):
        proc = subprocess.Popen([
            sys.executable, '--resolve-term', term],
            stdout=subprocess.PIPE
        )
        if proc.wait() == 0:
            logging.info('resolved {}'.format(term))
            return proc.stdout.read()
        raise RuntimeError('Unable to resolve [{}]'.format(term))

    def _warmup(self):
        cache_size = min(len(self.most_frequently_asked_terms), self.MAX_CACHE_SIZE)
        popular = [a[0] for a in self.most_frequently_asked_terms.most_common(cache_size)]
        for term in self.warmup_terms + popular:
            try:
                self.cache.set(term, self.async_resolve(term))
            except RuntimeError:
                logging.exception('Unable to resolve [{}]'.format(term))  # Advance to the next term

    def _run(self):
        while True:
            try:
                self._warmup()
            except Exception as ex:
                logging.getLogger('resolver').error(ex)
                pass
            finally:
                gevent.sleep(250)


def _resolve(term):
    try:
        start = datetime.now()
        hosts, instances = _do_resolve(term)
        logging.info('resolved [%s] in %.2f sec',
                     term, (datetime.now() - start).microseconds / 1000000.0)
    except Exception:  # TODO(okats): re-raise when clst is ready for 503s
        logging.exception('resolve [%s] failed', term)
        hosts, instances = set(), dict()

    def norm(x):
        return x.replace('.search.yandex.net', '').replace('.yandex.ru', '')

    norm_instances = defaultdict(dict)
    for h, v in instances.iteritems():
        for (s, i) in v:
            i = i.split('@', 1)[0]
            norm_instances[norm(h)][i] = s

    norm_hosts = sorted(map(norm, hosts))
    return norm_hosts, norm_instances


def _do_resolve(term):
    command = braceExpansion([term], True)
    if term.startswith('C@') and not term.startswith('C@ONLINE'):  # eligible for hq resolving
        return hq_resolve(command)

    if not term.startswith('H@') and not term.startswith('h@') \
            and not term.startswith('d@') and not term.startswith('dc@'):
        instances = resolve_instances(command)
    else:
        instances = {}
    gevent.sleep()
    hosts = resolve_hosts(command)
    return hosts, instances


def hq_resolve(term):
    def fix(x):
        return {k: fixx(v) for k, v in x.iteritems()}

    def fixx(x):
        return {(fixxx(a), fixxx(b)) for (a, b) in x}

    def fixxx(x):
        return utils.shortname(x) if x is not None else 'none'

    # noinspection PyUnresolvedReferences
    from library.sky.hostresolver.resolver import Resolver

    resolver = Resolver(False, False, use_hq=True)
    return resolver.resolveHosts(term), fix(resolver.resolveInstances(term))


class Cache(object):
    """Dict with expiration. No size limit, no garbage collecting."""

    Item = namedtuple('Item', ['time', 'value'])

    def __init__(self, timeout=timedelta(minutes=5)):
        self.cache = {}
        self.timeout = timeout

    def find(self, key):
        if key not in self.cache or self.cache[key].time + self.timeout < datetime.now():
            return None
        return self.cache[key].value

    def set(self, key, value):
        self.cache[key] = Cache.Item(value=value, time=datetime.now())

    def status(self):
        return {item[0]: str(item[1].time) for item in self.cache.iteritems()}

    def __len__(self):
        return len(self.cache)


updater = Updater()


@app.route('/')
def get_main():
    term = flask.request.args.get('term', '')
    updater.update_frequent(term)
    data = updater.cache.find(term)
    if data is None:
        data = updater.resolve(term)
        logging.warn('in rslv %s', data)
    if data is None:
        return flask.Response('Unable to resolve. Probably configuration does not exist', status=503)
    return flask.Response(data)


@app.route('/status')
def get_status():
    return flask.Response(json.dumps(updater.status(), indent=4))


def main_term(term):
    return msgpack.dumps(_resolve(term))


def main_web(port):
    @app.after_request
    def pre_request_logging(response):
        app.logger.info(' '.join([
            flask.request.remote_addr,
            flask.request.method,
            flask.request.url,
            str(response.status_code)
        ]))
        return response

    # logging.getLogger('werkzeug').addHandler(handler)
    setProcTitle('resolver [{}]'.format(port))

    updater.start()
    server = gevent.pywsgi.WSGIServer(('::', port), app)
    server.serve_forever()
