import logging
import os

import balancer
import mapping as mapping_module
import storage as storage_module
import slot as slot_module

import utils

import v2.plugin as plugin


class Plugin(plugin.BasePlugin):
    @classmethod
    def add_args(cls, parser):
        parser.add_argument('--shard', help='Shard', required=False)
        parser.add_argument('--db-timestamp', help='DB timestamp', required=False)
        parser.add_argument('--cache-dir', help='Fast storage (for caches)', required=False)
        parser.add_argument('--storage-root', help='Storage root', required=True)  # used to be shard-root
        parser.add_argument('--extra-arguments', help='extra args to daemon', required=False)
        parser.add_argument('--balancer-bin', help='Balancer', required=False)
        parser.add_argument('--balancer-cfg', help='Balancer config', required=False)

    def __init__(self, host, port, tags, args):
        port = int(port)
        super(Plugin, self).__init__(host, port, tags)
        self._storage = storage_module.Storage(os.path.join(args.storage_root, str(port)))
        self._download_queue = storage_module.DownloadQueue(self._storage, self._on_resource_downloaded)

        self._slots = {}
        self._config = None

        host_ = host

        class InstanceFactory(mapping_module.get_search_class(utils.get_itype(), utils.get_metaprj())):
            host = host_
            shard = args.shard
            db_timestamp = args.db_timestamp
            cache_root = args.cache_dir
            log_dir = args.log_dir
            extra_arguments = args.extra_arguments

        self._httpsearch_class = InstanceFactory

        if _is_balanced():
            self._balancer = balancer.Balancer(
                host, port, port + (1000 if self.mtn else 7),
                args.balancer_bin, args.balancer_cfg, args.log_dir,
                self._httpsearch_class.ports_required
            )
        else:
            self._balancer = None

        _log.debug('%s port required for each %s instance',
                   self._httpsearch_class.ports_required,
                   self._httpsearch_class.__base__)

        self._download_queue.start()

    @property
    def mtn(self):
        return 'gencfg-c' in self.host or '.yp-c.' in self.host

    def collect_status(self):
        return {
            'slots': {
                slot_name: slot.collect_status()
                for slot_name, slot in self._slots.iteritems()
            },
            'balancer': self._balancer.collect_status() if _is_balanced() else None,
        }

    def _get_next_slot_baseport(self):
        if self._slots:
            port = max([slot.base_port for slot in self._slots.values()]) + self._httpsearch_class.ports_required
            _log.debug('Next slot, base port %s', port)
        else:
            port = self.port + (self._httpsearch_class.ports_required if _is_balanced() else 0) + \
                (100 if self.mtn else 0)
            _log.debug('First slot, base port %s', port)
        return port

    def apply_config(self, config):
        self._ensure_slots(config['slots'].keys())

        for slot_name, slot in config['slots'].iteritems():
            # Stop changed/gone instances on all slots before enqueue new resources
            # to avoid race
            self._slots[slot_name].reset_instances(slot['instances'])

        for slot_name, slot in config['slots'].iteritems():
            for instance in slot['instances']:
                if instance['conf_hash'] not in self._slots[slot_name].started:
                    for resource in instance['resources']:
                        self._download_queue.push(resource['url'])

        self._config = config

        self._reconfigure_balancer()

    # NB: obsolete slots are not being removed.
    def _ensure_slots(self, slots):
        for slot in set(slots) - set(self._slots):
            _log.info('Add new slot %s', slot)
            self._slots[slot] = slot_module.Slot(self._httpsearch_class, self._get_next_slot_baseport(), self.port)

    def _reconfigure_balancer(self):
        if _is_balanced():
            self._balancer.apply_config(self._generate_mapping())

    def _needed_resources(self):
        needed_resources = set()
        for slot_name, slot in self._config['slots'].iteritems():
            for instance_config in slot['instances']:
                for resource in instance_config['resources']:
                    needed_resources.add(resource['url'])
        return needed_resources

    def _on_resource_downloaded(self, resource):
        _log.debug('on_resource_downloaded [%s]', resource)
        regenerate_balancer = False
        for slot_name, slot_config in self._config['slots'].iteritems():
            for instance_config in slot_config['instances']:
                if instance_config['conf_hash'] in self._slots[slot_name].started:
                    continue

                if all(self._storage.is_ready(resource['url']) for resource in instance_config['resources']):
                    _log.debug('All resources are ready for %s', slot_name)

                    # noinspection PyBroadException
                    try:
                        self._slots[slot_name].apply_config(
                            instance_config,
                            self._get_resources_for_instance(
                                slot_name,
                                instance_config['resources'],
                                self._httpsearch_class.resources_required
                            ),
                            '{}_{}'.format(slot_name, instance_config['revision']),
                        )
                        regenerate_balancer = True
                    except Exception:
                        _log.exception('Exception on apply_config for %s; will continue', slot_name)
                else:
                    _log.debug('Not all resources ready for %s conf_hash %s', slot_name, instance_config['conf_hash'])

        if regenerate_balancer:
            self._reconfigure_balancer()

        self._storage.cleanup(self._needed_resources())

    def _generate_mapping(self):
        mapping = {}

        for slot_name, slot in self._slots.iteritems():
            for timestamp, port in slot.mapping.iteritems():
                mapping['{}_{}'.format(slot_name, timestamp)] = port
                # Single slot scheme, TODO multi slot mmeta !!!
                if mapping.get('default'):
                    _log.warning('Got default on multi slot')
                if self._config['slots'][slot_name].get('default') == timestamp:
                    mapping['default'] = port

        _log.debug(mapping)

        return mapping

    def _get_resources_for_instance(self, slot_name, downloaded_resources, resources_required):
        return {
            resource: _get_resource_path(slot_name, self._storage, downloaded_resources, resource)
            for resource in resources_required
        }


def _extracted_resources_dir():
    return os.path.join(os.getcwd(), 'resources')  # TODO: do not put anything to CWD


def _get_resource_path(slot_name, storage, resources, resource_type):
    for resource in resources:
        if resource['name'] == resource_type:
            path = storage.path_to_resource(resource['url']) + '/' + resource['file_name']
            if 'extract_file' in resource:
                _log.debug('extracting %s from %s', resource['extract_file'], resource['file_name'])
                target_dir = os.path.join(
                    _extracted_resources_dir(),
                    slot_name,
                    storage.path_to_resource(resource['url']).replace('/', '_'),
                )
                utils.ensure_dir(target_dir)
                return utils.extract_file(path, resource['extract_file'], target_dir)
            else:
                return path


def _is_balanced():
    return (utils.get_itype(), utils.get_metaprj()) not in [('fusion', 'saas')]


_log = logging.getLogger(__name__)
