import collections

import infra.callisto.controllers.utils.const_obj as const_obj
import infra.callisto.controllers.utils.funcs as funcs


class Builder(const_obj.ConstObj):
    def __init__(self, agent, space, building_shards, prepared_shards, tier):
        self.agent = agent
        self.space_guarantee = space
        self.building_shards = building_shards
        self.prepared_shards = prepared_shards
        self.tier = tier
        self.on_initialize()

    @property
    def freespace(self):
        return self.space_guarantee - self.usedspace

    @property
    def usedspace(self):
        return self.tier.shard_size * (len(self.building_shards) + len(self.prepared_shards))

    def __hash__(self):
        return hash(self.agent)

    def __eq__(self, other):
        return other.agent == self.agent


class _ManyToMany(object):
    def __init__(self):
        self._relations = collections.defaultdict(set)

    def add(self, a, b):
        self._relations[a].add(b)
        self._relations[b].add(a)

    def get(self, a):
        return self._relations[a]

    def remove_all(self, a):
        for b in self._relations[a]:
            self._relations[b].discard(a)
        del self._relations[a]


def _map_prev_shard_to_tasks(tasks):
    result = collections.defaultdict(list)
    for task in tasks:
        result[task.prev_shard_name].append(task)
    return dict(result)


def _shard_builders_relations(shards, builders):
    result = _ManyToMany()
    for builder in builders:
        for shard in builder.prepared_shards & shards:
            result.add(shard, builder)
    return result


def _count_ready_shards(shards, builders):
    counter = collections.Counter()
    for builder in builders:
        counter.update((builder.prepared_shards | builder.building_shards) & shards)
    return counter


def _sort_tasks_key(task):
    return task.task_id


def _sort_tasks(tasks_to_build, builders):
    """
        first assign tasks having lowest number of replicas of its' prev_shard
        finally assign tasks having no replicas of its' prev_shard
    """

    prev_shard_to_tasks_map = _map_prev_shard_to_tasks(tasks_to_build)
    counter = _count_ready_shards(set(prev_shard_to_tasks_map), builders)

    result = []
    for prev_shard in sorted(counter, key=lambda shard: counter[shard]):
        result += sorted(prev_shard_to_tasks_map[prev_shard], key=_sort_tasks_key)

    seen_task_ids = {task.task_id for task in result}
    missed_tasks = [task for task in tasks_to_build if task.task_id not in seen_task_ids]

    result += sorted(missed_tasks, key=_sort_tasks_key)
    return result


def _sort_builders_key(builder):
    return (
        len(builder.prepared_shards | builder.building_shards),
        len(builder.building_shards),
        -builder.freespace,
    )


def _sort_builders_default(builders):
    return sorted(builders, key=_sort_builders_key)


def _sort_builders_for_incremental(shard_builders_relations, task):
    """
        choose builders having prev_shard of task
        sort by total number of prepared shards on builder
    """

    return sorted(shard_builders_relations.get(task.prev_shard_name), key=_sort_builders_key)


def _assign_inc_build_tasks_to_builders(tasks_to_build, builders):
    """
        assigns only incremental tasks if there is a host with prev_shard
        :param tasks_to_build: pass only tasks having prev_shard_name set
        :param builders: list of Builder objects
            pass builders with disk >= task.space_needed
        :return: (mapping: agent -> task, skipped_tasks)
    """
    mapping, skipped_tasks = {}, []

    shard_builders_relations = _shard_builders_relations(
        set(_map_prev_shard_to_tasks(tasks_to_build)),
        builders,
    )
    for task in _sort_tasks(tasks_to_build, builders):
        for builder in _sort_builders_for_incremental(shard_builders_relations, task):
            mapping[builder.agent] = task
            shard_builders_relations.remove_all(builder)
            break
        else:
            skipped_tasks.append(task)
    return mapping, skipped_tasks


def _assign_build_tasks_to_builders(tasks_to_build, builders):
    """
        assigns any tasks
        :param tasks_to_build: any tasks (incremental or not)
        :param builders: list of Builder objects
            pass builders with disk >= task.space_needed + shard_size if incremental
            otherwise with disk >= task.space_needed
        :return: (mapping: agent -> task, skipped_tasks)
    """
    mapping, skipped_tasks = {}, []
    available_builders = _sort_builders_default(builders)
    for task in tasks_to_build:
        if available_builders:
            mapping[available_builders.pop().agent] = task
        else:
            skipped_tasks.append(task)
    return mapping, skipped_tasks


def _split_tasks_to_inc_and_full(tasks_to_build):
    inc = [task for task in tasks_to_build if task.prev_shard_name]
    full = [task for task in tasks_to_build if not task.prev_shard_name]
    return inc, full


def assign_build_tasks_to_builders(tasks_to_build, builders, space_needed_inc, space_needed_full, shard_size):
    """
    :param tasks_to_build: any tasks (incremental or not)
    :param builders: list of Builder objects
    :param space_needed_inc: space needed to build a shard incrementally
    :param space_needed_full: space needed to build a shard fully
    :param shard_size: size of one shard
    :return: mapping: agent -> task
    """
    inc_build_tasks, full_build_tasks = _split_tasks_to_inc_and_full(tasks_to_build)

    mapping1, skipped_inc_tasks = _assign_inc_build_tasks_to_builders(
        inc_build_tasks,
        [builder for builder in builders if builder.freespace >= space_needed_inc],
    )

    builders = [builder for builder in builders if builder.agent not in mapping1]
    mapping2, _ = _assign_build_tasks_to_builders(
        skipped_inc_tasks,
        [builder for builder in builders if builder.freespace >= space_needed_inc + shard_size],
    )

    builders = [builder for builder in builders if builder.agent not in mapping2]
    mapping3, _ = _assign_build_tasks_to_builders(
        full_build_tasks,
        [builder for builder in builders if builder.freespace >= space_needed_full],
    )

    return funcs.merge_dicts_no_intersection(mapping1, mapping2, mapping3)


def _map_shard_to_task(tasks):
    result = {}
    for task in tasks:
        result[task.resource_name] = task
    return result


def _find_shards_to_remove(tasks_to_keep, builders, keep_replicas_count):
    counter = _count_ready_shards({task.resource_name for task in tasks_to_keep}, builders)
    return {
        shard: count - keep_replicas_count
        for shard, count in counter.items()
        if count > keep_replicas_count
    }


def _is_shard_needed(builder, shard, shard_to_task_mapping):
    for building_shard in builder.building_shards:
        if building_shard not in shard_to_task_mapping:
            continue
        if shard_to_task_mapping[building_shard].prev_shard_name == shard:
            return True
    return False


def _find_builders_with_shards_to_remove(builders, tasks_to_build, shards_to_remove):
    shard_to_task_mapping = _map_shard_to_task(tasks_to_build)
    result = collections.defaultdict(list)

    def _mark_to_remove(builder_, shard_):
        result[builder_.agent].append(shard_)
        shards_to_remove[shard_] -= 1
        if shards_to_remove[shard_] == 0:
            del shards_to_remove[shard_]

    for builder in reversed(_sort_builders_default(builders)):
        for shard in builder.building_shards & set(shards_to_remove):
            _mark_to_remove(builder, shard)

    for builder in reversed(_sort_builders_default(builders)):
        for shard in builder.prepared_shards & set(shards_to_remove):
            if not _is_shard_needed(builder, shard, shard_to_task_mapping):
                _mark_to_remove(builder, shard)
    return result


def find_tasks_to_remove(tasks_to_build, tasks_to_keep, builders, keep_replicas_count=1):
    """
    find old tasks which not used to build newer generation base
    :param tasks_to_build: all tasks with mode == Build
    :param tasks_to_keep: all tasks with mode == Keep
    :param builders: builders: list of Builder objects
    :param keep_replicas_count: keep at least `keep_replicas_count` replicas of a shard
    :return: mapping agent -> [tasks to remove]
    """

    shards_to_remove = _find_shards_to_remove(tasks_to_keep, builders, keep_replicas_count)
    builders_with_shards_to_remove = _find_builders_with_shards_to_remove(builders, tasks_to_build, shards_to_remove)
    shard_to_tasks_mapping = _map_shard_to_task(tasks_to_keep)

    return {
        agent: [shard_to_tasks_mapping[shard] for shard in shards]
        for agent, shards in builders_with_shards_to_remove.items()
    }
