import cv2
import torch
import numpy as np
import sentry_sdk

from bran.detectors.base_detector import BaseDetector
from bran.shared.config import config
from bran.shared.logger import log
from bran.train.nets import model_from_config

def get_max(digits_net):
    _, d1 = torch.max(digits_net[0], dim=1)
    _, d2 = torch.max(digits_net[1], dim=1)
    _, d3 = torch.max(digits_net[2], dim=1)
    _, d4 = torch.max(digits_net[3], dim=1)
    return torch.stack([d1, d2, d3, d4], dim=1)

def get_probs(digits_net):
    d1, _ = torch.max(torch.softmax(digits_net[0], dim=1), dim=1)
    d2, _ = torch.max(torch.softmax(digits_net[1], dim=1), dim=1)
    d3, _ = torch.max(torch.softmax(digits_net[2], dim=1), dim=1)
    d4, _ = torch.max(torch.softmax(digits_net[3], dim=1), dim=1)
    return torch.stack([d1, d2, d3, d4], dim=1)

def get_num(digits):
    x = int("".join([str(i.item()) for i in digits if i != 10]) or -1)
    return x

def get_prob(prob):
    return int(100 * np.cumprod(prob.cpu().detach().numpy())[-1])

def img_to_num(net, img):
    img2 = cv2.resize(img, (32, 32))
    img2 = img2.reshape([1, 32, 32, 1])
    img2 = img2.swapaxes(1, 3)
    #print(img2)
    output = net(torch.FloatTensor(img2))
    #print(output)
    maxs = get_max(output)
    probs = get_probs(output)
    #print(maxs)
    return get_num(maxs[0]), get_prob(probs[0])

def grouper(iterable, interval=2):
    prev = None
    group = []
    for item in iterable:
        if not prev or abs(item[1] - prev[1]) <= interval:
            group.append(item)
        else:
            yield group
            group = [item]
        prev = item
    if group:
        yield group

def find_text(net, img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    dia = cv2.dilate(gray, None, 10)

    mser = cv2.MSER_create()
    _, bboxes = mser.detectRegions(dia)

    mask = np.zeros((img.shape[0], img.shape[1], 1), dtype=np.uint8)

    h2 = []
    for bbox in bboxes:
        x, y, w, h = bbox

        if h > 20: # filter out big boxes
            continue

        cv2.rectangle(mask, (x, y), (x + w, y + h), 1, -1)
        h2.append([x, y, x + w, y + h])

    bboxes_list = sorted(h2, key=lambda k: k[1])  # Sort the bounding boxes based on y1 coordinate ( y of the left-top coordinate )
    combined_bboxes = grouper(bboxes_list, 2)  # Group the bounding boxes
    res = []
    for group in combined_bboxes:
        try:
            if len(group) < 5: # filter out low numbers
                continue

            x_min = min(group, key=lambda k: k[0])[0]  # Find min of x1
            x_max = max(group, key=lambda k: k[2])[2]  # Find max of x2
            y_min = min(group, key=lambda k: k[1])[1]  # Find min of y1
            y_max = max(group, key=lambda k: k[3])[3]  # Find max of y2

            if x_max - x_min < 80: # filter out small text areas
                continue

            mask_area = np.sum(mask[y_min:y_max, x_min:x_max])
            box_area = (x_max - x_min) * (y_max - y_min)

            if mask_area / box_area < 0.5: # filter out empty areas
                continue

            text = gray[y_min:y_max, max(x_min - 3, 0): min(x_max + 3, gray.shape[1])]
            text = cv2.resize(text, (150, 20))

            time = text[:, 26:61]
            left = text[:, 85:109]
            kill = text[:, 134:]

            timep = img_to_num(net, time)
            leftp = img_to_num(net, left)
            killp = img_to_num(net, kill)

            if leftp[1] < 80 and killp[1] < 80:
                continue

            res.append((timep, leftp, killp))
        except Exception:
            sentry_sdk.capture_exception()
            log.exception("Failed to get text from box")
    return res

class TorchCounterDetector(BaseDetector):
    def __init__(self, *args, **kwargs):
        kwargs['config'] = {
            "class": "pytorch_counter",
            "game": "Fortnite",
            "type": "counter",
            "save_event": False,
            "config": {
                "transforms": [
                ["crop", { "x": 0.83, "y": 0.13, "width": 0.17, "height": 0.22 }]
                ],
                "model": {
                "arch": "resnet18_svhn"
                }
            }
        }
        super().__init__(*args, **kwargs)
        torch.set_num_threads(1)
        self.model_config = self.config['model']
        self.model = self.load_model()
        self.cnt = -1

    def do_detect(self, stream_id, detection_id, preprocessed_image, frame, timestamp, track, **kwargs):
        return {'detected': False}
        self.cnt += 1
        if self.cnt % 2 != 0:
            return {'detected': False}

        txt = find_text(self.model, preprocessed_image)
        out = []

        for t in txt:
            msg = "{}, {}, {}".format(t[0][0], t[1][0], t[2][0])
            out.append(msg)

        if out:
            o = " | ".join(out)
        else:
            o = "Counter Not Found"

        log.debug("{}_{}: {}".format(self.game, self.type, o))

        return {"detected": False, "prob": 0}

    def load_model(self):
        model_file = "{}/weights.pt".format(self.file_dir)
        model = model_from_config(self.model_config, torch.device('cpu'))
        model.eval()
        model.load_state_dict(torch.load(model_file, map_location='cpu'))
        return model
