import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime

from travel.library.python.aioapp.utils import localize_dt
from travel.rasp.library.python.httpclient.async_client import ApiError
from travel.rasp.pathfinder_proxy.const import MSK_TZ, UTC_TZ, Status
from travel.rasp.pathfinder_proxy.logs import log_proxy_tariffs

logger = logging.getLogger('tariff_collector')


def make_station_key(station):
    return 's{}'.format(station['id'])


def prepare_transfer_variants(transfer_variants, transport_codes):
    prepared_tariffs = PreparedTariffs()
    for variant in transfer_variants:
        for segment in variant['segments']:
            if segment['transport']['code'] not in transport_codes:
                continue
            departure = localize_dt(segment['departure'][:-9], UTC_TZ)
            msk_departure = departure.astimezone(MSK_TZ)
            prepared_tariffs.stations_by_id[segment['stationFrom']['id']] = segment['stationFrom']
            prepared_tariffs.stations_by_id[segment['stationTo']['id']] = segment['stationTo']
            if prepared_tariffs.departure_date_from is None or prepared_tariffs.departure_date_from > msk_departure:
                prepared_tariffs.departure_date_from = msk_departure
            if prepared_tariffs.departure_date_to is None or prepared_tariffs.departure_date_to < msk_departure:
                prepared_tariffs.departure_date_to = msk_departure
            prepared_tariffs.departure_points.append(make_station_key(segment['stationFrom']))
            prepared_tariffs.arrival_points.append(make_station_key(segment['stationTo']))
            prepared_tariffs.departure_dts.append(departure)

    return prepared_tariffs


@dataclass
class PreparedTariffs:
    departure_points: list = field(default_factory=list)
    arrival_points: list = field(default_factory=list)
    departure_dts: list = field(default_factory=list)
    departure_date_from: datetime = None
    departure_date_to: datetime = None
    stations_by_id: dict = field(default_factory=dict)


class BaseTariffCollector:
    _CACHE_STORAGE = None
    _MAX_ATTEMPTS = 10

    def __init__(self, client, cache, transport_code):
        self._client = client
        self._storage = self._CACHE_STORAGE(cache)
        self._transport_code = transport_code

    async def iter_tariffs_for_variants(self, transfer_variants, tld, language):
        try:
            prepared_tariffs = self._prepare_transfer_variants(transfer_variants)

            timeout = 1
            for count in range(self._MAX_ATTEMPTS):
                if count:
                    await asyncio.sleep(timeout)
                    timeout += 1
                tariff_infos, querying = await self._collect_prices(prepared_tariffs, tld, language)
                yield tariff_infos, querying
                if not querying:
                    break

            log_line = 'Attempts count: {}. Querying: {}. Transport code: {}.'.format(
                count,
                querying,
                self._transport_code
            )
            if querying:
                log_line += ' Prepared tariffs: {}.'.format(repr(prepared_tariffs))

            logger.debug(log_line)

        except Exception as ex:
            logger.exception('get_tariffs_for_variants error: {}'.format(repr(ex)))
            raise

    async def get_tariffs_for_variants(self, transfer_variants, tld, language):
        async for tariff_infos, querying in self.iter_tariffs_for_variants(transfer_variants, tld, language):
            pass
        return tariff_infos, querying

    def _prepare_transfer_variants(self, transfer_variants):
        return prepare_transfer_variants(transfer_variants, {self._transport_code})

    async def _collect_prices(self, prepared_tariffs, tld, language):
        tariffs = []
        querying = False

        tasks = [self._get_price(segment_key, tld, language) for segment_key in set(zip(
            prepared_tariffs.departure_points,
            prepared_tariffs.arrival_points,
            prepared_tariffs.departure_dts
        ))]

        for result in await asyncio.gather(*tasks):
            if not result:
                continue
            result_tariffs, result_querying = result
            if result_querying:
                querying = True
            if not result_tariffs:
                continue
            else:
                tariffs.extend(result_tariffs)

        return tariffs, querying

    async def _get_price(self, segment_key, tld, language):
        departure_point, arrival_point, departure_dt = segment_key
        storage_key = departure_point, arrival_point, departure_dt, tld, language

        try:
            tariffs, querying = await self._storage.get(storage_key)
        except Exception as ex:
            logger.exception('_get_price get exception: {}'.format(repr(ex)))
            tariffs, querying = [], False

        if not tariffs or querying:
            try:
                result = await self._get_price_result(segment_key, querying, tld, language)
                await self._storage.set(storage_key, result)
                return self._prepare_tariffs(result.tariffs), result.querying
            except ApiError as ex:
                querying = ex.status in [402, 429, 499, 500]
                await self._storage.set(storage_key, self._make_empty_result(querying))
                return [], querying
            except Exception as ex:
                logger.exception('_get_price set exception: {}'.format(repr(ex)))
                return [], False

        else:
            return self._prepare_tariffs(tariffs), querying

    async def _get_price_result(self, segment_key, poll, tld, language):
        departure_point, arrival_point, departure_dt = segment_key
        try:
            result = await self._client.get_prices(departure_point, arrival_point, departure_dt,
                                                   poll=poll, tld=tld, language=language, asker='pathfinder')
            status = Status.DONE.value
        except Exception as ex:
            log_proxy_tariffs(
                cache_type=self._storage._CACHE_TYPE,
                point_from=departure_point,
                point_to=arrival_point,
                when=departure_dt,
                tld=tld,
                language=language,
                transport_type=self._transport_code,
                status=repr(ex),
                tariffs=None,
                poll=poll
            )
            raise

        log_proxy_tariffs(
            cache_type=self._storage._CACHE_TYPE,
            point_from=departure_point,
            point_to=arrival_point,
            when=departure_dt,
            tld=tld,
            language=language,
            transport_type=self._transport_code,
            status=status,
            tariffs=result,
            poll=poll
        )

        return self._make_result(result, poll)

    def _prepare_tariffs(self, tariffs):
        return list(map(self._prepare_tariff, tariffs))

    @staticmethod
    def _make_result(result, poll):
        raise NotImplementedError()

    @staticmethod
    def _prepare_tariff(tariff):
        raise NotImplementedError()

    @staticmethod
    def _make_empty_result(querying):
        raise NotImplementedError()
