import concurrent.futures
import logging
import multiprocessing
import os
import shutil
from collections import defaultdict, OrderedDict

from sandbox.sdk2.helpers import subprocess, ProcessRegistry, ProcessLog

from sandbox.projects.yabs.qa.module_base import ModuleBase


logger = logging.getLogger('BinBasesProvider')


def unpack(source_path, dest_path, transport_executable_path, use_packed, core_count):
    if use_packed:
        shutil.copy(source_path, dest_path)
    else:
        cmd = [
            transport_executable_path,
            'decompress',
            source_path,
            '{}.unpacking'.format(dest_path),
            str(core_count)
        ]
        logger.debug('Call %s', " ".join(cmd))
        with ProcessRegistry, ProcessLog(logger=logging.getLogger('unpacker')) as pl:
            subprocess.check_call(cmd, stderr=pl.stderr, stdout=pl.stdout)
        os.rename('{}.unpacking'.format(dest_path), dest_path)
    return os.path.abspath(dest_path)


class BinBasesProvider(ModuleBase):
    def __init__(self, adapter, shared_base_state=None, maximum_keys_per_tag=10):
        ModuleBase.__init__(self, adapter)
        self._maximum_keys_per_tag = maximum_keys_per_tag
        if shared_base_state:
            self._base_state = shared_base_state
            self.check_state(self._maximum_keys_per_tag)
        else:
            self._base_state = defaultdict(OrderedDict)

    @property
    def base_state(self):
        return self._base_state

    def check_state(self, maximum_keys_per_tag):
        logger.info('Starting shared state check')
        for tag, tag_dict in self.base_state.items():
            for identity_key, base_path_set in tag_dict.items():
                for base_path in list(base_path_set):
                    if not os.path.isfile(base_path):
                        logger.info('Dropping non-existant path %s from state', base_path)
                        base_path_set.remove(base_path)
                if not base_path_set:
                    logger.info('Dropping empty base path set for identity key %s', str(identity_key))
                    del tag_dict[identity_key]
            if not tag_dict:
                self.remove_empty_tag(self.base_state, tag)
            else:
                self.cull_tag(tag_dict, maximum_keys_per_tag)
        if not self.base_state:
            logger.warning('Check resulted in an empty base state')
        logger.info('Base provider state check complete')

    def remove_empty_tag(self, base_state, tag):
        assert not base_state.get(tag, None)
        logger.info('Dropping empty tag %s from state', tag)
        del base_state[tag]

    def flush_state(self):
        logger.info('Flushing base provider state')
        for tag, tag_dict in self.base_state.items():
            for key in tag_dict.keys():
                self.remove_key(tag_dict, key)
            self.remove_empty_tag(self.base_state, tag)

    def provide(self, servers, sync_resource_workers=8, unpack_workers=4):
        bases_to_dirs_dict = defaultdict(set)
        for server in servers:
            data_dir = server.get_bin_db_dir()
            db_list = set(server.get_bin_db_list())
            for base in db_list:
                bases_to_dirs_dict[base].add(data_dir)
        transport_executable_path = self.adapter.get_transport_resource_path()

        sync_resource_executor = concurrent.futures.ThreadPoolExecutor(max_workers=sync_resource_workers)
        unpack_base_executor = concurrent.futures.ThreadPoolExecutor(max_workers=unpack_workers)

        unpacked_bases = {}

        sync_resource_futures = []
        for base_name, dir_set in bases_to_dirs_dict.iteritems():
            _, tag, _ = base_name.split('.')  # version.tag.yabs
            unpacked_bases[tag] = {
                "base_name": base_name,
                "dir_set": dir_set,
            }

            base_identity_key = self.adapter.get_base_identity_key_by_tag(tag)
            if base_identity_key in self.base_state[tag]:
                logger.debug('Base %s with equal identity key found in provider state', tag)
                unpacked_bases[tag]["source_path"] = next(iter(self.base_state[tag][base_identity_key]))
                continue

            f = sync_resource_executor.submit(self.adapter.get_base_resource_path_by_tag, tag)
            f.tag = tag
            sync_resource_futures.append(f)

        unpack_base_futures = []
        for f in concurrent.futures.as_completed(sync_resource_futures):
            packed_base_path = f.result()
            tag = f.tag
            base_name = unpacked_bases[tag]["base_name"]
            dir_set = unpacked_bases[tag]["dir_set"]
            unpack_future = unpack_base_executor.submit(
                unpack,
                packed_base_path,
                os.path.join(next(iter(dir_set)), base_name),
                transport_executable_path,
                self.adapter.get_use_packed_base_by_tag(tag),
                core_count=multiprocessing.cpu_count(),
            )
            unpack_future.tag = tag
            unpack_base_futures.append(unpack_future)

        for f in concurrent.futures.as_completed(unpack_base_futures):
            source_path = f.result()
            tag = f.tag
            base_identity_key = self.adapter.get_base_identity_key_by_tag(tag)
            self.drop_invalidated_paths(tag, source_path)
            self.cull_tag(self.base_state[tag], self._maximum_keys_per_tag - 1)
            self.base_state[tag][base_identity_key] = {source_path}
            unpacked_bases[tag]["source_path"] = next(iter(self.base_state[tag][base_identity_key]))

        for tag, data in unpacked_bases.items():
            base_identity_key = self.adapter.get_base_identity_key_by_tag(tag)
            self.link_base_to_dirs(
                data["source_path"],
                data["base_name"],
                data["dir_set"],
                self.base_state[tag][base_identity_key]
            )
            logger.info('Base %s is done: %s', tag, data["source_path"])

    def remove_key(self, tag_dict, key):
        logger.info('Removing key %s from state', str(key))
        for base_path in tag_dict[key]:
            os.remove(base_path)
        del tag_dict[key]

    def cull_tag(self, tag_dict, maximum_keys_per_tag):
        if len(tag_dict) > maximum_keys_per_tag:
            for _ in range(len(tag_dict) - maximum_keys_per_tag):
                self.remove_key(tag_dict, next(tag_dict.iterkeys()))

    def drop_invalidated_paths(self, tag, source_path):
        for key, path_set in self.base_state[tag].items():
            if source_path in path_set:
                logger.info('Base %s with identity key %s invalidated by incoming path %s, dropping old key from state', tag, str(key), source_path)
                path_set.remove(source_path)
                if not path_set:
                    logger.info('Key %s emptied, removing from base', str(key))

                    del self.base_state[tag][key]

    def link_base_to_dirs(self, source_path, base_name, dir_set, base_state_path_set):
        for directory in dir_set:
            dest_path = os.path.abspath(os.path.join(directory, base_name))
            if dest_path not in base_state_path_set:
                os.link(source_path, dest_path)
                base_state_path_set.add(dest_path)
