import aiohttp
import logging

from tenacity import TryAgain
from datetime import datetime
from typing import Any

from async_clients.clients.base import BaseClient
from async_clients.auth_types import TVM2

from intranet.trip.src.lib.travel.exceptions import TravelError


logger = logging.getLogger(__name__)


class TravelClient(BaseClient):

    AUTH_TYPES = {TVM2}

    async def parse_response(self, response: aiohttp.ClientResponse, **kwargs) -> dict:
        if response.status in self.RETRY_CODES:
            raise TryAgain()
        try:
            response.raise_for_status()
        except aiohttp.ClientResponseError:
            message = (
                f'Got status: {response.status}, for request: '
                f'{response.method} {response.url}'
            )
            raise TravelError(
                content=message,
                status_code=response.status,
            )

        return await getattr(response, self.RESPONSE_TYPE)()


class TravelAviaClient(TravelClient):

    async def get_airline_info(self) -> dict[str, Any]:
        params = {
            'fields': 'default_tariff,carryon,baggage,iata',
        }
        return await self._make_request(
            method='get',
            path='/v1/backend/rest/airlines/airline_info',
            params=params,
        )

    async def get_covid_restrictions(self, country_code: int) -> dict[str, Any]:
        params = {
            'point_key': f'l{country_code}',
        }
        return await self._make_request(
            method='get',
            path='/v1/backend/rest/country-restrictions/covid-info',
            params=params,
        )


class TravelTrainClient(TravelClient):

    async def get_train_details(
        self,
        station_from: int,
        station_to: int,
        when: datetime,
        number: str,
    ) -> dict:
        params = {
            'stationFrom': station_from,
            'stationTo': station_to,
            'when': when.strftime("%Y-%m-%dT%H:%M:%S"),
            'number': number,
        }
        return await self._make_request(
            method='get',
            path='/ru/api/internal/train-details/',
            params=params,
        )
