import sys
import json
import pprint
import logging as builtin_logging
from datetime import datetime, timedelta
from collections import defaultdict, namedtuple

import msgpack
import tornado.gen
import tornado.ioloop
import tornado.web
from concurrent.futures import ProcessPoolExecutor

from utils import fqdn
from kernel.util import logging
from library.sky.hosts import braceExpansion


from libraries.resolver import resolve_hosts, resolve_instances_with_meta


class Service(tornado.web.Application):
    def __init__(self, port):
        super(Service, self).__init__(handlers=[
            ('/', MainHandler),
            ('/status', StatusHandler),
            ('/v1/?', MainHandler_V1),
            ('/v1.0/?', MainHandler_V1),
        ])

        self.cache = Cache()
        self.port = port

        self.request_executor = ProcessPoolExecutor(max_workers=8)

        self.log = logging.getLogger('resolver_app')

    def start(self):
        self.log.info('listening on %s', self.port)
        self.listen(self.port)

        try:
            tornado.ioloop.IOLoop.current().start()
        except KeyboardInterrupt:
            self.log.info('shutting down')
            self.request_executor.shutdown(wait=True)
            tornado.ioloop.IOLoop.current().stop()

    def status(self):
        return self.cache.status()

    def resolve_async(self, term):
        return self._resolve_async(term, self.request_executor)

    def _resolve_async(self, term, executor):
        future = executor.submit(_resolve, term)
        future.term = term
        future.add_done_callback(self._resolve_done)
        return future

    def _resolve_done(self, future):
        self.cache.set(future.term, future.result())


def _resolve(term):
    # noinspection PyBroadException
    try:
        start = datetime.now()
        hosts, instances = _do_resolve(term)
        logging.getLogger('resolver').info(
            'resolved [%s] in %.2f sec',
            term,
            (datetime.now() - start).microseconds / 1000000.0
        )
    except:
        logging.getLogger('resolver').exception('resolve [%s] failed', term, exc_info=sys.exc_info())
        hosts, instances = set(), {}

    instances['instances'] = _normalize_instances(instances['instances'])

    return list(hosts), instances


def _normalize_instances(instances):
    norm_instances = defaultdict(dict)
    for h, v in instances.iteritems():
        for (s, i) in v:
            i = i.split('@', 1)[0]
            norm_instances[h][i] = s
    return norm_instances


def _do_resolve(term):
    command = braceExpansion([term], True)

    if not term.startswith('H@') and not term.startswith('h@') and not term.startswith('d@') and not term.startswith('dc@'):
        instances = resolve_instances_with_meta(command)
    else:
        instances = {}

    hosts = resolve_hosts(command)
    return hosts, instances


class Cache(object):
    """Dict with expiration. No size limit, no garbage collecting."""

    Item = namedtuple('Item', ['time', 'value'])

    def __init__(self, timeout=timedelta(minutes=5)):
        self.cache = {}
        self.timeout = timeout

    def find(self, key):
        if key not in self.cache or self.cache[key].time + self.timeout < datetime.now():
            return None
        return self.cache[key].value

    def find2(self, key):
        if key not in self.cache or self.cache[key].time + self.timeout < datetime.now():
            return None
        return self.cache[key]

    def set(self, key, value):
        self.cache[key] = Cache.Item(value=value, time=datetime.now())

    def status(self):
        return {item[0]: str(item[1].time) for item in self.cache.iteritems()}

    def __len__(self):
        return len(self.cache)


class PeriodicCall(tornado.ioloop.PeriodicCallback):
    """Run the callback now, and then run it periodically if not stopped."""

    def __init__(self, callback, callback_time):
        super(PeriodicCall, self).__init__(callback, callback_time)
        self.stopped = False

    def start(self):
        self.callback()
        if not self.stopped:
            super(PeriodicCall, self).start()

    def stop(self):
        self.stopped = True


class MainHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self, *args, **kwargs):
        term = self.get_argument('term', '', True)
        data = self.application.cache.find(term)
        if not data:
            future = self.application.resolve_async(term)
            yield future
            data = self.application.cache.find(term)
        self.write(msgpack.dumps(data))


class MainHandler_V1(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self, *args, **kwargs):
        term = self.get_argument('term', '', True)
        fmt = self.get_argument('format', 'json', True)
        data = self.application.cache.find2(term)
        if not data:
            future = self.application.resolve_async(term)
            yield future
            data = self.application.cache.find2(term)

        result = {'instances': []}
        _, instances = data.value
        builtin_logging.error(pprint.pformat(data.value))

        for host, instance in instances['instances'].iteritems():
            for instance_str, shard in instance.iteritems():
                result['instances'].append({
                    'host': host,
                    'port': int(instance_str.split(':')[1]),
                    'extra': {
                        'shard': shard,
                    },
                })

        result['meta'] = {
            'backend': {
                'host': fqdn(),
                'port': self.application.port,
            },
            'versions': instances['meta'],
            'timestamp': int(data.time.strftime('%s')),
        }

        self.write((json if fmt == 'json' else msgpack).dumps(result))


class StatusHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self):
        self.write(json.dumps(self.application.status(), indent=4))
