import collections
import itertools
import logging


logger = logging.getLogger(__name__)


RegisteredTask = collections.namedtuple("RegisteredTask", ["key", "task"])


class InflightTracker(object):
    def __init__(self):
        self.task_to_batch = {}
        self.on_complete_callbacks = {}
        self.inflight_tasks = {}
        self.task_ids = itertools.count()
        self.batch_ids = itertools.count()

    def _get_new_task_id(self):
        return next(self.task_ids)

    def _get_new_batch_id(self):
        return next(self.batch_ids)

    def register(self, collection, on_complete_callback):
        if not collection:
            on_complete_callback()
            return []

        batch_id = self._get_new_batch_id()
        self.on_complete_callbacks[batch_id] = on_complete_callback

        self.inflight_tasks[batch_id] = len(collection)
        registered_tasks = [RegisteredTask(self._get_new_task_id(), task) for task in collection]
        for registered_task in registered_tasks:
            self.task_to_batch[registered_task.key] = batch_id
        return registered_tasks

    def complete(self, task):
        batch_id = self.task_to_batch[task.key]
        del self.task_to_batch[task.key]

        self.inflight_tasks[batch_id] -= 1
        if not self.inflight_tasks[batch_id]:
            logger.debug("Batch completed: %s", batch_id)
            del self.inflight_tasks[batch_id]
            self.on_complete_callbacks[batch_id]()
