from google.protobuf.json_format import MessageToJson
from google.protobuf.any_pb2 import Any
from load.projects.cloud.loadtesting.logan import lookup_logger
from load.projects.cloud.loadtesting.server.api.common import handler
from load.projects.cloud.loadtesting.server.api.private_v1 import tank as tank_service
from yandex.cloud.priv.loadtesting.v2 import agent_instance_service_pb2, agent_instance_service_pb2_grpc, \
    agent_instance_pb2
from yandex.cloud.priv.loadtesting.v1 import tank_instance_service_pb2, tank_instance_pb2


class AgentServicer(agent_instance_service_pb2_grpc.AgentInstanceServiceServicer):

    def __init__(self):
        self.logger = lookup_logger('AgentPrivate')

    class _Get(handler.BasePrivateHandler):
        _handler_name = 'Get'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            tank_instance = tank_service.TankServicer._Get(self._parent_logger).handle(
                tank_instance_service_pb2.GetTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)
            return agent_instance_service_pb2.GetAgentInstanceResponse(agent_instance=agent_from_tank(tank_instance.tank_instance))

    def Get(self, request, context):
        return self._Get(self.logger).handle(request, context)

    class _List(handler.BasePrivateHandler):
        _handler_name = 'List'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            tank_instances = tank_service.TankServicer._List(self._parent_logger).handle(
                tank_instance_service_pb2.ListTankInstancesRequest(
                    folder_id=self.request.folder_id,
                    page_size=self.request.page_size,
                    page_token=self.request.page_token,
                    filter=self.request.filter),
                self.context)

            agent_instances = [agent_from_tank(tank_instance) for tank_instance in tank_instances.tank_instances]
            return agent_instance_service_pb2.ListAgentInstancesResponse(
                folder_id=tank_instances.folder_id, agent_instances=agent_instances, next_page_token=tank_instances.next_page_token)

    def List(self, request, context):
        return self._List(self.logger).handle(request, context)

    class _Create(handler.BasePrivateHandler):
        _handler_name = 'Create'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            create_operation = tank_service.TankServicer._Create(self._parent_logger).handle(
                self._create_old_message(self.request), self.context)

            update_operation_fields(self.db, create_operation, tank_instance_service_pb2.CreateTankInstanceMetadata,
                                    agent_instance_service_pb2.CreateAgentInstanceMetadata)
            return create_operation

        @staticmethod
        def _create_old_message(request: agent_instance_service_pb2.CreateAgentInstanceRequest) -> tank_instance_service_pb2.CreateTankInstanceRequest:
            return tank_instance_service_pb2.CreateTankInstanceRequest(
                folder_id=request.folder_id,
                name=request.name,
                description=request.description,
                labels=request.labels,
                preset_id=request.preset_id,
                service_account_id=request.service_account_id,
                zone_id=request.zone_id,
                network_interface_specs=request.network_interface_specs,
                metadata=request.metadata
            )

    def Create(self, request, context):
        return self._Create(self.logger).handle(request, context)

    class _Delete(handler.BasePrivateHandler):
        _handler_name = 'Delete'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            delete_operation = tank_service.TankServicer._Delete(self._parent_logger).handle(
                tank_instance_service_pb2.DeleteTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)

            update_operation_fields(self.db, delete_operation, tank_instance_service_pb2.DeleteTankInstanceMetadata,
                                    agent_instance_service_pb2.DeleteAgentInstanceMetadata)
            return delete_operation

    def Delete(self, request, context):
        return self._Delete(self.logger).handle(request, context)

    class _Stop(handler.BasePrivateHandler):
        _handler_name = 'Stop'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            stop_operation = tank_service.TankServicer._Stop(self._parent_logger).handle(
                tank_instance_service_pb2.StopTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)

            update_operation_fields(self.db, stop_operation, tank_instance_service_pb2.StopTankInstanceMetadata,
                                    agent_instance_service_pb2.StopAgentInstanceMetadata)
            return stop_operation

    def Stop(self, request, context):
        return self._Stop(self.logger).handle(request, context)

    class _Start(handler.BasePrivateHandler):
        _handler_name = 'Start'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            start_operation = tank_service.TankServicer._Start(self._parent_logger).handle(
                tank_instance_service_pb2.StartTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)

            update_operation_fields(self.db, start_operation, tank_instance_service_pb2.StartTankInstanceMetadata,
                                    agent_instance_service_pb2.StartAgentInstanceMetadata)
            return start_operation

    def Start(self, request, context):
        return self._Start(self.logger).handle(request, context)

    class _Restart(handler.BasePrivateHandler):
        _handler_name = 'Restart'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            restart_operation = tank_service.TankServicer._Restart(self._parent_logger).handle(
                tank_instance_service_pb2.RestartTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)

            update_operation_fields(self.db, restart_operation, tank_instance_service_pb2.RestartTankInstanceMetadata,
                                    agent_instance_service_pb2.RestartAgentInstanceMetadata)
            return restart_operation

    def Restart(self, request, context):
        return self._Restart(self.logger).handle(request, context)

    class _UpgradeImage(handler.BasePrivateHandler):
        _handler_name = 'UpgradeImage'

        def __init__(self, parent_logger):
            super().__init__(parent_logger)
            self._parent_logger = parent_logger

        def proceed(self):
            upgrade_operation = tank_service.TankServicer._UpgradeImage(self._parent_logger).handle(
                tank_instance_service_pb2.UpgradeImageTankInstanceRequest(id=self.request.agent_instance_id),
                self.context)

            update_operation_fields(self.db, upgrade_operation, tank_instance_service_pb2.UpgradeImageTankInstanceMetadata,
                                    agent_instance_service_pb2.UpgradeImageAgentInstanceMetadata)
            return upgrade_operation

    def UpgradeImage(self, request, context):
        return self._UpgradeImage(self.logger).handle(request, context)


def agent_from_tank(tank_instance: tank_instance_pb2.TankInstance) -> agent_instance_pb2.AgentInstance:
    status = agent_instance_pb2.AgentInstance.Status.Name(tank_instance.status)
    agent_version = agent_instance_pb2.AgentVersion(
        id=tank_instance.agent_version.id,
        status=agent_instance_pb2.AgentVersion.VersionStatus.Name(tank_instance.agent_version.status),
        revision=tank_instance.agent_version.revision,
        description=tank_instance.agent_version.description,
        status_comment=tank_instance.agent_version.status_comment
    )
    return agent_instance_pb2.AgentInstance(
        id=tank_instance.id,
        folder_id=tank_instance.folder_id,
        created_at=tank_instance.created_at,
        compute_instance_updated_at=tank_instance.compute_instance_updated_at,
        name=tank_instance.name,
        description=tank_instance.description,
        labels=tank_instance.labels,
        service_account_id=tank_instance.service_account_id,
        preset_id=tank_instance.preset_id,
        yandex_tank_version=tank_instance.tank_version,
        status=status,
        errors=tank_instance.errors,
        current_job=tank_instance.current_job,
        compute_instance_id=tank_instance.compute_instance_id,
        agent_version=agent_version
    )


def update_operation_fields(db, operation, old_metadata_message, new_metadata_message):
    # change metadata
    metadata = old_metadata_message()
    operation.metadata.Unpack(metadata)
    agent_instance_id = metadata.id
    new_metadata = Any()
    new_metadata.Pack(new_metadata_message(agent_instance_id=agent_instance_id))
    operation.metadata.CopyFrom(new_metadata)

    db_operation = db.operation.get(operation.id)
    if db_operation.done_resource_snapshot:
        # change snapshot
        db_tank = db.tank.get(agent_instance_id)
        agent_message = agent_from_tank(tank_service.DbToGrpcTranslator.tank(db_tank))
        db_operation.done_resource_snapshot = MessageToJson(agent_message)
        db.operation.add(db_operation)

        # change response
        response = Any()
        response.Pack(agent_message)
        operation.response.CopyFrom(response)
    return operation
