# coding=utf-8
from __future__ import unicode_literals

from travel.proto.dicts.rasp.carrier_pb2 import TCarrier
from typing import Iterable, Dict

from collections import defaultdict
from datetime import date, datetime, timedelta

from travel.avia.shared_flights.lib.python.consts.consts import MIN_SIRENA_FLIGHT_BASE_ID, MIN_SIRENA_FLIGHT_PATTERN_ID
from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_matcher import DateMatcher
from travel.avia.shared_flights.lib.python.date_utils.date_shift import shift_week_days
from travel.avia.shared_flights.lib.python.db_models.stop_point import StopPoint
from travel.avia.shared_flights.data_importer.date_shift_calculator import DateShiftCalculator
from travel.proto.shared_flights.snapshots.station_with_codes_pb2 import TStationWithCodes
from travel.proto.shared_flights.ssim.flights_pb2 import TFlightBase, TFlightPattern
from travel.avia.shared_flights.data_importer.storage.missing_data import sirena_route_extra


class SirenaFlightsParser(object):

    def __init__(
        self,
        logger,
        carriers: Iterable[TCarrier],
        stations: Dict[int, TStationWithCodes],
        timezones,
        missing_data,
        start_date=None,
    ):
        self._logger = logger
        self._carriers: Iterable[TCarrier] = carriers
        self._stations: Iterable[TStationWithCodes] = stations.values()
        self._date_shift_calculator = DateShiftCalculator(stations, timezones, logger)
        self.missing_data = missing_data
        self._start_date = start_date if start_date else datetime.now() - timedelta(days=75)
        self._date_index = DateIndex(self._start_date)
        self._date_matcher = DateMatcher(self._date_index)

    def parse_data(self, sirena_airlines, routes):
        self._logger.info('Parsing data from Sirena')

        date_index = DateIndex(datetime.now())

        unknown_stop_points = {}
        unknown_carriers = set()
        unknown_departure_stations = defaultdict(set)
        unknown_arrival_stations = defaultdict(set)
        flight_bases = {}
        flight_patterns = defaultdict(list)
        try:
            flight_base_id = MIN_SIRENA_FLIGHT_BASE_ID
            flight_pattern_id = MIN_SIRENA_FLIGHT_PATTERN_ID
            ivi = 0  # itinerary variation identifier (the term from SSIM format)
            for codeshare_mode in [False, True]:
                for route in routes:
                    carrier = self.get_carrier(route.CarrierCode)
                    carrier_id = carrier.Id if carrier else 0

                    if not carrier_id:
                        unknown_carriers.add(route.CarrierCode)
                        self.missing_data.add_carrier(route.CarrierCode, meta={
                            'type': 'missing carrier',
                            'source': 'sirena',
                            'route': sirena_route_extra(route),
                        })
                        continue

                    iata = carrier.Iata if carrier.Iata else route.CarrierCode

                    flight_patterns_to_process = [fp for fp in route.FlightPatterns if fp.IsCodeshare == codeshare_mode]

                    for sirena_flight_pattern in flight_patterns_to_process:
                        for leg_number_index in range(len(sirena_flight_pattern.StopPoints) - 1):
                            flight_base_id += 1
                            flight_pattern_id += 1
                            ivi += 1
                            flight_number = route.FlightNumber.lstrip(' ').lstrip('0')
                            leg_number = leg_number_index + 1
                            leg_key = SirenaFlightsParser.get_leg_key(
                                carrier.Iata if carrier.Iata else carrier_id,
                                flight_number,
                                leg_number,
                                ivi,
                            )
                            bucket_key = SirenaFlightsParser.get_bucket_key(carrier_id, flight_number, leg_number)
                            sirena_flight_pattern.OperatingFromDate = date_index.adjust_intdate(
                                sirena_flight_pattern.OperatingFromDate)
                            sirena_flight_pattern.OperatingUntilDate = date_index.adjust_intdate(
                                sirena_flight_pattern.OperatingUntilDate)

                            departure_point = sirena_flight_pattern.StopPoints[leg_number_index]
                            arrival_point = sirena_flight_pattern.StopPoints[leg_number_index + 1]

                            if not departure_point:
                                self._logger.error('Empty departure point: %s', leg_key)
                                continue

                            if not arrival_point:
                                self._logger.error('Empty arrival point: %s', leg_key)
                                continue

                            departure_station = self.get_station(departure_point)
                            arrival_station = self.get_station(arrival_point)

                            if not departure_station:
                                departure_station_code = departure_point.StationCode or departure_point.CityCode
                                self.missing_data.add_station(
                                    departure_station_code,
                                    meta={
                                        'type': 'departure station',
                                        'source': 'sirena',
                                        'route': sirena_route_extra(route),
                                    })
                                unknown_departure_stations[departure_station_code].add(leg_key)
                                unknown_stoppoint = StopPoint()
                                unknown_stoppoint.merge(departure_point, leg_key)
                                unknown_stop_points[str(unknown_stoppoint)] = unknown_stoppoint
                                continue

                            if not arrival_station:
                                arrival_station_code = arrival_point.StationCode or arrival_point.CityCode
                                self.missing_data.add_station(
                                    arrival_station_code,
                                    meta={
                                        'type': 'arrival station',
                                        'source': 'sirena',
                                        'route': sirena_route_extra(route),
                                    })
                                unknown_arrival_stations[arrival_station_code].add(leg_key)
                                unknown_stoppoint = StopPoint()
                                unknown_stoppoint.merge(arrival_point, leg_key)
                                unknown_stop_points[str(unknown_stoppoint)] = unknown_stoppoint
                                continue

                            flight = TFlightBase()
                            flight.Id = flight_base_id
                            flight.OperatingCarrier = carrier_id
                            flight.OperatingCarrierIata = iata
                            flight.OperatingFlightNumber = flight_number
                            flight.LegSeqNumber = leg_number
                            flight.ItineraryVariationIdentifier = str(ivi)
                            flight.ServiceType = 'J'

                            flight.DepartureStation = departure_station.Station.Id
                            flight.DepartureStationIata = SirenaFlightsParser.get_code(departure_station)
                            flight.ScheduledDepartureTime = departure_point.DepartureTime
                            flight.DepartureTerminal = departure_point.Terminal

                            flight.ArrivalStation = arrival_station.Station.Id
                            flight.ArrivalStationIata = SirenaFlightsParser.get_code(arrival_station)
                            flight.ScheduledArrivalTime = arrival_point.ArrivalTime
                            flight.ArrivalTerminal = arrival_point.Terminal
                            flight.AircraftModel = sirena_flight_pattern.AircraftModel
                            flight.IntlDomesticStatus = ''
                            flight.BucketKey = bucket_key

                            flight_pattern = TFlightPattern()
                            flight_pattern.Id = flight_pattern_id
                            flight_pattern.FlightId = flight_base_id
                            flight_pattern.FlightLegKey = leg_key
                            flight_pattern.LegSeqNumber = leg_number
                            flight_pattern.OperatingFromDate = SirenaFlightsParser.shift_date(
                                sirena_flight_pattern.OperatingFromDate,
                                departure_point.DepartureDayShift,
                            )
                            flight_pattern.OperatingUntilDate = SirenaFlightsParser.shift_date(
                                sirena_flight_pattern.OperatingUntilDate,
                                departure_point.DepartureDayShift,
                            )
                            flight_pattern.OperatingOnDays = shift_week_days(
                                sirena_flight_pattern.OperatingOnDays,
                                departure_point.DepartureDayShift,
                            )

                            flight_pattern.MarketingCarrier = carrier_id
                            flight_pattern.MarketingCarrierIata = iata
                            flight_pattern.MarketingFlightNumber = flight_number
                            flight_pattern.BucketKey = bucket_key
                            flight_pattern.ArrivalDayShift = self._date_shift_calculator.calculate_arrival_day_shift(
                                flight_pattern,
                                flight,
                            )
                            if flight_pattern.LegSeqNumber > 1:
                                flight_pattern.DepartureDayShift = departure_point.DepartureDayShift

                            if not sirena_flight_pattern.IsCodeshare:
                                flight_patterns[bucket_key].append(flight_pattern)
                                flight_bases[flight.Id] = flight
                            else:
                                flight_pattern.IsCodeshare = True
                                operating_carrier = self.get_carrier(sirena_flight_pattern.OperatingFlight.CarrierCode)
                                operating_carrier_id = operating_carrier.Id if operating_carrier else 0

                                if not operating_carrier_id:
                                    unknown_carriers.add(sirena_flight_pattern.OperatingFlight.CarrierCode)
                                    continue

                                operating_bucket_key = SirenaFlightsParser.get_bucket_key(
                                    operating_carrier_id,
                                    sirena_flight_pattern.OperatingFlight.FlightNumber,
                                    leg_number
                                )
                                operating_fp_list = flight_patterns[operating_bucket_key]
                                operating_fp = None
                                flight_pattern.BucketKey = operating_bucket_key
                                for fp in operating_fp_list:
                                    if self._date_matcher.intersect(
                                        flight_pattern.OperatingFromDate,
                                        flight_pattern.OperatingUntilDate,
                                        flight_pattern.OperatingOnDays,
                                        fp.OperatingFromDate,
                                        fp.OperatingUntilDate,
                                        fp.OperatingOnDays,
                                    ):
                                        operating_fp = fp
                                        break

                                if not operating_fp:
                                    new_fp = TFlightPattern()
                                    new_fp.MergeFrom(flight_pattern)
                                    flight_pattern_id += 1
                                    new_fp.Id = flight_pattern_id
                                    new_fp.MarketingCarrier = operating_carrier_id
                                    new_fp.MarketingCarrierIata = sirena_flight_pattern.OperatingFlight.CarrierCode
                                    new_fp.MarketingFlightNumber = sirena_flight_pattern.OperatingFlight.FlightNumber
                                    new_fp.BucketKey = operating_bucket_key
                                    new_fp.FlightLegKey = SirenaFlightsParser.get_leg_key(
                                        new_fp.MarketingCarrierIata if new_fp.MarketingCarrierIata else operating_carrier_id,
                                        sirena_flight_pattern.OperatingFlight.FlightNumber,
                                        leg_number,
                                        ivi,
                                    )
                                    new_fp.IsCodeshare = False
                                    new_fp.IsDerivative = True

                                    flight_patterns[operating_bucket_key].append(new_fp)

                                    flight.OperatingCarrier = operating_carrier_id
                                    flight.OperatingCarrierIata = sirena_flight_pattern.OperatingFlight.CarrierCode
                                    flight.OperatingFlightNumber = sirena_flight_pattern.OperatingFlight.FlightNumber
                                    flight.BucketKey = operating_bucket_key
                                    flight_bases[flight.Id] = flight
                                else:
                                    flight_pattern.FlightId = operating_fp.FlightId
                                    flight_pattern.OperatingFlightPatternId = operating_fp.Id

                                flight_patterns[bucket_key].append(flight_pattern)
        except Exception as e:
            self._logger.exception('Sirena data import has been aborted')
            raise e
        finally:
            for carrier in unknown_carriers:
                self._logger.error('Unknown carrier: %s', carrier)
            for departure_station, leg_keys in unknown_departure_stations.items():
                self._logger.error('Unknown departure station: %s in flights %s', departure_station, leg_keys)
            for arrival_station, leg_keys in unknown_arrival_stations.items():
                self._logger.error('Unknown arrival station: %s in flights %s', arrival_station, leg_keys)
        return (flight_bases, flight_patterns, unknown_stop_points)

    @staticmethod
    def shift_date(int_date, day_shift):
        day = int_date % 100
        int_date //= 100
        month = int_date % 100
        year = int_date // 100
        shifted_date = date(year, month, day)
        if day_shift > 0:
            shifted_date = shifted_date + timedelta(days=day_shift)
        return shifted_date.strftime('%Y-%m-%d')

    def get_carrier(self, sirena_code):
        for carrier in self._carriers:
            if carrier.SirenaId == sirena_code or carrier.Iata == sirena_code:
                return carrier
        return None

    def get_station(self, stop_point):
        sirena_code = stop_point.StationCode
        if not sirena_code:
            sirena_code = stop_point.CityCode

        if not sirena_code:
            return None

        for station in self._stations:
            if station.SirenaCode == sirena_code or station.IataCode == sirena_code:
                return station
        return None

    @staticmethod
    def get_code(station):
        if not station:
            return None
        if station.IataCode:
            return station.IataCode
        if station.SirenaCode:
            return station.SirenaCode
        if station.IcaoCode:
            return station.IcaoCode
        return None

    @staticmethod
    def get_bucket_key(carrier_id, flight_number, leg_number):
        return '{}.{}.{}'.format(carrier_id, flight_number, leg_number)

    @staticmethod
    def get_leg_key(carrier, flight_number, leg_number, ivi):
        return '{}.{}.{}.{}'.format(
            carrier,
            flight_number,
            leg_number,
            ivi,
        )
