# Code generated by protoc-gen-twirp_python_fultonlambda v1.0.0, DO NOT EDIT.
# source: model_registry_twirp.proto

try:
    import httplib
except ImportError:
    import http.client as httplib

import base64
import json
import sys

from google.protobuf import symbol_database as _symbol_database

_sym_db = _symbol_database.Default()


class TwirpException(httplib.HTTPException):
    def __init__(self, code, message, meta):
        self.code = code
        self.message = message
        self.meta = meta
        super(TwirpException, self).__init__(message)

    @classmethod
    def from_http_err(cls, err):
        try:
            jsonerr = json.loads(err)
            code = jsonerr["code"]
            msg = jsonerr["msg"]
            meta = jsonerr.get("meta")
            if meta is None:
                meta = {}
        except:
            code = "internal"
            msg = "Error from intermediary with HTTP status code {} {}".format(
                err.code,
                httplib.responses[err.code],
            )
            meta = {}
        return cls(code, msg, meta)


class LambdaFunctionException(Exception):
    def __init__(self, error_type, error_msg):
        self.error_type = error_type
        self.error_msg = error_msg


class TwitchVXModelRegistryLambdaClient(object):
    def __init__(self, lambda_endpoint, boto_lambda_client):
        """Creates a new client for the TwitchVXModelRegistry service.

        Args:
            lambda_endpoint: The endpoint of the lambda function,
                should conform to the InvokeFunction API requirement.
            boto_lambda_client: The boto3 client for lambda
        """
        if sys.version_info[0] > 2:
            self.__target = lambda_endpoint
        else:
            self.__target = lambda_endpoint.encode("ascii")
        self.__service_name = (
            "twitch.fulton.example.twitchvxmodelregistry.TwitchVXModelRegistry"
        )
        self.__lambda_client = boto_lambda_client

    def __make_request(self, body, full_method):
        payload = {
            "path": "/twirp" + full_method,
            "httpMethod": "POST",
            "headers": {"Content-Type": "application/protobuf"},
            "isBase64Encoded": True,
            "body": base64.b64encode(body).decode("utf-8"),
        }

        resp = self.__lambda_client.invoke(
            FunctionName=self.__target, Payload=json.dumps(payload)
        )

        if resp.get("FunctionError", False):
            raise LambdaFunctionException(
                resp["FunctionError"], resp["Payload"].read().decode("utf-8")
            )

        resp = json.loads(resp["Payload"].read().decode("utf-8"))
        if resp["statusCode"] != 200:
            raise TwirpException.from_http_err(resp["body"])

        return base64.b64decode(resp["body"])

    def register_model(self, register_model_request):
        """
        Register the metadata for a model instance
        """

        serialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.RegisterModelRequest"
        ).SerializeToString
        deserialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.RegisterModelResponse"
        ).FromString

        full_method = "/{}/{}".format(self.__service_name, "RegisterModel")
        body = serialize(register_model_request)
        resp_str = self.__make_request(body=body, full_method=full_method)
        return deserialize(resp_str)

    def get_model(self, get_model_request):
        """
        Fetch the metadata for a model instance
        """

        serialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.GetModelRequest"
        ).SerializeToString
        deserialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.GetModelResponse"
        ).FromString

        full_method = "/{}/{}".format(self.__service_name, "GetModel")
        body = serialize(get_model_request)
        resp_str = self.__make_request(body=body, full_method=full_method)
        return deserialize(resp_str)

    def update_model(self, update_model_request):
        """
        Update the metadata for an existing model instance, only update or append the fields in the metadata
        """

        serialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.UpdateModelRequest"
        ).SerializeToString
        deserialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.UpdateModelResponse"
        ).FromString

        full_method = "/{}/{}".format(self.__service_name, "UpdateModel")
        body = serialize(update_model_request)
        resp_str = self.__make_request(body=body, full_method=full_method)
        return deserialize(resp_str)

    def rollout_model(self, rollout_model_request):
        """
        Ask model registry to start replacing the old instance with the new one
        """

        serialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.RolloutModelRequest"
        ).SerializeToString
        deserialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.RolloutModelResponse"
        ).FromString

        full_method = "/{}/{}".format(self.__service_name, "RolloutModel")
        body = serialize(rollout_model_request)
        resp_str = self.__make_request(body=body, full_method=full_method)
        return deserialize(resp_str)

    def get_available_models(self, get_available_models_request):
        """
        Get the most current rollout entry for model families since the clients are polling model families periodically
        """

        serialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.GetAvailableModelsRequest"
        ).SerializeToString
        deserialize = _sym_db.GetSymbol(
            "twitch.fulton.example.twitchvxmodelregistry.GetAvailableModelsResponse"
        ).FromString

        full_method = "/{}/{}".format(self.__service_name, "GetAvailableModels")
        body = serialize(get_available_models_request)
        resp_str = self.__make_request(body=body, full_method=full_method)
        return deserialize(resp_str)
