import asyncio
import itertools

from travel.library.python.aioapp.utils import localize_dt
from travel.rasp.pathfinder_proxy.client_tariffs.train_api_result import make_empty_result, make_result
from travel.rasp.pathfinder_proxy.const import TTransport, UTC_TZ
from travel.rasp.pathfinder_proxy.tariff_collectors.base_tariff_collector import BaseTariffCollector, prepare_transfer_variants
from travel.rasp.pathfinder_proxy.tariff_storages.train_api_storage import TrainApiStorage


class TrainApiCollector(BaseTariffCollector):
    _TRANSPORT_TYPE = TTransport.TRAIN
    _CACHE_STORAGE = TrainApiStorage
    _MAX_ATTEMPTS = 20

    def __init__(self, *args, **kwargs):
        self._local_cache = {}
        self._polled_segment_keys = set()
        super().__init__(*args, **kwargs)

    async def _collect_prices(self, prepared_tariffs, tld, language):
        querying = False

        segment_keys = set(zip(
            prepared_tariffs.departure_points,
            prepared_tariffs.arrival_points,
            prepared_tariffs.departure_dts
        ))

        tasks = [self._get_price(segment_key, tld, language) for segment_key in segment_keys
                 if segment_key not in self._polled_segment_keys]

        for result in await asyncio.gather(*tasks):
            (result_tariffs, result_querying), segment_key = result

            if result_querying:
                querying = True
            else:
                self._polled_segment_keys.add(segment_key)

            if result_tariffs:
                self._local_cache[segment_key] = result_tariffs

        return list(itertools.chain.from_iterable(self._local_cache.values())), querying

    async def _get_price(self, segment_key, tld, language):
        return await super()._get_price(segment_key, tld, language), segment_key

    @staticmethod
    def _make_result(result, poll):
        return make_result(result)

    def _prepare_transfer_variants(self, transfer_variants):
        return prepare_transfer_variants(transfer_variants, {self._transport_code, TTransport.get_name(TTransport.SUBURBAN)})

    def _prepare_tariff(self, tariff):
        tariff.departure_dt = localize_dt(tariff.departure_dt, UTC_TZ)
        tariff.arrival_dt = localize_dt(tariff.arrival_dt, UTC_TZ)

        return {
            'key': (tariff.departure_station_id, tariff.arrival_station_id, tariff.departure_dt),
            'result': {
                'tariff': tariff.tariff,
                'number': tariff.number,
                'hasDynamicPricing': tariff.has_dynamic_pricing,
                'rawTrainName': tariff.raw_train_name,
                'provider': tariff.provider
            }
        }

    @staticmethod
    def _make_empty_result(querying):
        return make_empty_result(querying)
