from collections import defaultdict, namedtuple, Counter
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timedelta
import json
import msgpack
import sys

from kernel.util.console import setProcTitle
from kernel.util import logging
from library.sky.hosts import braceExpansion

import tornado.gen
import tornado.ioloop
import tornado.web

from libraries.resolver import resolve_hosts, resolve_instances
from utils import configure_log


class WebApp(tornado.web.Application):
    def __init__(self, port):
        super(WebApp, self).__init__(handlers=handlers)

        self.cache = Cache()
        self.port = port

        self.periodic_update = PeriodicCall(self._warmup, 250000)
        self.warmup_executor = ProcessPoolExecutor(max_workers=4)
        self.request_executor = ProcessPoolExecutor(max_workers=8)

        self.log = logging.getLogger('resolver_app')

    def start(self):
        self.periodic_update.start()
        self.log.info('listening on %s', self.port)
        self.listen(self.port, address='0.0.0.0')

        try:
            tornado.ioloop.IOLoop.current().start()
        except KeyboardInterrupt:
            self.log.info('shutting down')
            self.periodic_update.stop()
            self.warmup_executor.shutdown(wait=True)
            self.request_executor.shutdown(wait=True)
            tornado.ioloop.IOLoop.current().stop()

    def status(self):
        return self.cache.status()

    def resolve_async(self, term):
        if term not in self.warmup_terms and 'C@' in term:
            self.most_frequently_asked_terms.update((term,))
        return self._resolve_async(term, self.request_executor)

    def _resolve_async(self, term, executor):
        future = executor.submit(_resolve, term)
        future.term = term
        future.add_done_callback(self._resolve_done)
        return future

    def _warmup(self):
        n = len(self.most_frequently_asked_terms)
        n = n if n < 10 else 10
        popular = [a[0] for a in self.most_frequently_asked_terms.most_common(n)]
        for term in self.warmup_terms + popular:
            self._resolve_async(term, self.warmup_executor)

    def _resolve_done(self, future):
        result = future.result()
        data = msgpack.dumps(result)
        self.cache.set(future.term, data)

    # take them from logs
    warmup_terms = [
        'I@MSK_L7_BALANCER',
        'I@production_upper_msk_web . I@MSK_WEB_NMETA',
        'I@production_noapache_msk_web . I@MSK_WEB_NOAPACHE',
        'I@production_mmeta',
        'I@production_int',
        'I@production_base - I@nostart',
        'I@SAS_L7_BALANCER',
        'I@production_upper_sas_web',
        'I@production_noapache_sas_web',
        'I@production_sas_mmeta',
        'I@production_sas_int',
        'I@production_sas_base',
        'I@production_upper_man_web',
        'I@production_noapache_man_web',
        'I@production_man_mmeta',
        'I@production_man_int',
        'I@production_man_base',
        'I@AMS_WEB_BALANCER',
        'I@production_upper_msk_imgs',
        'I@production_imgmmeta . I@a_geo_msk',
        'I@production_imgsint . I@a_geo_msk',
        'I@production_imgsbase . I@a_geo_msk',
        'I@production_upper_sas_imgs',
        'I@production_imgmmeta . I@a_geo_sas',
        'I@production_imgsint . I@a_geo_sas',
        'I@production_imgsbase . I@a_geo_sas',
        'C@HEAD',
        'I@ALL_SEARCH',
    ]

    most_frequently_asked_terms = Counter()


def _resolve(term):
    # noinspection PyBroadException
    try:
        start = datetime.now()
        hosts, instances = _do_resolve(term)
        logging.getLogger('resolver').info(
            'resolved [%s] in %.2f sec',
            term,
            (datetime.now() - start).microseconds / 1000000.0
        )
    except:
        logging.getLogger('resolver').exception('resolve [%s] failed', term, exc_info=sys.exc_info())
        hosts, instances = set(), {}

    def norm(x):
        return x.replace('.search.yandex.net', '').replace('.yandex.ru', '')

    norm_instances = defaultdict(dict)
    for h, v in instances.iteritems():
        for (s, i) in v:
            i = i.split('@', 1)[0]
            norm_instances[norm(h)][i] = s

    norm_hosts = map(norm, hosts)
    return norm_hosts, norm_instances


def _do_resolve(term):
    command = braceExpansion([term], True)

    if not term.startswith('H@') and not term.startswith('h@') and not term.startswith('d@') and not term.startswith('dc@'):
        instances = resolve_instances(command)
    else:
        instances = {}

    hosts = resolve_hosts(command)
    return hosts, instances


class Cache(object):
    """Dict with expiration. No size limit, no garbage collecting."""

    Item = namedtuple('Item', ['time', 'value'])

    def __init__(self, timeout=timedelta(minutes=5)):
        self.cache = {}
        self.timeout = timeout

    def find(self, key):
        if key not in self.cache or self.cache[key].time + self.timeout < datetime.now():
            return None
        return self.cache[key].value

    def set(self, key, value):
        self.cache[key] = Cache.Item(value=value, time=datetime.now())

    def status(self):
        return {item[0]: str(item[1].time) for item in self.cache.iteritems()}

    def __len__(self):
        return len(self.cache)


class PeriodicCall(tornado.ioloop.PeriodicCallback):
    """Run the callback now, and then run it periodically if not stopped."""

    def __init__(self, callback, callback_time):
        super(PeriodicCall, self).__init__(callback, callback_time)
        self.stopped = False

    def start(self):
        self.callback()
        if not self.stopped:
            super(PeriodicCall, self).start()

    def stop(self):
        self.stopped = True


class MainHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self, *args, **kwargs):
        term = self.get_argument('term', '', True)
        data = self.application.cache.find(term)
        if data is not None:
            self.write(data)
        else:
            future = self.application.resolve_async(term)
            yield future
            data = self.application.cache.find(term)
            self.write(data)


class StatusHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self):
        self.write(json.dumps(self.application.status(), indent=4))


handlers = [
    ('/', MainHandler),
    ('/status', StatusHandler),
]


def main():
    args = parse_args()
    configure_log(app='resolver', debug=args.debug, beta=args.beta)
    setProcTitle('resolver [{}]'.format(args.port))

    WebApp(args.port).start()


def parse_args():
    import argparse

    parser = argparse.ArgumentParser(description='Clusterstate resolver service')
    parser.add_argument('-p', '--port', default=20214, help='listen port')
    parser.add_argument('--warmup', dest='warmup', action='store_true', help='[default] warm up before port opening')
    parser.add_argument('--no-warmup', dest='warmup', action='store_false', help="don't warm up")
    parser.add_argument('--debug', dest='debug', default=False, action='store_true', help='enable fancy logging to stdout')
    parser.add_argument('--beta', dest='beta', default=False, action='store_true', help='run as beta (affects logging)')
    parser.set_defaults(warmup=True)
    return parser.parse_args()


if __name__ == '__main__':
    main()
