import os
import grpc
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

channel = grpc.insecure_channel("localhost:8501")
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)


class TfServingModel:
    def __init__(self, config):
        self.name = config['name']
        self.labels = config['labels']
        self.version = config.get('version', 1)

        self.transforms = transforms.Compose([
            lambda img: cv2.resize(img, (224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])

    def detect(self, image):
        '''
            image: numpy [w, h, c] RGB
            return: (label, prob)
        '''
        image = self.transforms(image).cpu().detach().numpy()

        request = predict_pb2.PredictRequest()
        request.model_spec.name = self.name.lower()
        request.model_spec.version.value = self.version
        request.inputs['input'].CopyFrom(tf.make_tensor_proto(
            image, dtype=tf.float32, shape=[1, 3, 224, 224]))
        result = stub.Predict(request, timeout=1)

        output = np.array(result.outputs['add_9'].float_val)
        index = np.argmax(output)
        probs = torch.softmax(output, dim=0).cpu().detach().numpy()

        return self.labels[index], probs[index]
