#!/usr/bin/env python

import os
import traceback

from sandbox.sandboxsdk.errors import SandboxTaskFailureError


class ProcessPool(object):
    def __init__(self, nproc):
        self.nproc = int(nproc)
        if (self.nproc) < 1:
            self.nproc = 1
        self.kids = {}
        self.counter = 1
        self.data = {}
        self.error = None

    def have_kids(self):
        return bool(self.kids)

    def num_kids(self):
        return len(self.kids.keys())

    def wait_for_one(self):
        while True:
            (pid, status) = os.wait()
            if pid in self.kids:
                signal = status & 0xff
                status = status >> 8
                (r, c) = self.kids[pid]
                del self.kids[pid]
                if status == 0:
                    self.data[c] = r.read()
                else:
                    self.data[c] = r.read()
                    self.error = "Process with pid {} return status code {}, killed by {}. Captured error: {}".format(pid, status, signal, self.data[c])
                return

    def get_data(self, c):
        if c in self.data:
            return self.data[c]
        elif c in [p[1] for p in self.kids.values()]:
            while c not in self.data:
                self.wait_for_one()
            return self.data[c]
        else:
            raise ValueError("no such process")

    def wait_for_slot(self):
        while self.num_kids() >= self.nproc:
            self.wait_for_one()

    def run(self, code):
        self.wait_for_slot()
        if self.error:
            return
        (r, w) = os.pipe()
        r = os.fdopen(r, 'r')
        w = os.fdopen(w, 'w')
        pid = os.fork()
        out_code = 0
        if pid == 0:
            try:
                out = str(code())
            except:
                out = "Exception raised: " + traceback.format_exc()
                out_code = 1

            w.write(out)
            w.close()
            os._exit(out_code)
        else:
            self.counter += 1
            self.kids[pid] = (r, self.counter)
            return self.counter

    def finish(self):
        while self.have_kids():
            self.wait_for_one()

    def map(self, code, items):
        ids = [self.run(lambda: code(i)) for i in items]
        self.finish()
        if self.error:
            raise SandboxTaskFailureError("Error executing some of the subprocesses: "+self.error)
        return [self.get_data(c) for c in ids]


if __name__ == "__main__":
    import time
    p = ProcessPool(3)

    t = time.time()
    p.map(lambda i: time.sleep(2), xrange(3))
    print time.time() - t

    t = time.time()
    p.map(lambda i: time.sleep(2), xrange(3))
    print time.time() - t

    t = time.time()
    out = p.map(lambda i: time.sleep(i) or i, [4] + [1] * 7)
    print out
    print time.time() - t
