import os

import cv2

import torch
import torchvision
import torchvision.transforms as transforms

from bran.train.nets.svhn import SVHNNet


def get_max(digits_net):
    p1, d1 = torch.max(torch.softmax(digits_net[0], dim=1), dim=1)
    p2, d2 = torch.max(torch.softmax(digits_net[1], dim=1), dim=1)
    p3, d3 = torch.max(torch.softmax(digits_net[2], dim=1), dim=1)
    return torch.stack([p1, p2, p3], dim=1), torch.stack([d1, d2, d3], dim=1)


def get_num(digits):
    nums = []
    for row in digits:
        x = int("".join([str(i.item()) for i in row if i != 10]) or -1)
        nums.append(x)
    return nums


class PyTorchSVHN:
    def __init__(self, config):
        self.weights = config['weights']

        model_file = os.path.join('/var/bran/gamesense/detectors', self.weights)

        self.model = SVHNNet()
        self.model.load_state_dict(torch.load(model_file, map_location='cpu'))
        self.model.eval()

    def detect(self, image):
        '''
            image: numpy [w, h] GRAYSCALE
            return: (label, prob)
        '''
        image = cv2.resize(image, (32, 32))
        img = image.reshape([1, 1, 32, 32]) / 255.0

        output = self.model(torch.FloatTensor(img))
        probs, maxs = get_max(output)

        probs = torch.prod(probs, dim=1)
        labels = get_num(maxs)

        return labels[0], probs[0].cpu().detach().numpy()
