# -*- encoding: utf-8 -*-
"""
выполнение задач в несколько параллельных процессов
"""

import os, time, logging
from multiprocessing import Process, cpu_count
from collections import deque, defaultdict
from direct.tools import MyTrace


logger = logging.getLogger(__name__)


class ParallelWorker(Process):
    def __init__(self, func, params):
        self.func = func
        self.params = params
        self.start_time = time.time()
        self.parent_trace = MyTrace.current()
        super(ParallelWorker, self).__init__();


    def run(self):
        pt = self.parent_trace
        with MyTrace(service=pt.service, method=pt.method, trace_id=pt.trace_id, parent_id=pt.span_id, span_id=MyTrace.generate_id(), tags=pt.tags + ',' + self.func.__name__) as trace:
            self.func(**self.params)


class ParallelPool(object):
    def __init__(self, workers_cnt=cpu_count(), max_tries=1, timeout=None, worker_timeout=None):
        self.workers_cnt = workers_cnt
        self.max_tries = max_tries
        self.timeout = timeout
        self.worker_timeout = worker_timeout
        

    def process(self, func, tasks):
        tasks_queue = deque(enumerate(tasks))
    
        processes = {}
        tries = defaultdict(int)
        start_time = time.time()
        successes, fails = 0, 0
        old_state = ''
        while processes or tasks_queue:
            if time.time() - start_time > self.timeout:
                logger.critical("global timeout")
                for pid in processes:
                    os.kill(pid, 9)
                else:
                    raise Exception("Global timeout")
            else:
                while len(processes) < self.workers_cnt and tasks_queue:
                    task_id, task = tasks_queue.popleft()
                    proc = ParallelWorker(func, task['params'])
                    proc.start()
                    processes[proc.pid] = (proc, task_id, task)
                    logger.warn("start proc %d for task %s" % (proc.pid, task['desc']))
    
            state = "completed: %d/%d, processes: %d, fails: %d" % (successes, len(tasks), len(processes), fails)
            if old_state != state:
                logger.warn(state)
                old_state = state
    
            time.sleep(1)

            for pid, (proc, task_id, task) in processes.items():
                task_desc = task['desc']
                if proc.is_alive():
                    if self.worker_timeout is not None and time.time() - proc.start_time > self.worker_timeout:
                        logger.error("task %s(%d): timeout" % (task_desc, proc.pid) )
                        proc.terminate()
                else:
                    del processes[pid]
                    if proc.exitcode == 0:
                        logger.warn("task %s: ok" % task_desc)
                        successes += 1
                    elif tries[task_id] < self.max_tries:
                        logger.warn("task %s: fail: %d" % (task_desc, proc.exitcode))
                        tries[task_id] += 1
                        fails += 1
                        tasks_queue.append((task_id, task))
                    else:
                        logger.critical("task %s: fail: %d, too many fails" % (task_desc, proc.exitcode))
                        for pid in processes:
                            os.kill(pid, 9)
                        raise Exception("Too many fails for %s" % task_desc)

