# coding=utf-8
from typing import Optional
import json

import inject
import click

from google.protobuf import json_format
from google.protobuf import descriptor_pb2

from tasklet.api.v2 import data_model_pb2 as data_model
from tasklet.api.v2 import tasklet_service_pb2 as tasklet_service
from tasklet.api.v2 import tasklet_service_pb2_grpc as tasklet_service_grpc

from tasklet.experimental.cli import consts
from tasklet.experimental.cli import context
from tasklet.experimental.cli import interfaces
from tasklet.experimental.cli import proto_support


def get_build_schema(
    build_id: str, ctx: context.TaskletContext, client: tasklet_service_grpc.TaskletServiceStub
) -> (data_model.IOSimpleSchemaProto, descriptor_pb2.FileDescriptorSet):
    get_build_resp: tasklet_service.GetBuildResponse = ctx.execute_request(
        client.GetBuild,
        tasklet_service.GetBuildRequest(build_id=build_id),
    )
    schema: data_model.IOSchema = get_build_resp.build.spec.schema.simple_proto
    fds = ctx.get_schema(schema.schema_hash).schema
    return schema, fds


@click.group(name="execution", cls=context.GroupWithContextOptions, help="List or get tasklet executions")
@context.option_output
@context.option_cluster
def execution_subcommand():
    pass


@execution_subcommand.command(name="list", help="Show list of tasklet executions")
@context.tasklet_descriptor_argument(argument_type=context.TaskletArgumentType.OPTIONAL_BUILD)
@click.option(
    "-l", "--limit", metavar="<limit>", type=click.IntRange(1), default=20, show_default=True,
    help="Max number of executions to show."
)
def list_executions(namespace: str, tasklet: str, build_id: Optional[str], limit: int):
    # noinspection PyTypeChecker
    ctx: context.TaskletContext = inject.instance(interfaces.ITaskletContext)
    if build_id:
        result = tasklet_service.ListExecutionsByBuildResponse()
    else:
        result = tasklet_service.ListExecutionsByTaskletResponse()
    while limit > 0:
        if build_id:
            response: tasklet_service.ListExecutionsByBuildResponse = ctx.execute_request(
                ctx.driver.get_tasklet_client().ListExecutionsByBuild,
                tasklet_service.ListExecutionsByBuildRequest(build_id=build_id, token=result.token)
            )
        else:
            response: tasklet_service.ListExecutionsByTaskletResponse = ctx.execute_request(
                ctx.driver.get_tasklet_client().ListExecutionsByTasklet,
                tasklet_service.ListExecutionsByTaskletRequest(namespace=namespace, tasklet=tasklet, token=result.token)
            )
        if response.executions:
            result.executions.extend(response.executions[:limit])
            limit -= len(response.executions)
        elif result.token == 0:
            click.echo(f"No executions in '{namespace}/{tasklet}{':' + build_id if build_id else ''}'", err=True)
            return
        else:
            result.token = response.token
            break
        result.token = response.token

    for execution in result.executions:
        # Strip legacy fields
        execution.status.ClearField("error")
        execution.status.ClearField("result")

    ctx.dump_proto_message(result)
    if ctx.output == consts.Output.TABLE:
        click.echo(f"page token: {result.token}")


@execution_subcommand.command(name="get", help="Get execution info by ID")
@click.argument("execution_id")
def get_execution(execution_id: str):
    # noinspection PyTypeChecker
    ctx: context.TaskletContext = inject.instance(interfaces.ITaskletContext)
    response: tasklet_service.GetExecutionResponse = ctx.execute_request(
        ctx.driver.get_tasklet_client().GetExecution,
        tasklet_service.GetExecutionRequest(id=execution_id)
    )
    # Strip legacy fields
    response.execution.status.ClearField("error")
    response.execution.status.ClearField("result")
    ctx.dump_proto_message(response)


@execution_subcommand.command(name="abort", help="Abort execution")
@click.argument("execution_id")
@click.argument("reason")
def abort_execution(execution_id: str, reason: str):
    # noinspection PyTypeChecker
    ctx: context.TaskletContext = inject.instance(interfaces.ITaskletContext)
    _: tasklet_service.AbortExecutionResponse = ctx.execute_request(
        ctx.driver.get_tasklet_client().AbortExecution,
        tasklet_service.AbortExecutionRequest(id=execution_id, reason=reason)
    )


@execution_subcommand.command(name="get-output", help="Get execution output by ID")
@click.argument("execution_id")
@click.argument(
    "output_format",
    type=click.Choice((
        consts.ProtoFormat.PROTO_JSON,
        consts.ProtoFormat.PROTO_TEXT,
        consts.ProtoFormat.PROTO_BINARY,
    )),
    default=consts.ProtoFormat.PROTO_JSON,
)
def get_execution_output(execution_id: str, output_format: consts.ProtoFormat):
    # noinspection PyTypeChecker
    ctx: context.TaskletContext = inject.instance(interfaces.ITaskletContext)

    client = ctx.driver.get_tasklet_client()
    response: tasklet_service.GetExecutionResponse = ctx.execute_request(
        client.GetExecution,
        tasklet_service.GetExecutionRequest(id=execution_id),
    )
    if response.execution.status.status != data_model.E_EXECUTION_STATUS_FINISHED:
        click.echo(f"Execution not finished. Status: {response.execution.status.status}", err=True)
        exit(1)

    result = response.execution.status.processing_result
    if result.WhichOneof("kind") == "server_error":
        click.echo("Failed to process execution:", err=True)
        js = json_format.MessageToJson(result.server_error, indent=2, sort_keys=True)
        click.echo(js, err=True)
        exit(1)
    elif result.WhichOneof("kind") == "user_error":
        click.echo("User error:", err=True)
        js = json_format.MessageToJson(result.user_error, indent=2, sort_keys=True)
        click.echo(js, err=True)
        exit(1)

    serialized_output = response.execution.status.processing_result.output.serialized_output

    if output_format == consts.ProtoFormat.PROTO_BINARY:
        print(serialized_output.decode())
        return

    schema, fds = get_build_schema(response.execution.meta.build_id, ctx, client)
    output_message = schema.output_message

    formatter = proto_support.ProtoSupport.from_fds(fds)
    if output_format == consts.ProtoFormat.PROTO_JSON:
        js_dict = formatter.convert_protobuf_bytes_to_json(
            output_message,
            serialized_output
        )
        print(json.dumps(js_dict, sort_keys=True, indent=4, ensure_ascii=False))
    elif output_format == consts.ProtoFormat.PROTO_TEXT:
        output = formatter.convert_protobuf_bytes_to_text(
            output_message,
            serialized_output
        )
        print(output)
    else:
        raise RuntimeError(f"Unexpected format {output_format}")


@execution_subcommand.command(name="get-input", help="Get execution input by ID")
@click.argument("execution_id")
@click.argument(
    "output_format",
    type=click.Choice((
        consts.ProtoFormat.PROTO_JSON,
        consts.ProtoFormat.PROTO_TEXT,
        consts.ProtoFormat.PROTO_BINARY,
    )),
    default=consts.ProtoFormat.PROTO_JSON,
)
def get_execution_input(execution_id: str, output_format: consts.ProtoFormat):
    # noinspection PyTypeChecker
    ctx: context.TaskletContext = inject.instance(interfaces.ITaskletContext)

    client = ctx.driver.get_tasklet_client()
    response: tasklet_service.GetExecutionResponse = ctx.execute_request(
        client.GetExecution,
        tasklet_service.GetExecutionRequest(id=execution_id),
    )
    serialized_data = response.execution.spec.input.serialized_data
    if output_format == consts.ProtoFormat.PROTO_BINARY:
        print(serialized_data.decode())
        return

    schema, fds = get_build_schema(response.execution.meta.build_id, ctx, client)
    input_message = schema.input_message

    formatter = proto_support.ProtoSupport.from_fds(fds)
    if output_format == consts.ProtoFormat.PROTO_JSON:
        js_dict = formatter.convert_protobuf_bytes_to_json(
            input_message,
            serialized_data,
        )
        print(json.dumps(js_dict, sort_keys=True, indent=4, ensure_ascii=False))
    elif output_format == consts.ProtoFormat.PROTO_TEXT:
        output = formatter.convert_protobuf_bytes_to_text(
            input_message,
            serialized_data,
        )
        print(output)
    else:
        raise RuntimeError(f"Unexpected format {output_format}")
