import argparse
import collections
import json
import logging
import math
import os
import pyuwsgi
import random
import sys
import time
import threading

from flask import Flask, request, request, jsonify, Response
from multiprocessing import Manager, Lock, Value
from threading import RLock

from config import Config
from datetime import timedelta
from mongo import QloudClient, MongoClient
from mongo_discovery import connect_to_location
from queue import Empty, Full
from threading import current_thread


FQDN_KEY_PREFIX = 'f_'
SUPPORTED_KEYS = set(['installation', 'segment', 'loc'])


class ShardedData(object):
    def __init__(self, manager, queue_size=100):
        self.shared_dict = manager.dict({
            "stats": {},
            'segment_cache_ready': False,
            'last_successful_update': None
        })
        self.http_errors = Value('i')


class Server(object):
    app = Flask(__name__)

    def __init__(self, config, clients, shared_data):
        self.clients = clients
        self.shared_dict = shared_data.shared_dict
        self.shared_data = shared_data

        self.app.logger.setLevel(getattr(logging,
            config.app.get('loglevel', 'WARN')))
        self.logger = self.app.logger

        cache_settings = config.app.get('segment_cache', {})
        self.update_period = cache_settings.get('update_period', 600)
        self.update_jitter = cache_settings.get('update_jitter', 60)
        self.batch_size = config.db.get('batch_size', 100)
        self.enable_stat = config.app.get('stat', {}).get('enabled', False)
        self.request_queue_size = \
            config.app.get('stat', {}).get('request_queue_size', 100)

    def _configure_routes(self):
        if self.enable_stat:
            def _put_thread_local_queue(request, name):
                thread_local = current_thread().__dict__
                queue_name = name + '_queue'
                if queue_name not in thread_local:
                    thread_local[queue_name] = collections.deque(maxlen=self.request_queue_size)
                request.queue = thread_local[queue_name]

            def _register_stat_data():
                request.start_time = time.time()
                if request.path.startswith('/v1/hosts'):
                    _put_thread_local_queue(request, 'hosts')
                elif request.path.startswith('/v1/host'):
                    _put_thread_local_queue(request, 'host')
                else:
                    request.queue = None

            def _process_stat_data(resp):
                time_spent = time.time() - request.start_time

                # TODO: performance bottleneck
                stat_queue = request.queue
                if stat_queue is not None:
                    stat_queue.append(time_spent)

                if resp.status_code not in [200, 404]:
                    counter = self.shared_data.http_errors
                    with counter.get_lock():
                        counter.value += 1
                return resp

            self.app.before_request(_register_stat_data)
            self.app.after_request(_process_stat_data)

        @self.app.route('/ping', methods=['GET'])
        def get_ping():
            # Return 200 as we alive
            return "pong", 200

        @self.app.route('/v1/hosts', methods=['GET'])
        @self.app.route('/v1/hosts/<properties>', methods=['GET'])
        def get_hosts_data(properties = None):
            properties = self._process_properties(properties)
            if not properties:
                return self._error("unsupported properties: %s" % (properties, ))

            if not self.shared_dict['segment_cache_ready']:
                return self._error("cache is not ready yet", status=500)

            snapshot = {}
            try:
                for k, v in self.shared_dict.copy().items():
                    if k.startswith(FQDN_KEY_PREFIX):
                        result = self._generate_response(v, properties)
                        if result:
                            snapshot[k[len(FQDN_KEY_PREFIX):]] = result
                return jsonify(snapshot)
            except Exception as error:
                self.logger.exception('Fatal error "{}" during list queued request'.format(str(error)))
                return self._error(str(error), status=500)

        @self.app.route('/v1/host/<fqdn>', methods=['GET'])
        @self.app.route('/v1/host/<fqdn>/<properties>', methods=['GET'])
        def get_host_data(fqdn, properties=None):
            properties = self._process_properties(properties)
            if not properties:
                return self._error("unsupported properties: %s" % (properties, ), status=404)

            if not self.shared_dict['segment_cache_ready']:
                return self._error("cache is not ready yet", status=500)

            try:
                fqdn = fqdn.lower()
                location = self.shared_dict.get(FQDN_KEY_PREFIX + fqdn, None)
                if location is None:
                    return self._error("no data for %s" % fqdn, status=404)

                return jsonify(self._generate_response(location, properties))
            except Exception as error:
                self.logger.exception('Fatal error "{}" during list queued request'.format(str(error)))
                return self._error(str(error), status=500)


        @self.app.route('/unistat', methods=['GET'])
        def get_unistat():
            messages = []
            for name, count in self.shared_dict['stats'].items():
                messages.append('["{}", {:.3f}]'.format(name, count or 0))

            last_successful_update = self.shared_dict['last_successful_update']
            if last_successful_update is not None:
                messages.append('["segment_cache.last_successful_update_axxx", {:.3f}]'.format(time.time() - last_successful_update))

            if self.enable_stat:
                _latency_stat(messages, 'host', [50, 75, 95, 99, 100])
                _latency_stat(messages, 'hosts', [50, 75, 95, 99, 100])

                counter = self.shared_data.http_errors
                with counter.get_lock():
                    messages.append('["application.http_errors_mxxx", {}]'.format(counter.value))
                    counter.value = 0

            if messages:
                st = "[\n"
                for message in messages[:-1]:
                    st += message + ",\n"
                st += messages[-1] + "\n"
                st += "]"
                return st
            return "[]"

        def _latency_stat(messages, req_type, percentiles):
            recent_data = []

            # TODO: unreliable
            queue_name = req_type + '_queue'
            for thread in threading.enumerate():
                queue = thread.__dict__.get(queue_name)
                if queue:
                    recent_data += list(queue)

            if recent_data:
                latencies = sorted(list(recent_data))
                for percentile in percentiles:
                    messages.append('["application.req_latency_ms.{}_p{}_axxx", {:.3f}]'.format(
                        req_type, percentile, 1000 * _percentile(latencies, percentile / 100.)))

        def _percentile(N, percent, key=lambda x: x):
            if not N:
                return None
            k = (len(N) - 1) * percent
            f = math.floor(k)
            c = math.ceil(k)
            if f == c:
                return key(N[int(k)])
            d0 = key(N[int(f)]) * (c - k)
            d1 = key(N[int(c)]) * (k - f)
            return d0 + d1


    def _process_properties(self, properties):
        if properties is None:
            return SUPPORTED_KEYS

        if ':' not in properties:
            return set([properties]) & SUPPORTED_KEYS
        return set(properties.split(':')) & SUPPORTED_KEYS

    def _generate_response(self, loc, properties):
        if not properties:
            return None

        immediate = len(properties) == 1
        if immediate and 'loc' in properties:
            return loc

        result = {}
        match_all = not properties
        if 'loc' in properties or match_all:
            result['loc'] = loc

        installation, segment = loc.split('.')
        if 'segment' in properties or match_all:
            if immediate:
                return segment
            result['segment'] = segment
        if 'installation' in properties or match_all:
            if immediate:
                return installation
            result['installation'] = installation
        return result

    def _error(self, description, status=400):
        return jsonify({'error': description}), status

    def _calculate_segment_mapping(self, _id):
        result = {}
        start_time = time.time()

        failed_installations = []
        segment_errors = 0
        segment_conflicts = 0
        for client in self.clients:
            try:
                boxes = client.client.db.box\
                    .find({}, {'_id': 1, 'hardwareSegment': 1})\
                    .batch_size(self.batch_size)

                for box in boxes:
                    fqdn = box['_id'].lower()
                    segment = box.get('hardwareSegment')
                    if not segment or '.' in segment:
                        segment_errors += 1
                        self.logger.warn('Invalid segment `%s` found for fqdn %s' % (segment, fqdn))
                    elif fqdn in result:
                        segment_conflicts += 1
                        old_installation = result.get(fqdn, {}).get('installation')
                        self.logger.warn('Decline value from installation %s for fqdn %s. '\
                            'Fqdn already registered in %s' % (client.installation, fqdn, old_installation))
                    else:
                        result[FQDN_KEY_PREFIX + fqdn] = '%s.%s' % (client.installation, segment)
            except Exception as e:
                failed_installations.append(client.installation)
                self.logger.error("Unable to connect to %s: %s" % (
                    client.installation, e), exc_info=True)
                continue

        data = self.shared_dict.copy()
        stats = data['stats']
        stats['segment_cache.size_axxx'] = len(result)
        stats['segment_cache.conflicts_axxx'] = segment_conflicts
        stats['segment_cache.invalid_axxx'] = segment_errors
        stats['segment_cache.fill_duration_axxx'] = time.time() - start_time
        stats['segment_cache.failed_installations_axxx'] = len(failed_installations)

        if not failed_installations:
            data['last_successful_update'] = time.time()
        data['stats'] = stats
        data['segment_cache_ready'] |= not failed_installations

        data.update(result)

        self.shared_dict.update(data)
        for key, value in self.shared_dict.copy().items():
            if key.startswith(FQDN_KEY_PREFIX) and key not in result:
                is_failed = any([value.startswith(i + '.') for i in failed_installations])
                if not is_failed:
                    # TODO: too many potential locks?
                    del self.shared_dict[key]

        jitter_time = random.uniform(0, self.update_jitter);
        time.sleep(jitter_time)
        self.logger.debug('%s fqdns loaded in %.2fs (%.2fs jitter)' % (
            len(result), time.time() - start_time, jitter_time))

    def make_app(self):
        import uwsgi
        import uwsgidecorators

        self._configure_routes()
        self.logger.info('Creating app. Segment cache update period: %s, jitter: %s' % (
            self.update_period, self.update_jitter))

        uwsgi.register_signal(55, "worker", self._calculate_segment_mapping)
        uwsgi.add_timer(55, self.update_period)
        uwsgi.add_rb_timer(55, 0, 1)
        return self.app


def parse_args():
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('-c', '--config', type=str,
                        default='config.yaml',
                        help='Config file (default config.yaml)')
    parser.add_argument('-d', '--debug', type=bool, nargs='?',
                        default=False,
                        help='Enable debug mode in Flask')
    return parser.parse_args()


def app():
    args = parse_args()

    config = Config(args)

    stream = logging.StreamHandler(sys.stdout)
    stream.setLevel(getattr(logging, config.app.get('loglevel', 'WARN')))
    formatter = logging.Formatter('%(levelname)-8s [%(asctime)-15s] %(message)s')
    stream.setFormatter(formatter)
    root = logging.getLogger()
    root.addHandler(stream)

    request_queue_size = \
        config.app.get('stat', {}).get('request_queue_size', 100)

    shared_data = ShardedData(Manager(), request_queue_size)
    db_password = os.environ['MONGO_PASSWORD']
    nanny_token = os.environ['NANNY_TOKEN']

    reconnect_timeout = config.db.get('reconnect_timeout')
    clients = []
    for installation in config.db.get('installations'):
        clients.append(QloudClient(
            installation,
            MongoClient(
                (lambda i: lambda: connect_to_location(
                  i,
                  nanny_token,
                  config.db.get('user'),
                  db_password))(installation),
                reconnect_timeout)))

    api = Server(config, clients, shared_data)
    return api.make_app()


def main():
    args = parse_args()
    config = Config(args)

    args = [
        '--module', 'server:app()',
        '--pyargv', ' '.join(sys.argv[1:]),
        *config.app.get('uwsgi_args', [])
    ]
    pyuwsgi.run(args)


if __name__ == '__main__':
    main()
