import threading
import random
import time
import functools
import socket
import os
import hashlib
from datetime import datetime
from collections import defaultdict

import flask
import yaml
import msgpack
import json
from pymongo import MongoClient

from genisys.web.model import MongoStorage
from genisys.web import stats


def nodeinfo_headers_wrapper(node_info, wsgi_app):
    node_info = [
        ('X-Genisys-{}'.format(key.replace('_', '-').title()), str(value))
        for key, value in node_info.items()
    ]
    @functools.wraps(wsgi_app)
    def wrapped(environ, start_response):
        def injecting_headers_start_response(status, response_headers,
                                             exc_info=None):
            response_headers.extend(node_info)
            return start_response(status, response_headers, exc_info)
        return wsgi_app(environ, injecting_headers_start_response)
    return wrapped


def make_app():
    app = flask.Flask('genisys.api')

    app.config.from_object('genisys.web.config')
    app.config.from_envvar('GENISYS_API_CONFIG')

    stats.set_up(app)

    mongoclient = MongoClient(app.config['MONGODB_URI'])
    app.storage = MongoStorage(mongoclient, app.config['MONGODB_DB_NAME'])

    app.tree_structure = TreeStructure(
        app.storage, app.logger.getChild('cache'),
        cache_validity_range=app.config['API_CACHE_VALIDITY_RANGE'],
    )
    app.tree_structure.start()

    # hivemind api emulation
    app.add_url_rule('/v1/hosts/<hostname>/skynet',
                     defaults={'from_path': '', 'hivemind_emulation': True},
                     view_func=host_config, methods=['GET'])
    app.add_url_rule('/v1/hosts/<hostname>/skyversion',
                     view_func=host_skyversion_hivemind, methods=['GET'])
    app.add_url_rule('/v1/rules',
                     view_func=rules_by_path, methods=['GET'])
    app.add_url_rule('/v1/skynet/version/rules',
                     view_func=skyversion_rules_hivemind, methods=['GET'])
    # native api
    app.add_url_rule('/v2/hosts/<hostname>',
                     defaults={'from_path': '', 'hivemind_emulation': False},
                     view_func=host_config, methods=['GET'])
    app.add_url_rule('/v2/hosts/<hostname>/<from_path>',
                     defaults={'hivemind_emulation': False},
                     view_func=host_config, methods=['GET'])
    app.add_url_rule('/v2/hosts-by-path-and-rulename',
                     view_func=hosts_by_path_and_rulename)
    # health check
    app.add_url_rule('/ping', view_func=ping, methods=['GET'])

    node_info = {'hostname': socket.gethostname(),
                 'pid': os.getpid(),
                 'address': app.config['WSGI_LISTEN_ADDRESS']}
    node_info_str = json.dumps(node_info, sort_keys=True)

    @app.errorhandler(Exception)
    def server_error(error):
        app.logger.exception(error)
        text = "Internal server error: %s\nNode info: %s" % (error,
                                                             node_info_str)
        return text, 500

    def stop_aside_threads():
        app.tree_structure.stop()

    app.cleanup = stop_aside_threads

    app.wsgi_app = nodeinfo_headers_wrapper(node_info, app.wsgi_app)

    return app


def ping():
    up_to_date = flask.current_app.tree_structure.is_up_to_date()
    status = 200 if up_to_date else 503
    return flask.current_app.response_class(status=status)


def host_skyversion_hivemind(hostname):
    nts = flask.current_app.tree_structure
    result, mtime = nts.sample(hostname, from_path='skynet.versions')
    result = {
        'SKYNET_BINARY': {
            'attrs': result['config']['attributes'],
            'http': result['config']['http']['links'],
            'md5': result['config']['md5'],
            'name': result['config']['description'],
            'rsync': result['config']['rsync']['links'],
            'size': result['config']['size'] >> 10,
            'skynet': result['config']['skynet_id']},
        'conf_author': result['changed_by'],
        'conf_id': str(result['revision']),
        'conf_mtime': result['mtime'],
        'svn_url': result['config']['attributes'].get('svn_url')
    }
    return _response_serialize(result, last_modified=mtime)


def skyversion_rules_hivemind():
    # TODO: remove for the great good when lacmus2 is out there
    nts = flask.current_app.tree_structure
    node, hosts, configs = nts.all_rules('skynet.versions')
    result = []
    hosts_lists = {}
    keys = sorted(k for k in configs if k >= 0)
    if -1 in configs:
        keys.append(-1)
    for key in keys:
        rule = configs[key]
        cfg = rule['config']
        hosts_list = hosts_lists[key] = []
        result.append({
            'name': rule['matched_rules'][0],
            'desc': '',
            'sandbox_task_name': cfg['description'],
            'resolved': {
                'resource': {
                    'SKYNET_BINARY': {'attrs': cfg['attributes']},
                    'svn_url': cfg['attributes'].get('svn_url'),
                },
                'hosts': hosts_list,
            }
        })
    for hostname in sorted(hosts):
        hosts_lists[hosts[hostname]].append(hostname)
    last_modified = node['mtime']
    response = _response_serialize(result, last_modified=last_modified)
    response.headers['X-Ya-Cms-Conf-Id'] = str(last_modified)
    return response


def rules_by_path():
    path = flask.request.args.get('path')
    if not path:
        flask.abort(400)

    nts = flask.current_app.tree_structure
    node, hosts, configs = nts.all_rules(path)
    result = []
    hosts_lists = {}
    keys = sorted(k for k in configs if k >= 0)
    if -1 in configs:
        keys.append(-1)
    for key in keys:
        rule = configs[key]
        cfg = rule['config']
        hosts_list = hosts_lists[key] = []
        result.append({
            'name': rule['matched_rules'][0],
            'desc': '',
            'sandbox_task_name': cfg['description'],
            'resolved': {
                'resource': {
                    'UNKNOWN': {'attrs': cfg['attributes']},
                    'svn_url': cfg['attributes'].get('svn_url'),
                },
                'hosts': hosts_list,
            }
        })
    for hostname in sorted(hosts):
        hosts_lists[hosts[hostname]].append(hostname)
    last_modified = node['mtime']
    response = _response_serialize(result, last_modified=last_modified)
    response.headers['X-Ya-Cms-Conf-Id'] = str(last_modified)
    return response


def host_config(hostname, from_path, hivemind_emulation):
    result, mtime = flask.current_app.tree_structure.sample(
        hostname, from_path, hivemind_emulation
    )
    if result is None:
        flask.abort(404)
    if 'config_hash' in flask.request.args:
        if flask.request.args['config_hash'] == result['config_hash']:
            return flask.current_app.response_class(status=304)
    return _response_serialize(result, last_modified=mtime)


def hosts_by_path_and_rulename():
    path = flask.request.args.get('path')
    rulename = flask.request.args.get('rulename')
    hosts = flask.current_app.tree_structure.hosts_by_path_and_rulename(
        path, rulename
    )
    return _response_serialize({'hosts': hosts}, None)


class TreeStructure(threading.Thread):
    STATUS_KEYS = "ctime etime mtime utime ttime " \
                  "tcount ucount mcount last_status".split()
    NULL_CONFIG_HASH = '0000000000000000000000000000000000000000'

    def __init__(self, storage, logger, cache_validity_range):
        super(TreeStructure, self).__init__()
        self.cache_validity_range = cache_validity_range
        self.ping_cache_validity = cache_validity_range[-1] + 10
        self.logger = logger
        self.storage = storage
        self._cached_timestamp = None
        self._cached_structure = None
        self._cached_configs = {}
        self._sec_volatiles = {}
        self._recache()
        self._stopped = threading.Event()
        self._stopped.clear()

    def run(self):
        while not self._stopped.is_set():
            time_to_sleep = random.randint(*self.cache_validity_range)
            self.logger.info('sleeping for %ds', time_to_sleep)
            self._stopped.wait(timeout=time_to_sleep)
            if self._stopped.is_set():
                break
            try:
                self._recache()
            except:
                self.logger.exception('')

    def stop(self):
        self.logger.info('stopping')
        self._stopped.set()
        self.join()
        self.logger.info('stopped')

    def _recache(self):
        self.logger.info('starting updating cache')
        root = self.storage.get_section_subtree("", structure_only=True)
        all_paths = [section['path'] for section in self._iter_sections(root)]
        sec_volatiles = self.storage.get_volatiles(
            vtype='section', keys=all_paths, strict=False, full=False
        )
        paths_to_fetch = []
        for path, status in sec_volatiles.items():
            if path not in self._sec_volatiles:
                paths_to_fetch.append(path)
                continue
            old_status = self._sec_volatiles[path]
            if status['ttime'] is None:
                continue
            if old_status['ttime'] is not None \
                    and old_status['ttime'] >= status['ttime']:
                # don't allow cached value to go back in time
                continue
            if old_status['mtime'] != status['mtime']:
                # cached value is outdated, need to fetch
                # actual value of the section
                paths_to_fetch.append(path)
            else:
                # need to update metadata only
                old_status.update(status)

        if not paths_to_fetch:
            self.logger.info("no section value has changed")
        else:
            self.logger.info("updating values for sections: %s",
                             paths_to_fetch)
            self._sec_volatiles.update(self.storage.get_volatiles(
                vtype='section', keys=paths_to_fetch, strict=False, full=True
            ))
            self._cached_configs.update({
                path: self._sec_volatiles[path]['value']
                for path in paths_to_fetch
            })
        structure = self._make_structure(root, self._sec_volatiles)
        self._cached_structure = msgpack.dumps(structure, encoding='utf8')
        self._cached_timestamp = time.time()

    def hosts_by_path_and_rulename(self, path, rulename):
        cfg = self._cached_configs.get(path)
        if not cfg:
            return None
        cfg_ids = set(key for key in cfg['configs']
                      if rulename in cfg['configs'][key]['matched_rules'])
        if not cfg_ids:
            return None
        hosts = [host for host, cfg_id in cfg['hosts'].items()
                 if cfg_id in cfg_ids]
        return hosts

    def _get_config_id(self, status_value, hostname):
        if status_value is None:
            return None
        config_id = status_value['hosts'].get(hostname)
        if config_id is None:
            config_id = -1
        if config_id in status_value['configs']:
            return config_id
        return None

    def _iter_sections(self, root):
        stack = [root]
        while stack:
            item = stack.pop(0)
            stack.extend(item['subsections'].values())
            yield item

    def is_up_to_date(self):
        return time.time() - self._cached_timestamp < self.ping_cache_validity

    def _make_structure(self, root, sec_volatiles):
        for section in self._iter_sections(root):
            status = sec_volatiles[section['path']]
            for key in self.STATUS_KEYS:
                section[key] = status[key]
            section['revision'] = status['meta']['revision']
            section['changed_by'] = status['meta']['changed_by']
            section['changed_at'] = status['meta']['mtime']
            section['owners'] = status['meta']['owners']
            section['stype'] = status['source']['stype']
            section['stype_options'] = status['source']['stype_options']
        return root

    def sample(self, hostname, from_path, hivemind_emulation=False):
        root = msgpack.loads(self._cached_structure, encoding='utf8')
        if from_path:
            for name in from_path.split('.'):
                if name not in root['subsections']:
                    return None, None
                root = root['subsections'][name]
        latest_mtime = int(max(
            section['mtime'] for section in self._iter_sections(root)
            if section['mtime'] is not None
        ))
        if hivemind_emulation:
            res = self._sample_hivemind(root, self._cached_configs, hostname)
        else:
            res = self._sample_native(root, self._cached_configs, hostname)
        return res, latest_mtime

    def all_rules(self, path):
        cfg = self._cached_configs.get(path)
        if cfg is None:
            return None
        node = msgpack.loads(self._cached_structure, encoding='utf8')
        for name in path.split('.'):
            node = node['subsections'][name]
        return node, cfg['hosts'], cfg['configs']

    def _sample_native(self, root, configs, hostname):
        config_hashes = defaultdict(dict)
        for section in self._iter_sections(root):
            config_id = self._get_config_id(configs[section['path']], hostname)
            if config_id is None:
                section['config'] = None
                section['matched_rules'] = []
                section['config_hash'] = self.NULL_CONFIG_HASH
            else:
                config = configs[section['path']]['configs'][config_id]
                section['config'] = config['config']
                section['matched_rules'] = config['matched_rules']
                section['config_hash'] = config['config_hash']

            path = section['path']
            while path != '':
                path = '' if not '.' in path else path.rsplit('.', 1)[0]
                config_hashes[path][section['path']] = section['config_hash']

        for section in self._iter_sections(root):
            hash_info = config_hashes.get(section['path'])
            if hash_info is None:
                continue
            hash_info[section['path']] = section['config_hash']
            section['config_hash'] = self._get_combined_config_hash(hash_info)
        return root

    def _get_combined_config_hash(self, hash_info):
        hash_info = sorted(hash_info.items())
        return hashlib.sha1(msgpack.dumps(hash_info,
                                          encoding='utf8')).hexdigest()

    def _sample_hivemind(self, root, configs, hostname):
        result = {}
        for section in self._iter_sections(root):
            config_id = self._get_config_id(configs[section['path']], hostname)
            if config_id is None:
                continue
            config = configs[section['path']]['configs'][config_id]['config']
            result[section['path']] = {
                'data': config,
                'author': section['changed_by'],
                'mtime': section['mtime'],
            }
        return result


def _response_serialize(data, last_modified):
    if flask.request.if_modified_since and last_modified:
        request_modified = flask.request.if_modified_since
        actual_modified = datetime.utcfromtimestamp(int(last_modified))
        if actual_modified <= request_modified:
            return flask.current_app.response_class(status=304)

    fmt = flask.request.args.get('fmt', 'json')
    if fmt not in set(('yaml', 'json', 'msgpack')):
        flask.abort(400)
    response = flask.current_app.response_class
    if fmt == 'yaml':
        resp = response(yaml.dump(data, default_flow_style=False),
                        content_type='text/plain')
    if fmt == 'msgpack':
        resp = response(msgpack.dumps(data, encoding='utf-8'),
                        content_type='application/msgpack')
    if fmt == 'json':
        resp = response(json.dumps(data), content_type='application/json')
    if last_modified:
        resp.last_modified = datetime.utcfromtimestamp(last_modified)
    return resp
