import logging

from travel.library.python.aioapp.utils import localize_dt

from travel.rasp.pathfinder_proxy.const import TTransport, UTC_TZ

logger = logging.getLogger(__name__)


class BaseService:
    _TRANSPORT_TYPE = None
    _COLLECTOR = None

    def __init__(self, client, cache):
        self._transport_code = TTransport.get_name(self._TRANSPORT_TYPE)
        self.client = client
        self.cache = cache

    def get_collector(self):
        return self._COLLECTOR(self.client, self.cache, self._transport_code)

    def _merge_variants_with_tariffs(self, transfer_variants, tariffs, querying):
        if None in tariffs:
            logger.debug(f'None in tariffs ({self._transport_code})')
        tariffs_by_key = {tariff['key']: tariff['result'] for tariff in tariffs if tariff is not None}
        for variant in transfer_variants:
            self._merge_variant_with_tariffs(variant, tariffs_by_key, querying)

    @staticmethod
    def _get_segment_key(segment):
        return (
            int(segment['stationFrom']['id']),
            int(segment['stationTo']['id']),
            localize_dt(segment['departure'][:-9], UTC_TZ)
        )

    def _merge_variant_with_tariffs(self, variant, tariffs_by_key, querying):
        for segment in variant['segments']:
            if segment['transport']['code'] != self._transport_code:
                continue
            tariff = tariffs_by_key.get(self._get_segment_key(segment))
            if tariff is not None:
                self._merge_segment(segment, tariff)

    @staticmethod
    def _merge_segment(segment, tariff):
        raise NotImplementedError()

    async def iter_variants_with_tariffs(self, transfer_variants, tld='ru', language='ru'):
        async for tariffs, querying in self.get_collector().iter_tariffs_for_variants(transfer_variants, tld, language):
            self._merge_variants_with_tariffs(transfer_variants, tariffs, querying)
            yield
