import collections

from ...utils import entities


class ShardObserved(object):
    def __init__(self):
        self.building = set()
        self.prepared = set()
        self.idle = set()
        self.failed = set()
        self.dead = set()
        self.total_fails_count = 0

    @property
    def all_agents(self):
        return set.union(self.building, self.prepared, self.idle, self.dead, self.failed)

    @property
    def status(self):
        if self.prepared:
            return entities.ShardStatus.Prepared
        if self.building:
            return entities.ShardStatus.Building
        if self.idle:
            return entities.ShardStatus.NotStarted
        if self.failed:
            return entities.ShardStatus.Failed
        return entities.ShardStatus.Null

    def json(self):
        return {
            'building': entities.serialize_agents(self.building),
            'prepared': entities.serialize_agents(self.prepared),
            'idle': entities.serialize_agents(self.idle),
            'failed': entities.serialize_agents(self.failed),
            'dead': entities.serialize_agents(self.dead),
            'status': self.status,
            'fails_count': self.total_fails_count,
        }


def calculate_observed_v2(targets, reports):
    tasks_observed = collections.defaultdict(ShardObserved)

    for task_id, target in targets.iteritems():
        resource_name = target.task.resource_name
        observed = ShardObserved()

        for agent in target.all_agents:
            _set_observed_state(observed, resource_name, agent, reports.get(agent))
        tasks_observed[task_id] = observed

    return tasks_observed


def _set_observed_state(observed, resource_name, agent, report):
    if report:
        if resource_name in report.prepared:
            observed.prepared.add(agent)
        elif resource_name in report.building:
            observed.building.add(agent)
        elif resource_name in report.idle:
            observed.idle.add(agent)
        elif resource_name in report.failed:
            observed.failed.add(agent)
        else:
            observed.idle.add(agent)
        _set_stats(observed, resource_name, report)
    else:
        observed.dead.add(agent)


def _set_stats(observed, resource_name, report):
    task_stats = report.all_tasks.get(resource_name)
    if task_stats:
        observed.total_fails_count += task_stats.fails_cnt
