from typing import Optional

import os
import sys

import grpc
import google.protobuf.message as pb_message

from tasklet.api.v2 import context_pb2
from tasklet.api.v2 import executor_service_pb2
from tasklet.api.v2 import executor_service_pb2_grpc
from tasklet.api.v2 import well_known_structures_pb2 as wks_pb2

TASKLET_CONTEXT_ENV_VAR = "TASKLET_CONTEXT"


class TaskletInterface:

    def __init__(self, args=None):
        if args is None:
            args = sys.argv[1:]

        if len(args) < 4:
            raise RuntimeError(f"Unsupported tasklet command. There must be at least 4 arguments ({len(args)} given)")

        self.__executor_address: str = args[0]
        self.__input_path: str = args[1]
        self.__output_path: str = args[2]
        self.__error_path: str = args[3]
        self.__context_path: str = os.environ[TASKLET_CONTEXT_ENV_VAR]
        self.__context: Optional[context_pb2.Context] = None
        self.__client: Optional[executor_service_pb2_grpc.ExecutorServiceStub] = None

    def get_context(self) -> context_pb2.Context:
        if self.__context is None:
            self.__context = context_pb2.Context()
            with open(self.__context_path, "rb") as context_file:
                self.__context.ParseFromString(context_file.read())
        return self.__context

    @property
    def executor_client(self) -> executor_service_pb2_grpc.ExecutorServiceStub:
        if self.__client is None:
            channel = grpc.insecure_channel(self.__executor_address)
            self.__client = executor_service_pb2_grpc.ExecutorServiceStub(channel)
        return self.__client

    def get_secret(self, secret_ref: wks_pb2.SecretRef) -> executor_service_pb2.SecretValue:
        resp: executor_service_pb2.GetSecretRefResponse = self.executor_client.GetSecretRef(
            executor_service_pb2.GetSecretRefRequest(ref=secret_ref),
        )
        return resp.value

    def read_input(self, message: pb_message.Message) -> pb_message.Message:
        with open(self.__input_path, "rb") as f:
            data = f.read()
        message.ParseFromString(data)
        return message

    def write_output(self, output: pb_message.Message):
        self.write_raw_output(output.SerializeToString())

    def write_error(self, user_error: wks_pb2.UserError):
        self.write_raw_error(user_error.SerializeToString())

    def write_raw_output(self, output: bytes):
        with open(self.__output_path, "wb") as out:
            out.write(output)

    def write_raw_error(self, output: bytes):
        with open(self.__error_path, "wb") as err:
            err.write(output)
