from travel.rasp.pathfinder_proxy.const import TTransport
from travel.rasp.pathfinder_proxy.services.base_service import BaseService
from travel.rasp.pathfinder_proxy.tariff_collectors.train_api_collector import TrainApiCollector


class TrainApiService(BaseService):
    def __init__(self, client, cache, settings):
        super().__init__(client, cache)
        self._prefix = settings.TRAINS_URL_PREFIX
        self._front_suburban_tariffs = settings.FRONT_SUBURBAN_TARIFFS

    _TRANSPORT_TYPE = TTransport.TRAIN
    _COLLECTOR = TrainApiCollector

    @staticmethod
    def _merge_segment(segment, tariff):
        segment['tariffs'] = tariff['tariff']
        segment['hasDynamicPricing'] = tariff['hasDynamicPricing']
        segment['rawTrainName'] = tariff['rawTrainName']
        segment['provider'] = tariff['provider']

        if segment['transport']['code'] == TTransport.get_name(TTransport.SUBURBAN):
            segment['hasTrainTariffs'] = True

    def _merge_variant_with_tariffs(self, variant, tariffs_by_key, querying):
        suburban_code = TTransport.get_name(TTransport.SUBURBAN)

        if (
            querying
            and all(segment['transport']['code'] == self._transport_code for segment in variant['segments'])
            and any(self._get_segment_key(segment) not in tariffs_by_key for segment in variant['segments'])
        ):
            # do not merge partial tariffs
            return

        for segment in variant['segments']:
            if not ((segment['transport']['code'] == self._transport_code) or
                    (segment['transport']['code'] == suburban_code and segment['trainNumbers'])):
                continue

            tariff = tariffs_by_key.get(self._get_segment_key(segment))
            if tariff is not None:
                if (segment['transport']['code'] == self._transport_code or
                        (tariff['number'] in segment['trainNumbers'] and self._front_suburban_tariffs)):
                    self._merge_segment(segment, tariff)
