from contextlib import asynccontextmanager
from typing import Any, Optional

from aiohttp import ClientResponse

from mail.payments.payments.conf import settings
from mail.payments.payments.core.entities.spark import SparkData, SparkMethod
from mail.payments.payments.interactions.base import AbstractInteractionClient
from mail.payments.payments.interactions.spark.request import SparkRequest
from mail.payments.payments.interactions.spark.response import SparkResponse
from mail.payments.payments.utils.helpers import is_entrepreneur_by_inn


class SparkClient(AbstractInteractionClient):
    SERVICE = 'spark'
    BASE_URL = settings.SPARK_URL
    TVM_ID = None
    TVM_SESSION_CLS = None

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self._login = kwargs.get('login', settings.SPARK_AUTH_LOGIN)
        self._password = kwargs.get('password', settings.SPARK_AUTH_PASSWORD)

    def _get_session_kwargs(self) -> dict:
        return {
            **super()._get_session_kwargs(),
            **self._get_timeout_kwargs(settings.SPARK_TIMEOUT),
        }

    async def _make_request(self, interaction_method: str, method: str, url: str, **kwargs: Any) -> Any:
        headers = kwargs.get('headers', {})
        headers['SOAPAction'] = f'http://interfax.ru/ifax/{interaction_method}'
        headers['Content-Type'] = 'text/xml; charset=utf-8'
        kwargs['headers'] = headers

        return await super()._make_request(interaction_method, method, url, **kwargs)

    async def _process_response(self, response: ClientResponse, interaction_method: str) -> Any:
        if response.status >= 400:
            await self._handle_response_error(response)
        return await response.read()

    async def _spark_session_start(self) -> None:
        response = await self.post(
            SparkMethod.AUTH.value,
            self.BASE_URL,
            data=SparkRequest.auth(login=self._login, password=self._password)
        )
        SparkResponse.check_auth_response(response)

    async def _spark_session_end(self) -> None:
        response = await self.post(
            SparkMethod.END.value,
            self.BASE_URL,
            data=SparkRequest.end()
        )
        await self.close()
        SparkResponse.check_end_response(response)

    @asynccontextmanager
    async def _spark_session(self):
        await self._spark_session_start()
        try:
            yield
        finally:
            await self._spark_session_end()

    async def _get_entrepreneur_info(self, inn: str) -> SparkData:
        response = await self.post(
            SparkMethod.GET_ENTREPRENEUR.value,
            self.BASE_URL,
            data=SparkRequest.get_entrepreneur(inn=inn)
        )
        return SparkResponse.parse_entrepreneur_response(response)

    async def _get_company_info(self, inn: str, spark_id: Optional[int] = None) -> SparkData:
        response = await self.post(
            SparkMethod.GET_COMPANY.value,
            self.BASE_URL,
            data=SparkRequest.get_company(inn=inn, spark_id=spark_id)
        )
        return SparkResponse.parse_company_response(response)

    async def get_info(self, inn: str, spark_id: Optional[int] = None) -> SparkData:
        async with self._spark_session():
            if is_entrepreneur_by_inn(inn):
                return await self._get_entrepreneur_info(inn)
            else:
                return await self._get_company_info(inn, spark_id)
