# -*- coding: utf-8 -*-
from datetime import date

import yt.wrapper as yt

from travel.avia.library.python.references.station_to_settlement import StationToSettlementCache
from travel.avia.library.python.references.station import StationCache


class FlightsCollector(object):
    NATIONAL_VERSIONS_MAP = {
        1: 'ru',
        2: 'ua',
        3: 'kz',
        4: 'com',
        5: 'tr',
    }

    def __init__(self, yt_client, variants_log_path, station_cache, station_to_settlement_cache):
        # type: (yt.YtClient, str, StationCache, StationToSettlementCache) -> None
        self._yt_client = yt_client
        self._variants_log_path = variants_log_path
        self._station_cache = station_cache
        self._station_to_settlement = station_to_settlement_cache

    def collect_flights_to_yt_table(self, date):
        # type: (date) -> str
        source_table = yt.TablePath(
            yt.ypath_join(self._variants_log_path, date.strftime('%Y-%m-%d')),
            columns=['forward_segments', 'backward_segments', 'national_version_id']
        )
        output_table = self._yt_client.create_temp_table(
            path='//home/avia/tmp',
            prefix='avia_statistics_flights_',
            expiration_timeout=6 * 60 * 60 * 1000,  # 6 hours
        )
        self._yt_client.run_map_reduce(
            mapper=self._get_variant_to_flights_mapper(),
            reducer=self._flight_reducer,
            source_table=source_table,
            destination_table=output_table,
            reduce_by=['from_id', 'to_id', 'national_version', 'company_id', 'airline_code', 'flight_number', 'departure_date'],
        )
        return output_table

    def _validate_segment(self, segment, national_version_id):
        departure_station_id = segment.get('departure_station_id')
        arrival_station_id = segment.get('arrival_station_id')
        return all((
            national_version_id in self.NATIONAL_VERSIONS_MAP,
            segment['company_id'],
            segment['route'],
            segment['departure_time'],
            departure_station_id,
            self._map_station_id_to_settlement_id(departure_station_id),
            arrival_station_id,
            self._map_station_id_to_settlement_id(arrival_station_id),
        ))

    def _map_station_id_to_settlement_id(self, station_id):
        settlement_id = self._station_cache.settlement_id_by_id(station_id, raise_on_unknown=False)
        if settlement_id:
            return settlement_id
        return self._station_to_settlement.settlement_id_by_id(station_id, raise_on_unknown=False)

    def _segment_to_flight(self, segment, national_version_id):
        arrival_station_id = segment['arrival_station_id']
        departure_station_id = segment['departure_station_id']
        airline_code, flight_number = segment['route'].split()

        return {
            'from_id': self._map_station_id_to_settlement_id(departure_station_id),
            'to_id': self._map_station_id_to_settlement_id(arrival_station_id),
            'national_version': self.NATIONAL_VERSIONS_MAP[national_version_id],
            'company_id': segment['company_id'],
            'airline_code': airline_code,
            'flight_number': flight_number,
            'departure_date': self._map_departure_date(segment['departure_time']),
        }

    @staticmethod
    def _map_departure_date(departure_time):
        return int(date.fromtimestamp(departure_time).strftime('%Y%m%d'))

    def _get_variant_to_flights_mapper(self):
        def mapper(variant):
            if variant.get('forward_segments'):
                for segment in variant.get('forward_segments'):
                    if self._validate_segment(segment, variant.get('national_version_id')):
                        yield self._segment_to_flight(segment, variant.get('national_version_id'))

            if variant.get('backward_segments'):
                for segment in variant.get('backward_segments'):
                    if self._validate_segment(segment, variant.get('national_version_id')):
                        yield self._segment_to_flight(segment, variant.get('national_version_id'))
        return mapper

    @staticmethod
    def _flight_reducer(key, rows):
        yield key

    def iterate_flights(self, table):
        for record in self._yt_client.read_table(table):
            yield record
