import logging
import collections

import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.slots.state as slots_state


class SoftController(sdk.Controller):
    path = 'soft'

    @property
    def id(self):
        return 'soft_' + self._contour.name

    def __init__(self, contour, slot_controllers):
        super(SoftController, self).__init__()
        self._contour = contour
        self._slots = slot_controllers
        self._topology = TopologyGenerator(contour)

        self.register(*slot_controllers.values())

    def set_context(self, context):
        context = slots_state.mapping_from_dict(context)
        self._topology = TopologyGenerator(self._contour, context)

    def get_context(self):
        context = slots_state.mapping_to_dict(self._topology.target)
        return context

    def update_targets(self, built_shards, unused_slots):
        fresh_groups = _newest_groups(self._contour.tier, built_shards)

        if fresh_groups:
            self._topology.update(self.slots_state(), fresh_groups)

        new_order = self._topology.target
        self._prepare_ready_groups(fresh_groups, new_order)
        self._switch(unused_slots, new_order)

    def _prepare_ready_groups(self, fresh_groups, new_order):
        for slot, ctrl_ in self._slots.items():
            _, searcher_target = self._slots[slot].get_target_state()
            _, searcher_observed = self._slots[slot].get_observed_state()
            searcher_target = searcher_target or searcher_observed
            if new_order[slot] in fresh_groups:
                deployer_target = {searcher_target, new_order[slot]} if searcher_target else {new_order[slot]}
            else:
                deployer_target = {searcher_target} if searcher_target else set()
            self._slots[slot].set_target_state(deployer_target, searcher_target)

    def _switch(self, unused_slots, new_order):
        for slot in unused_slots:
            for state in self._slots[slot].get_observed_state()[0]:
                if new_order[slot] and state == new_order[slot]:
                    self._slots[slot].set_target_state({state}, state)

    def slots_state(self):
        return {
            slot: ctrl_.get_observed_state()[1] for slot, ctrl_ in self._slots.items()
        }

    def searchers_state(self, slot_name_to_find=None):
        result = {}
        for slot_name, slot_ctrl in self._slots.items():
            if slot_name_to_find in (slot_name, None):
                result[slot_name] = slot_ctrl.searchers_state()
        return result

    def json_view(self):
        return {
            group_number: ctrl_.json_view()
            for group_number, ctrl_ in self._slots.items()
        }


class TopologyGenerator(object):
    def __init__(self, contour, target=None):
        self._contour = contour
        self._target = target or slots_state.empty_mapping(self._contour)

    @classmethod
    def _find_group(cls, mappings, to_find):
        for slot in sorted(mappings.keys()):
            if mappings[slot] and mappings[slot].group_number == to_find:
                return slot
        return None

    @classmethod
    def _iter_null(cls, mappings):
        for slot in sorted(mappings.keys()):
            if not mappings[slot]:
                yield slot

    @classmethod
    def _iter_next_slot(cls, target, current_observed):
        for null in cls._iter_null(current_observed):
            while null is not None:
                if target[null]:
                    null = cls._find_group(current_observed, target[null].group_number)
                else:
                    yield null
                    null = None

    def _copy_target(self, target, observed, fresh_states):
        new_target = slots_state.empty_mapping(self._contour)

        for slot, state in observed.items():
            if state in fresh_states:
                new_target[slot] = state

        unused_fresh_states = fresh_states.difference(set(new_target.values()))
        unused_fresh_states = {state.group_number: state for state in unused_fresh_states}

        prev, current = None, list(self._iter_next_slot(new_target, observed))
        while prev != current:
            for slot in current:
                group_number = target[slot].group_number
                if group_number in unused_fresh_states:
                    new_target[slot] = unused_fresh_states.pop(group_number)
            prev, current = current, list(self._iter_next_slot(new_target, observed))

        for fresh_state in sorted(unused_fresh_states.values(), key=lambda x: x.group_number):
            for slot in self._iter_next_slot(new_target, observed):
                new_target[slot] = fresh_state
                break

        return new_target

    def update(self, current_observed, fresh_states):
        observed = _keep_fresh(current_observed, self._contour)
        self._target = self._copy_target(self.target, observed, fresh_states)

    @property
    def target(self):
        return self._target.copy()


def _newest_groups(tier, shards):
    groups = collections.defaultdict(set)
    for shard in shards:
        groups[slots_state.SlotState(timestamp=shard.timestamp, group_number=shard.group_number)].add(shard)
    complete = collections.defaultdict(set)
    for state, shards in groups.items():
        if len(shards) == tier.shards_in_group:
            complete[state.group_number].add(state)
    result = set()
    for group in complete:
        result.add(sorted(complete[group], key=lambda x: x.timestamp)[-1])
    return result


def _keep_fresh(mappings, contour):
    used_groups = {}
    for slot in sorted(mappings):
        if mappings[slot]:
            state, group = mappings[slot], mappings[slot].group_number
            if group not in used_groups or used_groups[group][1].timestamp < state.timestamp:
                used_groups[group] = (slot, state)
    result = slots_state.empty_mapping(contour)
    for slot, state in used_groups.values():
        result[slot] = state
    return result


_log = logging.getLogger(__name__)
