import os
import signal

import cv2
import torch
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms

from bran.train.nets import model_from_config
from bran.models.pytorch_resnet import PyTorchResNet
from bran.models.pytorch_svhn import PyTorchSVHN

# each MP process has its own global detectors
# init_models and detect will be run in their own process
detectors = None


def init_models():
    global detectors
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    detectors = {
        'kill': PyTorchResNet({
            "n_classes": 2,
            "labels": ["no", "yes"],
            "weights": 'kill/weights.pt'
        }),
        'victory': PyTorchResNet({
            "n_classes": 2,
            "labels": ["no", "yes"],
            'weights': 'victory/weights.pt'
        }),
        'apex-kill': PyTorchSVHN({
            'weights': 'apex-kill/weights.pt'
        }),
        'apex-victory': PyTorchResNet({
            "n_classes": 2,
            "labels": ["no", "yes"],
            'weights': 'apex-victory/weights.pt'
        })
    }


def detect(name, image):
    global detectors
    return detectors[name].detect(image)


pool = mp.Pool(mp.cpu_count(), initializer=init_models)


class PyTorchPool:
    def __init__(self, name):
        self.name = name

    def detect(self, image):
        return pool.apply(detect, args=(self.name, image))
