import os

import cv2

import torch
import torchvision
import torchvision.transforms as transforms

from bran.train.nets import model_from_config

class PyTorchResNet:
    def __init__(self, config):
        self.labels = config['labels']
        self.n_classes = config['n_classes']
        self.weights = config['weights']
        self.use_gpu = config.get('gpu', False)

        model_file = os.path.join('/var/bran/gamesense/detectors', self.weights)

        self.model = torchvision.models.resnet18()
        num_features = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_features, self.n_classes)
        self.model.load_state_dict(torch.load(model_file, map_location='cpu'))
        if self.use_gpu:
            self.model = self.model.cuda()
        self.model.eval()

        self.transforms = transforms.Compose([
            lambda img: cv2.resize(img, (224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])

    def detect(self, image):
        '''
            image: numpy [w, h, c] RGB
            return: (label, prob)
        '''
        tensor = self.transforms(image).unsqueeze(0)
        if self.use_gpu:
            tensor = tensor.cuda()
        output = self.model(tensor)[0]

        _, index = torch.max(output, dim=0)
        probs = torch.softmax(output, dim=0).cpu().detach().numpy()

        return self.labels[index], probs[index]
