from tensorflow_serving.apis import prediction_service_pb2_grpc
from tensorflow_serving.apis import predict_pb2
from absl import app, flags, logging
from absl.flags import FLAGS
import grpc
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()

flags.DEFINE_string("model", "fortnite-gamestate", "")
flags.DEFINE_string("input", "fortnite-lobby.jpg", "")
flags.DEFINE_string("api_url", "35.166.208.94:8500", "")


def resnet18_loader():
    def loader(img):
        img = tf.image.decode_jpeg(tf.read_file(img), channels=3)
        img = tf.image.resize(img, (224, 224))
        img = tf.expand_dims(img, 0)
        img = img / 255.0
        img = img - tf.constant([0.485, 0.456, 0.406])
        img = img / tf.constant([0.229, 0.224, 0.225])
        img = tf.transpose(img, [0, 3, 1, 2])
        img_proto = tf.compat.v1.make_tensor_proto(
            img.numpy(), dtype=tf.float32)
        return img_proto
    return loader


def resnet18_result(labels):
    def loader(output):
        output = np.array(output['add_9'].float_val)
        index = np.argmax(output)
        return labels[index], tf.nn.softmax(output)[index].numpy()
    return loader


def digitnet_loader():
    def loader(img):
        img = tf.image.decode_jpeg(tf.read_file(img), channels=1)
        img = tf.image.resize(img, (32, 32))
        img = tf.expand_dims(img, 0)
        img = img / 255.0
        img = tf.transpose(img, [0, 3, 1, 2])
        img_proto = tf.compat.v1.make_tensor_proto(
            img.numpy(), dtype=tf.float32)
        return img_proto
    return loader


def digitnet_result():
    def loader(outputs):
        num = ""
        probability = 1.0
        for d in ['add_9', 'add_10', 'add_11']:
            output = np.array(outputs[d].float_val)
            index = np.argmax(output)
            prob = tf.nn.softmax(output)
            num += str(index) if index != 10 else ""
            probability *= prob[index].numpy()
        return num, probability
    return loader


models = {
    'fortnite-kill': {
        'in_loader': resnet18_loader(),
        'out_loader': resnet18_result(['no', 'yes'])
    },
    'fortnite-victory': {
        'in_loader': resnet18_loader(),
        'out_loader': resnet18_result(['no', 'yes'])
    },
    'fortnite-gamestate': {
        'in_loader': resnet18_loader(),
        'out_loader': resnet18_result(['lobby', 'game', 'no'])
    },
    'apex-kill': {
        'in_loader': digitnet_loader(),
        'out_loader': digitnet_result()
    },
    'apex-victory': {
        'in_loader': resnet18_loader(),
        'out_loader': resnet18_result(['no', 'yes'])
    },
    'pubg-kill': {
        'in_loader': digitnet_loader(),
        'out_loader': digitnet_result()
    },
    'pubgmobile-kill': {
        'in_loader': digitnet_loader(),
        'out_loader': digitnet_result()
    },
    'generic-digits': {
        'in_loader': digitnet_loader(),
        'out_loader': digitnet_result()
    },
}


def main(_):
    channel = grpc.insecure_channel(FLAGS.api_url)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

    img_proto = models[FLAGS.model]['in_loader'](FLAGS.input)

    request = predict_pb2.PredictRequest()
    request.model_spec.name = FLAGS.model
    request.model_spec.version.value = 0
    request.inputs['input'].CopyFrom(img_proto)
    result = stub.Predict(request, timeout=5)

    result = models[FLAGS.model]['out_loader'](result.outputs)
    print(result)


if __name__ == '__main__':
    app.run(main)
