import logging
import collections

import infra.callisto.controllers.deployer2.proxy as proxies
import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.sdk.notify as notify
import infra.callisto.controllers.slots.state as slot_state
import infra.callisto.controllers.utils.gencfg_api as gencfg_api
import infra.callisto.controllers.utils.sandbox_utils as sandbox_utils

import allocator as allocation

_FREE_SPACE_RESERVE = 5 * 1024 ** 3


class _ChunksDeployProxy(proxies.ChunksProxy):
    def __init__(
        self,
        name,
        deploy_ctrl,
        namespace_prefix,
        allocator,
        chunks_generator,
        remote_storage_instances,
        remote_storage_host_agent_map,
        generation_max_share,
        chunk_resources=('',),
        configs_generator=None,
        enable_deploy_multi_level_cache_replicas=False,
        is_optional=False,
    ):
        self._chunk_resources = chunk_resources
        self._generation_max_share = generation_max_share
        self._name = name
        self._allocator = allocator
        self._chunks_generator = chunks_generator
        self._remote_storage_host_agent_map = remote_storage_host_agent_map
        self._remote_storage_instances = remote_storage_instances
        self._targets = set()
        self._generation_to_mappings = {}
        self._logger = logging.getLogger('{} chunks-deploy'.format(self._name))
        self._configs_generator = configs_generator
        self._notifications = notify.NotificationsAggregator()
        self._enable_deploy_multi_level_cache_replicas = enable_deploy_multi_level_cache_replicas
        self._is_optional = is_optional
        super(_ChunksDeployProxy, self).__init__(deploy_ctrl, namespace_prefix, chunk_resources)

    @property
    def path(self):
        return self._name

    def chunk_specs(self, generation):
        specs = self._erasure_chunks_specs(generation)
        if self._enable_deploy_multi_level_cache_replicas:
            specs += self._replicas_chunks_specs(generation)
        return specs

    def execute(self):
        for generation in self._generation_to_mappings.keys():
            if generation not in self._targets:
                del self._generation_to_mappings[generation]
        for generation in sorted(self._targets):
            self._logger.info('target generation %s', generation)
            if not self._allocator.mapping_exists(generation):
                self._allocator.ensure_mapping_removed(generation)
                self._logger.info("mapping for [%s] doesn't exist, generating", generation)

                specs = self.chunk_specs(generation)
                self._logger.info("sum chunks size %d", sum(spec.size for spec in specs))

                self._generation_to_mappings[generation] = self._allocator.generate_mapping(
                    generation,
                    specs,
                    self._hosts_specs()
                )
                self._log_changes()
            elif self._configs_generator:
                sb_task = self._configs_generator(db_timestamp=generation)
                if not sb_task:
                    self._logger.debug('Configs generation task is not created')
                elif sandbox_utils.is_task_failed(sb_task):
                    self._logger.debug('There is failed %s task id %s', sb_task['type'], sb_task['id'])
                    self._notifications.add_notification(notify.TextNotification(
                        'Check failed {} task {} and run it manually'.format(sb_task['type'], sb_task['url']),
                        notify.NotifyLevels.WARNING,
                    ))
                else:
                    self._logger.debug('%s task found %s', sb_task['type'], sb_task['url'])

            if generation not in self._generation_to_mappings:
                self._generation_to_mappings[generation] = self._allocator.load_mapping(generation)

        super(_ChunksDeployProxy, self).execute()

    def notifications(self):
        return self._notifications.get_notifications()

    @property
    def targets(self):
        return self._targets

    def set_targets(self, targets):
        if self._targets != targets:
            self._logger.info('%s: %s -> %s', self, self._targets, targets)
            self._targets = targets

    def host_chunks_target(self):
        result = collections.defaultdict(set)
        for generation in self._targets:
            for chunk, hosts in self._generations_mapping(generation).iteritems():
                for host in hosts:
                    result[host].add(chunk)
        return result

    def _log_changes(self):
        if len(self._generation_to_mappings) == 2:
            prev_gen, next_gen = sorted(self._generation_to_mappings)
            moved_number = 0
            total_number = len(self._generation_to_mappings[next_gen])
            for part, target in self._generation_to_mappings[next_gen].iteritems():
                prev_target = self._generation_to_mappings[prev_gen].get(
                    allocation._substitute_timestamp(prev_gen, allocation._drop_timestamp(part)),
                    []
                )
                if target != prev_target:
                    moved_number += 1
            self._logger.info(
                'Compare %s and %s parts allocations, parts moved: %s / %s (%0.2f%%)',
                prev_gen, next_gen, moved_number, total_number, moved_number * 100. / total_number
            )
        else:
            self._logger.info('Could not compare parts movement')

    def _generations_mapping(self, generation):
        return self._generation_to_mappings.get(generation, {})

    def _hosts_specs(self):
        hosts_specs = []
        target = self.host_chunks_target()
        guarantees = collections.defaultdict(int)
        slots = set()
        racks = {}
        for instance in self._remote_storage_instances:
            host = instance['hostname']
            slots.add(host)
            guarantees[host] += gencfg_api.storage_guarantee(instance)
            racks[host] = instance['rack']

        self._logger.debug("sum guarantee %d", sum(guarantees.values()))
        sum_freespace = 0
        dead_hosts_freespace = 0
        for host in self._remote_storage_host_agent_map:
            freespace = self._actual_freespace(host, target[host], guarantees[host])
            if self._deploy_ctrl.host_is_alive(host):
                sum_freespace += freespace
                hosts_specs.append(allocation.HostSpec(host, freespace, racks.get(host, host)))
            else:
                dead_hosts_freespace += freespace
                self._logger.debug("dead host %s", host)

        slots.symmetric_difference_update(self._remote_storage_host_agent_map)
        for host in slots:
            self._logger.error("diff in slots/agentmap %s", host)

        self._logger.debug("sum freespace %d (%d on dead hosts)", sum_freespace, dead_hosts_freespace)
        return hosts_specs

    def _actual_freespace(self, host, target, guarantee):
        guarantee -= _FREE_SPACE_RESERVE
        limit = guarantee * self._generation_max_share
        # subtract all planned chunks from guarantee
        for chunk in set(target):
            guarantee -= self._chunks_generator.chunk_size(chunk)
        # subtract some threshold
        return max(min(limit, guarantee), 0)

    def _erasure_chunks_specs(self, generation):
        return map(self._chunk_spec, self._chunks_generator.list_erasure_chunks(generation))

    def _replicas_chunks_specs(self, generation):
        return map(self._chunk_spec, self._chunks_generator.list_replicas_chunks(generation))

    def _chunk_spec(self, chunk):
        return allocation.ResourceSpec(
            chunk,
            self._chunks_generator.chunk_size(chunk),
            self._chunks_generator.chunk_replication(chunk),
            self._chunks_generator.chunk_parts_count(chunk)
        )

    def _observed_timestamps(self):
        observed = set()
        for timestamp, state in self._stats.items():
            if state['done'] >= state['total'] or self._is_optional:
                observed.add(timestamp)
        return observed

    def _observed_chunks(self):
        erasure_chunks = collections.defaultdict(lambda: collections.defaultdict(set))
        replicas_chunks = collections.defaultdict(lambda: collections.defaultdict(lambda: collections.defaultdict(int)))

        for host, target_chunks in self.host_chunks_target().iteritems():
            for chunk in self.chunks_observed_on_host(host):
                if chunk in target_chunks:
                    try:
                        if chunk.path.startswith('remote_storage'):
                            erasure_chunks[chunk.timestamp][(chunk.shard, chunk.number)].add(chunk.part)
                        elif self._enable_deploy_multi_level_cache_replicas and chunk.path.startswith('multi_level_cache_replicas'):
                            replicas_chunks[chunk.timestamp][(chunk.shard, chunk.number)][chunk.part] += 1
                    except IndexError:
                        self._logger.debug('%s, %s', host, chunk)

        return erasure_chunks, replicas_chunks

    def _target_stats_template(self):
        erasure_chunks_count = self._chunks_generator.shards_count() * self._chunks_generator.erasure_chunks_cnt
        erasure_parts_count = erasure_chunks_count * self._chunks_generator.erasure_parts_cnt

        replicas_chunks_count = 0
        replicas_parts_count = 0
        if self._enable_deploy_multi_level_cache_replicas:
            replicas_chunks_count = self._chunks_generator.shards_count() * self._chunks_generator.replicas_chunks_cnt
            replicas_parts_count = replicas_chunks_count * self._chunks_generator.replicas_parts_cnt * self._chunks_generator.replicas_parts_replication

        return {
            'done': 0,
            'total': erasure_chunks_count + replicas_chunks_count,
            'done_replicas': 0,
            'total_replicas': erasure_parts_count + replicas_parts_count,
        }

    def _eval_stats(self):
        stats = {}

        erasure_observed_chunks, replicas_observed_chunks = self._observed_chunks()

        # Check count of observed ERASURE parts
        for timestamp, chunks in erasure_observed_chunks.iteritems():
            if timestamp not in stats:
                stats[timestamp] = self._target_stats_template()
            stats[timestamp]['done'] += sum(1 for parts in chunks.itervalues() if len(parts) >= (self._chunks_generator.erasure_parts_cnt - 3))
            stats[timestamp]['done_replicas'] += sum(len(parts) for parts in chunks.itervalues())

        # Check count of observed REPLICAS parts
        for timestamp, chunks in replicas_observed_chunks.iteritems():
            if timestamp not in stats:
                stats[timestamp] = self._target_stats_template()
            for parts in chunks.values():
                for replic_count in parts.values():
                    if replic_count > 0:
                        stats[timestamp]['done'] += 1
                    stats[timestamp]['done_replicas'] += replic_count

        return stats

    def __str__(self):
        return '_ChunksDeployProxy({})'.format(self._name)


class SlotCtrlAdapter(_ChunksDeployProxy):
    def __init__(self, *args, **kwargs):
        super(SlotCtrlAdapter, self).__init__(*args, **kwargs)
        self._searcher_target = slot_state.NullSlotState

    def __str__(self):
        return 'SlotCtrlAdapter({})'.format(self._name)

    def set_target_state(self, deployer_target_states, searcher_target_state):
        self.set_targets({state.timestamp for state in deployer_target_states})
        self._searcher_target = searcher_target_state

    def get_observed_state(self):
        return {slot_state.SlotState(timestamp) for timestamp in self._observed_timestamps()}, self._searcher_target

    def searchers_state(self, slot_name_to_find=None):
        return {}

    def deploy_progress(self, timestamp=None):
        return {}

    def html_view(self):
        return sdk.blocks.SlotView(self.path, super(SlotCtrlAdapter, self).html_view(), sdk.blocks.Block())

    def json_view(self):
        deployer_state, searcher_state = self.get_observed_state()
        return {
            'deployer': {
                'observed': [state.json() for state in deployer_state],
                'target': [slot_state.SlotState(target).json() for target in self.targets]
            },
            'searcher': {
                'observed': searcher_state.json(),
                'target': self._searcher_target.json()
            }
        }
