import collections
import datetime
import logging

from ...build import task as _task_module
from ...utils import entities
from ...utils import disjoint_sets


class Mode(object):
    Build = 'build'
    Keep = 'keep'
    Remove = 'remove'


class Target(object):
    def __init__(self, tasks):
        self._tasks = tasks

    def build(self, task):
        if task.task_id not in self._tasks:
            self._tasks[task.task_id] = TaskTarget(task)
        self._tasks[task.task_id].mode = Mode.Build

    def keep(self, task_id):
        self._tasks[task_id].mode = Mode.Keep

    def remove(self, task_id):
        del self._tasks[task_id]
        _log.info('removed %s', task_id)

    def shards_on_agents(self, agents, building=False, prepared=False):
        result = {agent: set() for agent in agents}
        for target in self.task_targets():
            if building:
                for agent in target.building & agents:
                    result[agent].add(target.task.resource_name)
            if prepared:
                for agent in target.keeping & agents:
                    result[agent].add(target.task.resource_name)
        return result

    def _active_tasks_of_agents(self):
        result = collections.defaultdict(set)
        for task_id, target in self._tasks.iteritems():
            for agent in target.building:
                result[agent].add(task_id)
        return result

    def assign_tasks(self, mapping):
        active_tasks_of_agents = self._active_tasks_of_agents()
        for agent, task_id in mapping:
            for active_task_id in active_tasks_of_agents[agent]:
                self._tasks[active_task_id].ensure_remove(agent)
            self._tasks[task_id].ensure_build(agent)

    @classmethod
    def from_json(cls, json_):
        tasks = {}
        for task in json_['tasks']:
            task_target = TaskTarget.from_json(task)
            tasks[task_target.task.task_id] = task_target
        return Target(tasks)

    def json(self):
        tasks = [target.json() for target in self._tasks.values()]
        return {'tasks': tasks}

    def iteritems(self):
        for task_id, target in self._tasks.iteritems():
            yield task_id, target

    def task_ids(self):
        return self._tasks.keys()

    def task_targets(self):
        return self._tasks.values()

    def __getitem__(self, task_id):
        return self._tasks[task_id]

    def __contains__(self, task_id):
        return task_id in self._tasks


class TaskTarget(object):
    def __init__(self, task, mode=Mode.Remove):
        self._task = task
        self._mode = mode
        self._state = disjoint_sets.DisjointSets(task, [Mode.Build, Mode.Keep])
        self._last_modified = datetime.datetime.min

    @property
    def task(self):
        return self._task

    @property
    def since_last_modified(self):
        return datetime.datetime.now() - self._last_modified

    @property
    def all_agents(self):
        return self._state.all

    @property
    def building(self):
        return self._state.get_class(Mode.Build)

    @property
    def keeping(self):
        return self._state.get_class(Mode.Keep)

    def ensure_keep(self, agent):
        if self._state.ensure(agent, Mode.Keep, log_level=logging.INFO):
            self._last_modified = datetime.datetime.now()

    def ensure_build(self, agent):
        if self._state.ensure(agent, Mode.Build, log_level=logging.INFO):
            self._last_modified = datetime.datetime.now()

    def ensure_remove(self, agent):
        self._state.ensure_none(agent, log_level=logging.INFO, set_last_modified=False)

    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, value):
        if value != self._mode:
            self._mode, old_mode = value, self._mode
            _log.info('%s mode: %s -> %s', self._task, old_mode, value)
        if value == Mode.Remove:
            for agent in self.all_agents:
                self.ensure_remove(agent)

    def json(self):
        return {
            'target': {
                Mode.Build: entities.serialize_agents(self.building),
                Mode.Keep: entities.serialize_agents(self.keeping),
                'mode': self._mode,
                'last_modified': self._last_modified,
            },
            'task': self.task.json()
        }

    @classmethod
    def from_json(cls, data):
        new_target = cls(_task_module.Task.from_json(data['task']))
        data = data['target']
        new_target._mode = data['mode']
        for agent in data[Mode.Build]:
            new_target._state.ensure(entities.Agent(agent['host'], agent['port']), Mode.Build)
        for agent in data[Mode.Keep]:
            new_target._state.ensure(entities.Agent(agent['host'], agent['port']), Mode.Keep)
        new_target._last_modified = data.get('last_modified', datetime.datetime.min)
        return new_target


_log = logging.getLogger(__name__)
