from enum import Enum
from typing import Optional, Type, TypeVar, cast

from mail.payments.payments.core.actions.base.merchant import BaseMerchantAction
from mail.payments.payments.core.actions.merchant.check_roles import CheckUpdateMerchantRolesAction
from mail.payments.payments.core.actions.merchant.functionality import put_merchant_functionality
from mail.payments.payments.core.actions.mixins.callback_task import APICallbackTaskMixin
from mail.payments.payments.core.entities.enums import CallbackMessageType, MerchantDraftPolicy, PersonType
from mail.payments.payments.core.entities.merchant import (
    AddressData, BankData, Merchant, MerchantData, OrganizationData, PersonData
)
from mail.payments.payments.core.entities.not_fetched import NOT_FETCHED
from mail.payments.payments.core.entities.service import Service
from mail.payments.payments.core.exceptions import CoreFieldError, InnModificationError, MerchantIsAlreadyRegistered

Data = TypeVar('Data', MerchantData, AddressData, BankData, OrganizationData, PersonData)


class UpdateMerchantAction(APICallbackTaskMixin, BaseMerchantAction):
    check_moderation_disapproved = True
    draft_policy = MerchantDraftPolicy.MERCHANT_DRAFT_FORBIDDEN
    for_update = True
    skip_oauth = True
    transact = True

    def __init__(self, params: dict, check_not_registered: bool = True, send_notifications: bool = False):
        super().__init__(params.pop('uid'))
        self.params = params
        self.check_not_registered = check_not_registered
        self.send_notifications = send_notifications

    async def pre_handle(self):
        await super().pre_handle()
        await CheckUpdateMerchantRolesAction().run()

    def _check_inn_not_changed(self) -> None:
        if 'organization' in self.params \
                and self.data.organization is not None \
                and self.data.organization.inn is not None:
            if self.params['organization'] is None:
                raise InnModificationError

            if self.data.organization.inn != self.params['organization'].get('inn', self.data.organization.inn):
                raise InnModificationError

    @staticmethod
    def _check_has_attribute(data: Data, attrib_name: str, object_name: Optional[str] = None) -> None:
        if not hasattr(data, attrib_name):
            key = attrib_name if not object_name else f'{object_name}.{attrib_name}'
            raise CoreFieldError(fields={key: 'Unexpected field name'})

    def _patch(self, field_name: str) -> bool:
        self._check_has_attribute(self.data, field_name)
        if field_name not in self.params:
            return True
        if self.params[field_name] is None:
            setattr(self.data, field_name, None)
            return True
        return False

    @staticmethod
    def _create_dataclass_instance(field_name: str, data_cls: Type[Data], values: dict) -> Data:
        try:
            return data_cls(**values)
        except TypeError:
            raise CoreFieldError(fields={
                field_name: "Can't create new record with provided values"
            })

    def _patch_or_create_dict_data(self, field_name: str, data_cls: Type[Data]) -> None:
        ''' Patch existing props or try to create new dataclass with provided values '''
        if self._patch(field_name):
            return
        if (data_field := getattr(self.data, field_name)) is None:
            new_item = self._create_dataclass_instance(field_name, data_cls, self.params[field_name])
            setattr(self.data, field_name, new_item)
        else:
            for k, v in self.params[field_name].items():
                self._check_has_attribute(data_field, k, field_name)
                setattr(data_field, k, v)

    def _patch_or_create_dict_list_data(self,
                                        field_name: str,
                                        data_cls: Type[Data],
                                        key_name: str = 'type',
                                        enum_cls: Optional[Type[Enum]] = None) -> None:
        ''' Patch/create lists of dataclasses (PersonData's, AddressData's) '''
        if self._patch(field_name):
            return
        for key, value_dict in self.params[field_name].items():
            key = key if not enum_cls else enum_cls(key)
            data_list = getattr(self.data, field_name) or []
            data_list_item = next(filter(lambda item: getattr(item, key_name) == key, data_list), None)
            if not data_list_item:
                if value_dict:
                    values = {
                        **value_dict,
                        key_name: key
                    }
                    new_item = self._create_dataclass_instance(field_name, data_cls, values)
                    data_list.append(new_item)
                    setattr(self.data, field_name, data_list)
            else:
                if value_dict:
                    for k, v in value_dict.items():
                        self._check_has_attribute(data_list_item, k, field_name)
                        setattr(data_list_item, k, v)
                elif value_dict is None:
                    data_list.remove(data_list_item)

    async def _send_notifications(self, merchant: Merchant) -> None:
        async for service_merchant in self.storage.service_merchant.find(merchant.uid):
            service_id = service_merchant.service_id
            async for service_client in self.storage.service_client.find(service_id=service_id, with_service=True):
                service = cast(Service, service_client.service)
                service_client.service = NOT_FETCHED
                service.service_merchant = service_merchant
                service.service_client = service_client
                callback_message_type = CallbackMessageType.MERCHANT_REQUISITES_UPDATED
                await self.create_service_merchant_callback_task(service, callback_message_type)

    async def handle(self) -> Merchant:
        self.logger.context_push(uid=self.uid)

        assert self.merchant
        assert self.merchant.data

        if self.check_not_registered and self.merchant.registered:
            raise MerchantIsAlreadyRegistered

        self.data = self.merchant.data
        self._check_inn_not_changed()

        # patch simple fields
        if 'name' in self.params:
            self.merchant.name = self.params['name']
        if 'username' in self.params:
            self.merchant.data.username = self.params['username']

        # patch flat data
        self._patch_or_create_dict_data('bank', BankData)
        self._patch_or_create_dict_data('organization', OrganizationData)

        # patch data with key (it is passed as nested dicts in action input)
        self._patch_or_create_dict_list_data('addresses', AddressData)
        self._patch_or_create_dict_list_data('persons', PersonData, enum_cls=PersonType)

        self.logger.context_push(merchant=self.merchant)

        if 'functionality' in self.params:
            await put_merchant_functionality[self.params['functionality'].type](
                merchant=self.merchant,
                data=self.params['functionality'],
            ).run()

        merchant = await self.storage.merchant.save(self.merchant)
        merchant.load_data()

        if self.send_notifications:
            await self._send_notifications(merchant)

        return merchant
