from datetime import timedelta
from typing import Optional

from sendr_interactions import exceptions as interaction_errors
from sendr_utils import utcnow

from mail.payments.payments.core.actions.base.merchant import BaseMerchantAction
from mail.payments.payments.core.actions.shop.get_or_ensure_default import GetOrEnsureDefaultShopAction
from mail.payments.payments.core.entities.enums import MerchantOAuthMode, ShopType
from mail.payments.payments.core.entities.merchant import Merchant
from mail.payments.payments.core.entities.merchant_oauth import MerchantOAuth, OAuthToken
from mail.payments.payments.core.entities.shop import Shop
from mail.payments.payments.core.exceptions import (
    ChildMerchantError, KassaMeError, OAuthAlreadyExistsError, OAuthCodeError, ShopNotFoundError
)
from mail.payments.payments.interactions.exceptions import OAuthClientError
from mail.payments.payments.storage.exceptions import MerchantOAuthAlreadyExistsStorageError, ShopNotFound


class OAuthCompleteMerchantAction(BaseMerchantAction):
    skip_parent = True
    skip_data = True
    skip_moderation = True

    def __init__(self,
                 code: str, uid: Optional[int] = None,
                 shop_id: Optional[int] = None,
                 merchant: Optional[Merchant] = None):
        super().__init__(uid=uid, merchant=merchant)
        self.code = code
        self.shop_id = shop_id

    async def _get_shop(self, shop_id: Optional[int], mode: MerchantOAuthMode) -> Shop:
        shop_type = ShopType.from_oauth_mode(mode)
        uid: int = self.uid  # type: ignore
        if shop_id is not None:
            try:
                return await self.storage.shop.get(uid, shop_id, shop_type=shop_type)
            except ShopNotFound:
                raise ShopNotFoundError
        else:
            assert self.merchant and self.merchant.uid
            return await GetOrEnsureDefaultShopAction(uid=self.merchant.uid, default_shop_type=shop_type).run()

    async def handle(self) -> MerchantOAuth:
        assert self.merchant and self.merchant.uid
        if self.merchant.parent_uid:
            raise ChildMerchantError

        try:
            oauth_token: OAuthToken = await self.clients.oauth.get_token(self.code)
        except OAuthClientError as exc:
            self.logger.context_push(error=exc.params)
            self.logger.error('OAuth get_token error')
            raise OAuthCodeError

        try:
            data = await self.clients.kassa.me(oauth_token.access_token)
        except interaction_errors.BaseInteractionError:
            self.logger.error('Kassa me error')
            raise KassaMeError

        mode = MerchantOAuthMode.TEST if data['test'] else MerchantOAuthMode.PROD
        shop = await self._get_shop(self.shop_id, mode)

        merchant_oauth = MerchantOAuth(
            uid=self.merchant.uid,
            shop_id=shop.shop_id,  # type: ignore
            expires=utcnow() + timedelta(seconds=oauth_token.expires_in),
            mode=mode,
            data=data
        )
        merchant_oauth.decrypted_access_token = oauth_token.access_token
        merchant_oauth.decrypted_refresh_token = oauth_token.refresh_token

        try:
            merchant_oauth = await self.storage.merchant_oauth.create(merchant_oauth)
        except MerchantOAuthAlreadyExistsStorageError:
            raise OAuthAlreadyExistsError

        return merchant_oauth
