# coding=utf-8
from __future__ import unicode_literals

from collections import namedtuple

from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex


FlightBaseData = namedtuple(
    'FlightBaseData',
    [
        'operating_carrier',
        'operating_flight_number',
        'leg_number',
        'departure_station_iata',
        'departure_time',
        'departure_terminal',
        'arrival_station_iata',
        'arrival_time',
        'arrival_terminal',
        'aircraft_model',
        'arrival_day_shift',
        'departure_day_shift',
    ]
)


class FlightBaseDataFactory(object):

    def __init__(self, start_date, arrival_times_cacher, logger):
        self._date_index = DateIndex(start_date)
        self._arrival_times_cacher = arrival_times_cacher
        self._skipped_flights = set()
        self._logger = logger

    def new_flight_base_data(self, flight_segment):
        if not flight_segment.route:
            return None
        if flight_segment.stops != '0':
            return None
        stops = flight_segment.route.split('-')
        if len(stops) < 2:
            raise Exception('Invalid number of stops: {}'.format(flight_segment.__dict__))
        dep_index = self._date_index.get_index_for_date_str(flight_segment.dep_date)
        arrival_index = self._date_index.get_index_for_date_str(flight_segment.arrival_date)
        leg_number = flight_segment.get_leg_number()
        if not leg_number:
            self._logger.error('Invalid leg number: %s %d', flight_segment, leg_number)
            self._skipped_flights.add(flight_segment)
            return None

        departure_day_shift = 0
        if leg_number > 1:
            prev_leg_same_day_arrival = self._arrival_times_cacher.get_arrival(
                flight_segment.operating_carrier,
                flight_segment.operating_flight_code,
                flight_segment.dep_date,
                leg_number-1,
            )
            prev_leg_prev_day_arrival = self._arrival_times_cacher.get_arrival(
                flight_segment.operating_carrier,
                flight_segment.operating_flight_code,
                self._date_index.get_date_str(dep_index-1),
                leg_number-1,
            )
            fs_dep_time = flight_segment.get_dep_time_int()
            prev_day_predicate = prev_leg_prev_day_arrival >= 0 and prev_leg_prev_day_arrival > fs_dep_time
            same_day_predicate = prev_leg_same_day_arrival < 0 or prev_leg_same_day_arrival > fs_dep_time
            if prev_day_predicate and same_day_predicate:
                departure_day_shift = 1

        return FlightBaseData(
            operating_carrier=flight_segment.operating_carrier,
            operating_flight_number=flight_segment.operating_flight_code,
            leg_number=leg_number,
            departure_station_iata=flight_segment.board_airport,
            departure_time=flight_segment.dep_time,
            departure_terminal=flight_segment.board_terminal,
            arrival_station_iata=flight_segment.off_airport,
            arrival_time=flight_segment.arrival_time,
            arrival_terminal=flight_segment.off_terminal,
            aircraft_model=flight_segment.aircraft_type,
            arrival_day_shift=arrival_index-dep_index,
            departure_day_shift=departure_day_shift,
        )

    def log_skipped_flights(self):
        for flight_segment in self._skipped_flights:
            self._logger.info('Skipped flight: %s %s', flight_segment.operating_carrier, flight_segment.operating_flight_code)
