import pytz

from travel.library.python.aioapp.utils import localize_dt
from travel.rasp.pathfinder_proxy.client_tariffs.ticket_daemon_result import make_empty_interlines, make_interlines
from travel.rasp.pathfinder_proxy.const import MSK_TZ, UTC_TZ
from travel.rasp.pathfinder_proxy.tariff_collectors.base_tariff_collector import (BaseTariffCollector, PreparedTariffs,
                                                                                  make_station_key)
from travel.rasp.pathfinder_proxy.tariff_storages.interline_storage import InterlineStorage


def prepare_transfer_variants(transfer_variants, transport_code):
    prepared_tariffs = PreparedTariffs()
    for variant in transfer_variants:
        segments = variant['segments']
        if any(segment['transport']['code'] != transport_code for segment in segments):
            continue
        departure_segment = segments[0]
        arrival_segment = segments[-1]

        departure = localize_dt(departure_segment['departure'][:-9], pytz.UTC)
        msk_departure = departure.astimezone(MSK_TZ)
        if prepared_tariffs.departure_date_from is None or prepared_tariffs.departure_date_from > msk_departure:
            prepared_tariffs.departure_date_from = msk_departure
        if prepared_tariffs.departure_date_to is None or prepared_tariffs.departure_date_to < msk_departure:
            prepared_tariffs.departure_date_to = msk_departure
        prepared_tariffs.departure_points.append(make_station_key(departure_segment['stationFrom']))
        prepared_tariffs.arrival_points.append(make_station_key(arrival_segment['stationTo']))
        prepared_tariffs.departure_dts.append(departure)

    return prepared_tariffs


class InterlineCollector(BaseTariffCollector):
    _CACHE_STORAGE = InterlineStorage

    def _prepare_transfer_variants(self, transfer_variants):
        return prepare_transfer_variants(transfer_variants, self._transport_code)

    async def _get_price_result(self, segment_key, poll, tld, language):
        departure_point, arrival_point, departure_dt = segment_key
        result = await self._client.get_prices(departure_point, arrival_point, departure_dt,
                                               poll=poll, tld=tld, language=language, include_interlines=True)
        return self._make_result(result, poll)

    @staticmethod
    def _prepare_tariff(tariffs):
        variants = tariffs.variants
        if not variants:
            return None

        min_variant = variants[0]
        for variant in variants[1:]:
            if variant['tariff']['value'] < min_variant['tariff']['value']:
                min_variant = variant

        if min_variant['tariff']['currency'] == 'RUR':
            min_variant['tariff']['currency'] = 'RUB'

        for segment in tariffs.segments:
            segment.departure_dt = localize_dt(segment.departure_dt, UTC_TZ)
            segment.arrival_dt = localize_dt(segment.arrival_dt, UTC_TZ)

        key = tuple(
            (segment.departure_station_id, segment.arrival_station_id, segment.departure_dt, segment.number)
            for segment in tariffs.segments
        )

        return {
            'key': key,
            'result': {
                'electronicTicket': False,
                'classes': {
                    'economy': {
                        'partner': min_variant['partner'],
                        'price': min_variant['tariff'],
                        'orderUrl': min_variant['deep_link']
                    }
                }
            }
        }

    @staticmethod
    def _make_result(result, poll):
        return make_interlines(result, poll)

    @staticmethod
    def _make_empty_result(querying):
        return make_empty_interlines(querying)
