from pomegranate import DiscreteDistribution, State, HiddenMarkovModel
import cv2
import torch
import numpy

from bran.detectors.base_detector import BaseDetector
from bran.models.pytorch_resnet import PyTorchResNet

O_OTHER = "other"
O_GAME = "game"
O_LOBBY = "lobby"
O_BUS = "bus"
O_SPECTATING = "spectating"

s_other = State(DiscreteDistribution({
    O_OTHER: 0.9,
    O_GAME: 0.025,
    O_LOBBY: 0.025,
    O_BUS: 0.025,
    O_SPECTATING: 0.025
}), name="other")

s_game = State(DiscreteDistribution({
    O_OTHER: 0.025,
    O_GAME: 0.9,
    O_LOBBY: 0.025,
    O_BUS: 0.025,
    O_SPECTATING: 0.025
}), name="game")

s_lobby = State(DiscreteDistribution({
    O_OTHER: 0.025,
    O_GAME: 0.025,
    O_LOBBY: 0.9,
    O_BUS: 0.025,
    O_SPECTATING: 0.025
}), name="lobby")

s_bus = State(DiscreteDistribution({
    O_OTHER: 0.025,
    O_GAME: 0.025,
    O_LOBBY: 0.025,
    O_BUS: 0.9,
    O_SPECTATING: 0.025
}), name="bus")

s_spectating = State(DiscreteDistribution({
    O_OTHER: 0.025,
    O_GAME: 0.025,
    O_LOBBY: 0.025,
    O_BUS: 0.025,
    O_SPECTATING: 0.9
}), name="spectating")

model = HiddenMarkovModel('gamestate', s_other)
model.add_states([s_other, s_game, s_lobby, s_bus, s_spectating])
model.add_transitions(s_other,
                      [s_other, s_game, s_lobby, s_bus, s_spectating],
                      [0.8, 0.05, 0.08, 0.04, 0.3])
model.add_transitions(s_game,
                      [s_other, s_game, s_lobby, s_bus, s_spectating],
                      [0.05, 0.7, 0.05, 0.1, 0.1])
model.add_transitions(s_lobby,
                      [s_other, s_game, s_lobby, s_bus, s_spectating],
                      [0.8, 0.05, 0.08, 0.04, 0.3])
model.add_transitions(s_bus,
                      [s_other, s_game, s_lobby, s_bus, s_spectating],
                      [0.05, 0.35, 0.05, 0.5, 0.05])
model.add_transitions(s_spectating,
                      [s_other, s_game, s_lobby, s_bus, s_spectating],
                      [0.1, 0.05, 0.4, 0.05, 0.4])
model.bake()

class GameStateDetector(BaseDetector):
    def __init__(self, *args, **kwargs):
        kwargs['name'] = 'gamestate'
        kwargs['save_event'] = True
        super().__init__(*args, **kwargs)

        self.model = PyTorchResNet({
            "n_classes": 3,
            "labels": ["lobby", "game", "no"],
            'weights': 'gamestate/weights.pt'
        })

        self.dedupe_timestamp = None
        self.last_state = None

    def do_detect(self, user_id, detection_id, frame, timestamp, **kwargs):
        image = numpy.swapaxes(frame, 0, 1) # TODO(delete me) the old model is trained with wrong axes
        label, prob = self.model.detect(image)
        detected_prob = prob * 100

        detected = False
        if detected_prob > 90 and self.last_state != label:
            detected = True
            self.last_state = label

        return {"detected": detected, "prob": detected_prob, "label": label}
