import abc
import datetime as dt
from typing import Collection, Optional, Union

import click
import prettytable

from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2

import tasklet.api.v2.data_model_pb2 as data_model
import tasklet.api.v2.schema_registry_service_pb2 as schema_registry_service
import tasklet.api.v2.tasklet_service_pb2 as tasklet_service


def _timestamp_to_date(ts: timestamp_pb2.Timestamp) -> str:
    return dt.datetime.fromtimestamp(ts.seconds).strftime("%Y-%m-%dT%H:%M:%S")


def human_sized(size: Union[int, float]) -> str:
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(size) < 1024.0:
            return "{:3.1f}{}b".format(size, unit)
        size /= 1024.0
    return "{:.1f}{}b".format(size, "Y")


class ABCTableDumper(metaclass=abc.ABCMeta):
    headers: [str]
    order_by: Optional[str] = None

    @classmethod
    @abc.abstractmethod
    def _get_row_data(cls, message) -> [str]:
        return NotImplemented

    @classmethod
    def dump_single_object(cls, message) -> prettytable.PrettyTable:
        table = prettytable.PrettyTable(("Key", "Value"), header=False, border=False, align="l")
        for row in zip(cls.headers, cls._get_row_data(message)):
            table.add_row(row)
        return table

    @classmethod
    def dump_objects_list(cls, message_list: Collection) -> prettytable.PrettyTable:
        table = prettytable.PrettyTable(cls.headers, sortby=cls.order_by or cls.headers[0])
        for message in message_list:
            if not message:
                table.add_row(["-"] * len(cls.headers))
            else:
                table.add_row(cls._get_row_data(message))
        return table


class NamespaceDumper(ABCTableDumper):
    headers = ("Namespace", "Namespace ID", "Owner", "Created")

    @classmethod
    def _get_row_data(cls, namespace: data_model.Namespace) -> [str]:
        return (
            namespace.meta.name,
            namespace.meta.id,
            namespace.meta.account_id,
            _timestamp_to_date(namespace.meta.created_at),
        )

    @classmethod
    def dump_namespace_object(cls, message):
        return super().dump_single_object(message.namespace)

    @classmethod
    def dump_namespace_list(cls, message):
        return super().dump_objects_list(message.namespaces)


class TaskletDumper(ABCTableDumper):
    headers = ("Tasklet", "Tasklet ID", "Owner", "Catalog", "Created", "Revision")

    @classmethod
    def _get_row_data(cls, tasklet: data_model.Tasklet) -> [str]:
        return (
            tasklet.meta.name,
            tasklet.meta.id,
            tasklet.meta.account_id,
            tasklet.spec.catalog,
            _timestamp_to_date(tasklet.meta.created_at),
            tasklet.spec.revision,
        )

    @classmethod
    def dump_tasklet_object(cls, message):
        return super().dump_single_object(message.tasklet)

    @classmethod
    def dump_tasklet_list(cls, message):
        return super().dump_objects_list(message.tasklets)


class BuildDumper(ABCTableDumper):
    headers = ("Build ID", "Resource ID", "Type", "CPU", "RAM", "Storage", "Created")
    order_by = "Created"
    storage_class_to_string = {
        data_model.EStorageClass.E_STORAGE_CLASS_HDD: "hdd",
        data_model.EStorageClass.E_STORAGE_CLASS_SSD: "ssd",
        data_model.EStorageClass.E_STORAGE_CLASS_RAM: "ram",
    }

    @classmethod
    def _get_row_data(cls, build: data_model.Build) -> [str]:
        storage_class_name = cls.storage_class_to_string.get(build.spec.workspace.storage_class)
        if storage_class_name:
            storage_info = f"{storage_class_name}:{human_sized(build.spec.workspace.storage_size)}"
        else:
            storage_info = "<invalid>"
        return (
            build.meta.id,
            build.spec.payload.sandbox_resource_id or "<undefined>",
            build.spec.launch_spec.type,
            f"{build.spec.compute_resources.vcpu_limit / 1000.:.3f}",
            human_sized(build.spec.compute_resources.memory_limit),
            storage_info,
            _timestamp_to_date(build.meta.created_at),
        )

    @classmethod
    def dump_build_object(cls, message):
        return super().dump_single_object(message.build)

    @classmethod
    def dump_build_list(cls, message):
        return super().dump_objects_list(message.builds)


class LabelDumper(ABCTableDumper):
    headers = ("Label", "Label ID", "Build ID")

    @classmethod
    def _get_row_data(cls, label: data_model.Label) -> [str]:
        return (
            label.meta.name,
            label.meta.id,
            label.spec.build_id or "<undefined>",
        )

    @classmethod
    def dump_label_object(cls, message):
        return super().dump_single_object(message.label)

    @classmethod
    def dump_label_list(cls, message):
        return super().dump_objects_list(message.labels)


class ExecutionDumper(ABCTableDumper):
    headers = ("Execution ID", "Author", "Label", "Build ID", "Created", "Status", "Task Status", "Result")
    order_by = "Created"

    execution_status_to_string = {
        data_model.EExecutionStatus.E_EXECUTION_STATUS_EXECUTING: "executing",
        data_model.EExecutionStatus.E_EXECUTION_STATUS_FINISHED: "finished",
        data_model.EExecutionStatus.E_EXECUTION_STATUS_INVALID: "invalid",
    }

    @classmethod
    def _error_code_to_string(cls, code : data_model.ErrorCodes.ErrorCode)->str:
        try:
            return data_model.ErrorCodes.ErrorCode.Name(code)
        except ValueError:
            return f"UNKNOWN_CODE_{code}"

    @classmethod
    def _get_row_data(cls, execution: data_model.Execution) -> [str]:
        status: data_model.ExecutionStatus = execution.status

        if status.status != data_model.E_EXECUTION_STATUS_FINISHED:
            execution_result = "<pending>"
        else:
            pr = status.processing_result
            if pr.WhichOneof("kind") == "output":
                execution_result = "OK"
            elif pr.WhichOneof("kind") == "server_error":
                code_str = cls._error_code_to_string(pr.server_error.code)
                execution_result = f"Failed: {code_str}, Message: {pr.server_error.description}"
            else:
                # NB: user error
                execution_result = f"Failed: {pr.user_error.description}"

        if status.HasField("processor") and status.processor.updated_at:  # TODO: add YT processing
            processor_status = f"{status.processor.message} (id: {status.processor.sandbox_task_id})"
        else:
            processor_status = "<no task>"
        return (
            execution.meta.id,
            execution.spec.author,
            execution.spec.referenced_label,
            execution.meta.build_id,
            _timestamp_to_date(execution.meta.created_at),
            cls.execution_status_to_string[execution.status.status],
            processor_status,
            execution_result
        )

    @classmethod
    def dump_execution_object(cls, message):
        execution: data_model.Execution = message.execution
        table = super().dump_single_object(execution)
        if execution.status.processing_result.WhichOneof("kind") == "output":
            table.add_row(("Output", execution.status.processing_result.output.serialized_output or "<no output>"))
        return table

    @classmethod
    def dump_execution_list(cls, message):
        return super().dump_objects_list(message.executions)


class AnyDumper(ABCTableDumper):
    headers = ("Payload",)
    order_by = "Payload"

    @classmethod
    def _get_row_data(cls, message) -> [str]:
        from google.protobuf import json_format

        return (
            json_format.MessageToJson(message),
        )


GLOBAL_DUMPING_FACTORY = {
    # Namespace
    tasklet_service.GetNamespaceResponse: NamespaceDumper.dump_namespace_object,
    tasklet_service.CreateNamespaceResponse: NamespaceDumper.dump_namespace_object,
    tasklet_service.ListNamespacesResponse: NamespaceDumper.dump_namespace_list,
    # Tasklet
    tasklet_service.GetTaskletResponse: TaskletDumper.dump_tasklet_object,
    tasklet_service.CreateTaskletResponse: TaskletDumper.dump_tasklet_object,
    tasklet_service.UpdateTaskletResponse: TaskletDumper.dump_tasklet_object,
    tasklet_service.ListTaskletsResponse: TaskletDumper.dump_tasklet_list,
    # Build
    tasklet_service.GetBuildResponse: BuildDumper.dump_build_object,
    tasklet_service.CreateBuildResponse: BuildDumper.dump_build_object,
    tasklet_service.ListBuildsResponse: BuildDumper.dump_build_list,
    # Label
    tasklet_service.GetLabelResponse: LabelDumper.dump_label_object,
    tasklet_service.CreateLabelResponse: LabelDumper.dump_label_object,
    tasklet_service.UpdateLabelResponse: LabelDumper.dump_label_object,
    tasklet_service.MoveLabelResponse: LabelDumper.dump_label_object,
    tasklet_service.ListLabelsResponse: LabelDumper.dump_label_list,
    # Execution
    tasklet_service.GetExecutionResponse: ExecutionDumper.dump_execution_object,
    tasklet_service.ExecuteResponse: ExecutionDumper.dump_execution_object,
    tasklet_service.ListExecutionsByBuildResponse: ExecutionDumper.dump_execution_list,
    tasklet_service.ListExecutionsByTaskletResponse: ExecutionDumper.dump_execution_list,

    # Schema registry
    schema_registry_service.SchemaMetadata: AnyDumper.dump_single_object,
    struct_pb2.Struct: AnyDumper.dump_single_object,
}


def get_table_for_message(message) -> prettytable.PrettyTable:
    dump_method = GLOBAL_DUMPING_FACTORY.get(type(message))
    if not dump_method:
        click.echo(f"Can't print message of type '{type(message)}': ")
    return dump_method(message)


def merge_tables(t1: prettytable.PrettyTable, t2: prettytable.PrettyTable) -> prettytable.PrettyTable:
    if len(t1.rows) != len(t2.rows):
        raise ValueError(f"Can't merge tables with different number of rows ({len(t1.rows)} != {len(t2.rows)})")
    t1 = t1.copy()
    t1_names, t2_names = t1.field_names, t2.field_names
    for column_name in t1_names:
        if column_name in t2_names:
            t2.del_column(column_name)
    t = prettytable.PrettyTable(t1_names + t2.field_names)
    for i, t1_row in enumerate(t1.rows):
        t.add_row(t1_row + t2.rows[i])
    return t
