import typing
import functools
import aioboto3
from contextlib import AsyncExitStack
from aiobotocore.config import AioConfig
from google.protobuf.message import Message
from library.python.monlib.metric_registry import MetricRegistry

from smb.common.rmq.rpc.server import BaseRpcHandler

from crm.agency_cabinet.common.server.common.config import MdsConfig
from crm.agency_cabinet.ord.common import structs
from crm.agency_cabinet.ord.proto import request_pb2, common_pb2, reports_pb2, clients_pb2, import_data_pb2, \
    campaigns_pb2, acts_pb2, organizations_pb2, contracts_pb2, invites_pb2, client_rows_pb2

from crm.agency_cabinet.ord.server.src import procedures
from crm.agency_cabinet.ord.common.exceptions import (
    OrdException,
)


def process_exception(response_proto: typing.Type[Message]):
    def wrapper(func):
        @functools.wraps(func)
        async def wrapped(*args, **kwargs):
            try:
                return await func(*args, **kwargs)
            except OrdException as ex:
                field_name = ex.proto_field
                # TODO: add generic error field for case when field_name not found or exception unexpected
                if response_proto.DESCRIPTOR.fields_by_name[field_name].message_type.name == 'Empty':
                    # TODO: use ErrorResponse
                    kwargs = {field_name: common_pb2.Empty()}
                else:
                    kwargs = {field_name: common_pb2.ErrorMessageResponse(message=ex.message)}
                # TODO: add handler tests to check every possible exception
                return response_proto(
                    **kwargs
                )
        return wrapped
    return wrapper


class Handler(BaseRpcHandler):
    _request_proto = request_pb2.RpcRequest

    def __init__(self, mds_cfg: MdsConfig, metric_registry: MetricRegistry = None):
        self.metric_registry = metric_registry
        self.mds_cfg = mds_cfg
        self.s3_resource = None
        self.s3_client = None

    async def ping(self, _: common_pb2.Empty) -> common_pb2.PingOutput:
        return common_pb2.PingOutput(ping='pong')

    async def setup(self):
        success = await self._setup_s3()
        return success

    async def _setup_s3(self) -> bool:
        self.boto3_session = aioboto3.Session(
            aws_access_key_id=self.mds_cfg.access_key_id,
            aws_secret_access_key=self.mds_cfg.secret_access_key
        )
        self.context_stack = AsyncExitStack()
        self.s3_resource = await self.context_stack.enter_async_context(
            self.boto3_session.resource(
                's3',
                endpoint_url=self.mds_cfg.endpoint_url,
                config=AioConfig(s3={'addressing_style': 'virtual'})
            )
        )
        self.s3_client = await self.context_stack.enter_async_context(
            self.boto3_session.client(
                's3',
                endpoint_url=self.mds_cfg.endpoint_url,
                config=AioConfig(s3={'addressing_style': 'virtual'})
            )
        )
        return True

    @process_exception(response_proto=reports_pb2.GetReportsInfoOutput)
    async def get_reports_info(self, message: reports_pb2.GetReportsInfo) -> reports_pb2.GetReportsInfoOutput:
        # TODO: unify function invocation
        result = await procedures.GetReportsInfo()(request=structs.GetReportsInfoRequest.from_proto(message))
        return reports_pb2.GetReportsInfoOutput(result=result.to_proto())

    @process_exception(response_proto=reports_pb2.GetDetailedReportInfoOutput)
    async def get_detailed_report_info(self, message: reports_pb2.GetDetailedReportInfo) -> reports_pb2.GetDetailedReportInfoOutput:
        result = await procedures.GetDetailedReportInfo()(request=structs.GetDetailedReportInfoRequest.from_proto(message))
        return reports_pb2.GetDetailedReportInfoOutput(result=result.to_proto())

    @process_exception(response_proto=clients_pb2.GetReportClientsInfoOutput)
    async def get_report_clients_info(self, message: clients_pb2.GetReportClientsInfoInput) -> clients_pb2.GetReportClientsInfoOutput:
        result = await procedures.GetReportClientsInfo()(request=structs.GetReportClientsInfoInput.from_proto(message))
        return clients_pb2.GetReportClientsInfoOutput(result=result.to_proto())

    @process_exception(response_proto=reports_pb2.SendReportOutput)
    async def send_report(self, message: reports_pb2.SendReportInput) -> reports_pb2.SendReportOutput:
        await procedures.SendReport()(
            request=structs.SendReportInput.from_proto(message))
        return reports_pb2.SendReportOutput(result=common_pb2.Empty())

    @process_exception(response_proto=reports_pb2.ReportExportOutput)
    async def report_export(self, message: reports_pb2.ReportExportInput) -> reports_pb2.ReportExportOutput:
        result = await procedures.ReportExport()(
            request=structs.ReportExportRequest.from_proto(message)
        )
        return reports_pb2.ReportExportOutput(result=result.to_proto())

    @process_exception(response_proto=reports_pb2.ReportExportOutput)
    async def get_report_export_info(self, message: reports_pb2.ReportExportInfoInput) -> reports_pb2.ReportExportOutput:
        result = await procedures.GetReportExportInfo()(
            request=structs.ReportExportInfoRequest.from_proto(message)
        )
        return reports_pb2.ReportExportOutput(result=result.to_proto())

    @process_exception(response_proto=reports_pb2.GetReportUrlOutput)
    async def get_report_url(self, message: reports_pb2.GetReportUrlInput) -> reports_pb2.GetReportUrlOutput:
        result = await procedures.GetReportUrl()(
            request=structs.GetReportUrlRequest.from_proto(message),
            s3_client=self.s3_client
        )
        return reports_pb2.GetReportUrlOutput(result=result.to_proto())

    @process_exception(response_proto=import_data_pb2.ImportDataOutput)
    async def import_data(self, message: import_data_pb2.ImportDataInput) -> import_data_pb2.ImportDataOutput:
        result = await procedures.ImportData()(request=structs.ImportDataInput.from_proto(message))
        return import_data_pb2.ImportDataOutput(result=result.to_proto())

    @process_exception(response_proto=import_data_pb2.GetLockStatusOutput)
    async def get_lock_status(self, message: import_data_pb2.GetLockStatusInput) -> import_data_pb2.GetLockStatusOutput:
        result = await procedures.GetLockStatus()(request=structs.GetLockStatusInput.from_proto(message))
        return import_data_pb2.GetLockStatusOutput(result=result.to_proto())

    @process_exception(response_proto=reports_pb2.DeleteReportOutput)
    async def delete_report(self, message: reports_pb2.DeleteReportInput) -> reports_pb2.DeleteReportInput:
        params = structs.DeleteReportRequest.from_proto(message)
        await procedures.DeleteReport()(request=params)
        return reports_pb2.DeleteReportOutput(result=common_pb2.Empty())

    @process_exception(response_proto=reports_pb2.CreateReportOutput)
    async def create_report(self, message: reports_pb2.CreateReportInput) -> reports_pb2.CreateReportOutput:
        result = await procedures.CreateReport()(request=structs.CreateReportRequest.from_proto(message))
        return reports_pb2.CreateReportOutput(result=result.to_proto())

    @process_exception(response_proto=client_rows_pb2.GetClientRowsOutput)
    async def get_client_rows(self, message: client_rows_pb2.GetClientRowsInput) -> client_rows_pb2.GetClientRowsOutput:
        params = structs.GetClientRowsInput.from_proto(message)
        result = await procedures.GetClientRows()(request=params)
        return client_rows_pb2.GetClientRowsOutput(result=result.to_proto())

    @process_exception(response_proto=campaigns_pb2.GetCampaignsOutput)
    async def get_campaigns(self, message: campaigns_pb2.GetCampaignsInput) -> campaigns_pb2.GetCampaignsOutput:
        result = await procedures.GetCampaigns()(request=structs.GetCampaignsInput.from_proto(message))
        return campaigns_pb2.GetCampaignsOutput(result=result.to_proto())

    @process_exception(response_proto=acts_pb2.GetActsOutput)
    async def get_acts(self, message: acts_pb2.GetActsInput) -> acts_pb2.GetActsOutput:
        result = await procedures.GetActs()(request=structs.GetActsInput.from_proto(message))
        return acts_pb2.GetActsOutput(result=result.to_proto())

    @process_exception(response_proto=client_rows_pb2.EditClientRowOutput)
    async def edit_client_row(self, message: client_rows_pb2.EditClientRowInput) -> client_rows_pb2.EditClientRowOutput:
        params = structs.EditClientRowInput.from_proto(message)
        await procedures.EditClientRow()(request=params)
        return client_rows_pb2.EditClientRowOutput(result=common_pb2.Empty())

    @process_exception(response_proto=clients_pb2.ClientShortInfoOutput)
    async def get_client_short_info(self, message: clients_pb2.ClientShortInfoInput) -> clients_pb2.ClientShortInfoOutput:
        request = structs.ClientShortInfoInput.from_proto(message)
        info = await procedures.GetClientShortInfo()(request=request)
        return clients_pb2.ClientShortInfoOutput(result=info.to_proto())

    @process_exception(response_proto=acts_pb2.AddActOutput)
    async def add_act(self, message: acts_pb2.AddActInput) -> acts_pb2.AddActOutput:
        result = await procedures.AddAct()(request=structs.AddActInput.from_proto(message))
        return acts_pb2.AddActOutput(result=result.to_proto())

    @process_exception(response_proto=acts_pb2.EditActOutput)
    async def edit_act(self, message: acts_pb2.EditActInput) -> acts_pb2.EditActOutput:
        await procedures.EditAct()(request=structs.EditActInput.from_proto(message))
        return acts_pb2.EditActOutput(result=common_pb2.Empty())

    @process_exception(response_proto=clients_pb2.CreateClientOutput)
    async def create_client(self, message: clients_pb2.CreateClientInput) -> clients_pb2.CreateClientOutput:
        result = await procedures.CreateClient()(request=structs.CreateClientInput.from_proto(message))
        return clients_pb2.CreateClientOutput(result=result.to_proto())

    @process_exception(response_proto=organizations_pb2.GetOrganizationsOutput)
    async def get_organizations(self, message: organizations_pb2.GetOrganizationsInput) -> organizations_pb2.GetOrganizationsOutput:
        result = await procedures.GetOrganizations()(request=structs.GetOrganizationsInput.from_proto(message))
        return organizations_pb2.GetOrganizationsOutput(result=result.to_proto())

    @process_exception(response_proto=contracts_pb2.ContractsList)
    async def get_contracts(self, message: contracts_pb2.GetContractsInput) -> contracts_pb2.GetContractsOutput:
        request = structs.GetContractsInput.from_proto(message)
        contracts = await procedures.GetContracts()(request=request)
        return contracts_pb2.GetContractsOutput(result=contracts.to_proto())

    @process_exception(response_proto=invites_pb2.GetInvitesOutput)
    async def get_invites(self, message: invites_pb2.GetInvitesInput) -> invites_pb2.GetInvitesOutput:
        result = await procedures.GetInvites()(request=structs.GetInvitesInput.from_proto(message))
        return invites_pb2.GetInvitesOutput(result=result.to_proto())
