from flask import current_app as app

import multiprocessing
import logging

logger = logging.getLogger(__name__)

USE_ALL_CORES = -1


class Consumer(multiprocessing.Process):

    def __init__(self, func, task_queue, result_queue=None):
        multiprocessing.Process.__init__(self)
        self.func = func
        self.task_queue = task_queue
        self.result_queue = result_queue

    def call_func(self, task):
        return self.func(task)

    def run(self):
        while True:
            next_task = self.task_queue.get()
            if next_task is None:
                # Poison pill means shutdown
                logger.info("Exiting process %s", self.name)
                self.task_queue.task_done()
                break
            logger.info("Process %s got task %s", self.name, next_task)

            try:
                result = self.call_func(next_task)
                if self.result_queue:
                    self.result_queue.put(result)
            except:
                logger.exception('Unexpected error processing task %s', next_task)
            self.task_queue.task_done()
        return


class CountryCategoryConsumer(Consumer):
    """
    NOTE: this is mostly legacy and will be gone eventually,
    but some old code that uses `process_countries_categories`
    accepts country and category as kwargs now and needs special treatment.
    """

    def call_func(self, task):
        return self.func(
            country=task.country,
            category=task.category
        )


class Task(object):

    def __init__(self, country, category):
        self.country = country
        self.category = category

    def __str__(self):
        return 'Task {}, {}'.format(self.country, self.category)


def process_tasks(func, task_generator, result_queue=None, n_jobs=1, consumer_class=Consumer):
    if n_jobs == USE_ALL_CORES:
        n_jobs = multiprocessing.cpu_count()

    task_queue = multiprocessing.JoinableQueue(maxsize=n_jobs)
    consumers = [
        consumer_class(func, task_queue, result_queue)
        for _ in xrange(n_jobs)
    ]

    # start consumers
    for consumer in consumers:
        consumer.start()

    # enqueue jobs
    for task in task_generator:
        task_queue.put(task, block=True)

    # tell workers to finish
    for _ in xrange(n_jobs):
        task_queue.put(None)

    task_queue.join()

    logger.info('All workers are done')
    # results will be collected in result_queue if it's not None
