# coding=utf-8

import io
from collections import namedtuple
from datetime import datetime, timedelta

from travel.avia.shared_flights.diff_builder.codeshares_map import CodesharesMap
from travel.avia.shared_flights.diff_builder.utils import (
    convert_to_latin,
    get_bucket_key,
    ensure_string,
    write_binary_string,
)
from travel.avia.shared_flights.lib.python.consts.consts import MIN_APM_FLIGHT_BASE_ID, MIN_SIRENA_FLIGHT_BASE_ID
from travel.avia.shared_flights.lib.python.consts.consts import (
    MIN_AMADEUS_RETAINED_FLIGHT_BASE_ID,
    MIN_SIRENA_RETAINED_FLIGHT_BASE_ID,
)
from travel.avia.shared_flights.lib.python.consts.consts import (
    MIN_AMADEUS_RETAINED_FLIGHT_PATTERN_ID,
    MIN_SIRENA_RETAINED_FLIGHT_PATTERN_ID,
)
from travel.avia.shared_flights.lib.python.consts.consts import (
    MIN_SIRENA_FLIGHT_PATTERN_ID,
    MIN_SIRENA_CODE_SHARE_FLIGHT_PATTERN_ID,
)
from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_mask import DateMaskMatcher
from travel.avia.shared_flights.lib.python.db_models.existing_flight_base import ExistingFlightBase
from travel.avia.shared_flights.lib.python.db_models.existing_flight_pattern import ExistingFlightPattern
from travel.proto.shared_flights.ssim.flights_pb2 import TFlightBase, TFlightPattern, EFlightBaseSource

FlightPatternKey = namedtuple(
    'FlightPatternKey',
    [
        'marketing_carrier',
        'marketing_flight_number',
        'leg_seq_number',
        'source',
        'flight_base_id',
        'marketing_carrier_iata',
        'is_codeshare',
        'arrival_day_shift',
        'designated_carrier',
        'departure_day_shift',
        'is_derivative',
    ],
)

ExistingFlightBaseKey = namedtuple(
    'ExistingFlightBaseKey',
    [
        'operating_carrier',
        'operating_carrier_iata',
        'operating_flight_number',
        'leg_seq_number',
        'departure_station',
        'departure_station_iata',
        'scheduled_departure_time',
        'departure_terminal',
        'arrival_station',
        'arrival_station_iata',
        'scheduled_arrival_time',
        'arrival_terminal',
        'aircraft_model',
        'flying_carrier_iata',
        'intl_dom_status',
        'traffic_restriction_code',
        'source',
    ],
)

FlightBaseIdEntry = namedtuple('FlightBaseIdEntry', ['id', 'bucket_key'])
FlightBaseStations = namedtuple('FlightBaseStations', ['station_from', 'station_to'])


# Max days in the past to put into snapshot
MAX_PAST_DAYS = 32
# Log progress after this number of rows processed
LOG_INDEX_STEP = 250000
# Number of rows to fetch in a single portion when caching large tables
FETCH_ITERSIZE = 1000


# Loads flight bases and flight patterns from the database into memory
class FlightsBuilder(object):
    def __init__(self, logger, start_date=None):
        self.logger = logger
        self._start_date = start_date if start_date else datetime.now() - timedelta(days=MAX_PAST_DAYS)

    def fetch_flight_bases(self, conn, transport_models):
        mem_entry = io.BytesIO()
        flight_bases_stations = {}

        self._fetch_flight_bases_to_temp_file(conn, mem_entry, flight_bases_stations, transport_models)
        flight_bases_ids = self.fetch_existing_flight_bases(conn, mem_entry, flight_bases_stations, transport_models)
        return mem_entry, flight_bases_ids, flight_bases_stations

    def _fetch_flight_bases_to_temp_file(self, conn, mem_entry, flight_bases_stations, transport_models):
        self.logger.info('Reading flight_bases from the database')
        fields = '''
                id,                       --[0]
                bucket_key,               --[1]
                operating_carrier,        --[2]
                operating_carrier_iata,   --[3]
                operating_flight_number,  --[4]
                itinerary_variation,      --[5]
                leg_seq_number,           --[6]
                departure_station,        --[7]
                departure_station_iata,   --[8]
                scheduled_departure_time, --[9]
                departure_terminal,       --[10]
                arrival_station,          --[11]
                arrival_station_iata,     --[12]
                scheduled_arrival_time,   --[13]
                arrival_terminal,         --[14]
                aircraft_model,           --[15]
                flying_carrier_iata,      --[16]
                intl_dom_status,          --[17]
                traffic_restriction_code, --[18]
                designated_carrier        --[19]
        '''
        select_stmt = '''
            select
                {fields}
            from
                sirena_flight_base
            union
            select
                {fields}
            from
                flight_base
            union
            select
                {fields}
            from
                apm_flight_base
            order by
                bucket_key,
                itinerary_variation
            '''.format(
            fields=fields
        )
        with conn.cursor(name='flight_bases_scheduled') as cursor:
            cursor.itersize = FETCH_ITERSIZE
            cursor.execute(select_stmt)
            row_index = 0
            for row in cursor:
                flight_base = TFlightBase()
                flight_base.Id = int(row[0])
                flight_base.BucketKey = ensure_string(row[1])
                flight_base.OperatingCarrier = int(row[2])
                flight_base.OperatingCarrierIata = ensure_string(row[3])
                flight_base.OperatingFlightNumber = ensure_string(row[4])
                flight_base.ItineraryVariationIdentifier = ensure_string(row[5])
                flight_base.LegSeqNumber = int(row[6])

                flight_base.DepartureStation = int(row[7])
                flight_base.DepartureStationIata = ensure_string(row[8])
                flight_base.ScheduledDepartureTime = int(row[9])
                flight_base.DepartureTerminal = convert_to_latin(ensure_string(row[10]))

                flight_base.ArrivalStation = int(row[11])
                flight_base.ArrivalStationIata = ensure_string(row[12])
                flight_base.ScheduledArrivalTime = int(row[13])
                flight_base.ArrivalTerminal = convert_to_latin(ensure_string(row[14]))
                flight_base.AircraftModel = ensure_string(row[15])
                flight_base.AircraftTypeId = self._get_transport_model(transport_models, flight_base.AircraftModel)
                flight_base.FlyingCarrierIata = ensure_string(row[16])
                flight_base.IntlDomesticStatus = ensure_string(row[17])
                flight_base.TrafficRestrictionCode = ensure_string(row[18])
                flight_base.DesignatedCarrier = int(row[19])
                flight_base.Source = self._get_source(flight_base.Id)

                write_binary_string(mem_entry, flight_base.SerializeToString())
                flight_bases_stations[flight_base.Id] = FlightBaseStations(
                    flight_base.DepartureStation,
                    flight_base.ArrivalStation,
                )
                row_index += 1
                if row_index % LOG_INDEX_STEP == 0:
                    self.logger.info('Loaded %d scheduled flight bases', row_index)
            self.logger.info('Loaded %d scheduled flight bases', row_index)
        self.logger.info('Done fetching scheduled flight bases')

    def fetch_flight_patterns(self, conn, fb_ids, fb_stations, p2p_cache):
        self.logger.info('Reading flight_patterns from the database')
        fields = '''
            id,                          --[0]
            bucket_key,                  --[1]
            flight_base_id,              --[2]
            flight_leg_key,              --[3]
            operating_from,              --[4]
            operating_until,             --[5]
            operating_on_days,           --[6]
            marketing_carrier,           --[7]
            marketing_carrier_iata,      --[8]
            marketing_flight_number,     --[9]
            is_administrative,           --[10]
            is_codeshare,                --[11]
            arrival_day_shift,           --[12]
            departure_day_shift,         --[13]
            operating_flight_pattern_id, --[14]
            leg_seq_number,              --[15]
            div(id, {}) as is_sirena,    --[16]
            is_derivative                --[17]
            '''.format(
            MIN_SIRENA_FLIGHT_PATTERN_ID
        )
        # Sirena's flights should be processed last, since we need to extract codeshares info first
        select_stmt = '''
            select
                {fields}
            from flight_pattern
            union
            select
                {fields}
            from apm_flight_pattern
            union
            select
                {fields}
            from sirena_flight_pattern
            order by
                is_sirena,
                is_codeshare,
                leg_seq_number
            '''.format(
            fields=fields
        )
        codeshares_map = CodesharesMap(MIN_SIRENA_CODE_SHARE_FLIGHT_PATTERN_ID, self.logger)
        with conn.cursor(name='scheduled_flight_patterns') as cursor:
            cursor.itersize = FETCH_ITERSIZE
            self.logger.info('Parsing the flight-patterns data, scan 1')
            cursor.execute(select_stmt)
            row_index = 0
            for row in cursor:
                flight_pattern = self._get_flight_pattern(row)
                codeshares_map.add_flight(flight_pattern)
                row_index += 1
                if row_index % LOG_INDEX_STEP == 0:
                    self.logger.info('Loaded %d scheduled flight patterns', row_index)
            self.logger.info('Loaded %d scheduled flight patterns', row_index)

        with conn.cursor(name='existing_flight_patterns') as cursor:
            cursor.itersize = FETCH_ITERSIZE
            self.logger.info('Addig existing flyouts to the flight patterns set')
            row_index = 0
            for flight_pattern in self.generate_flight_patterns_from_flyouts(
                self.fetch_flyouts(cursor), fb_ids, codeshares_map
            ):
                codeshares_map.add_flight(flight_pattern)
                row_index += 1
                if row_index % LOG_INDEX_STEP == 0:
                    self.logger.info('Loaded %d existing flight patterns', row_index)
            self.logger.info('Loaded %d existing flight patterns', row_index)

        self.logger.info('Parsing the flight-patterns data, scan 2')
        mem_entry = io.BytesIO()
        flight_patterns_count = 0
        for fp in codeshares_map.generate_flights_to_write():
            fb_data = fb_stations.get(fp.FlightId)
            if fb_data:
                p2p_cache.add_segment(fp, fb_data.station_from, fb_data.station_to)
            else:
                self.logger.error('Unable to find flight base for flight pattern: %s', fp)
            write_binary_string(mem_entry, fp.SerializeToString())
            flight_patterns_count += 1
            if flight_patterns_count % LOG_INDEX_STEP == 0:
                self.logger.info('Flight patterns processed so far: %d', flight_patterns_count)

        self.logger.info('Done fetching flight patterns')
        return mem_entry

    def _get_flight_pattern(self, row):
        flight_pattern = TFlightPattern()
        flight_pattern.Id = int(row[0])
        flight_pattern.BucketKey = ensure_string(row[1])
        flight_pattern.FlightId = int(row[2])
        flight_pattern.FlightLegKey = ensure_string(row[3])
        flight_pattern.OperatingFromDate = row[4].strftime('%Y-%m-%d')
        flight_pattern.OperatingUntilDate = row[5].strftime('%Y-%m-%d')
        flight_pattern.OperatingOnDays = int(row[6])

        flight_pattern.MarketingCarrier = int(row[7])
        flight_pattern.MarketingCarrierIata = ensure_string(row[8])
        flight_pattern.MarketingFlightNumber = ensure_string(row[9])
        flight_pattern.IsAdministrative = bool(row[10])
        flight_pattern.IsCodeshare = bool(row[11])
        flight_pattern.ArrivalDayShift = int(row[12]) if row[12] else 0
        flight_pattern.DepartureDayShift = int(row[13]) if row[13] else 0
        flight_pattern.OperatingFlightPatternId = int(row[14]) if row[14] else 0
        flight_pattern.LegSeqNumber = int(row[15]) if row[15] else 0
        flight_pattern.IsDerivative = bool(row[17])
        return flight_pattern

    def _get_transport_model(self, transport_models, code):
        if not code:
            return 0
        transport_model_id = transport_models.get(code)
        return transport_model_id if transport_model_id else 0

    def _get_source(self, flight_base_id):
        if not flight_base_id:
            return EFlightBaseSource.TYPE_UNKNOWN
        if flight_base_id > MIN_SIRENA_FLIGHT_BASE_ID:
            return EFlightBaseSource.TYPE_SIRENA
        if flight_base_id > MIN_APM_FLIGHT_BASE_ID:
            return EFlightBaseSource.TYPE_APM
        return EFlightBaseSource.TYPE_AMADEUS

    def fetch_existing_flight_bases(self, conn, mem_entry, flight_bases_stations, transport_models):
        flight_bases_ids = {}
        current_amadeus_fb_id = MIN_AMADEUS_RETAINED_FLIGHT_BASE_ID
        current_sirena_fb_id = MIN_SIRENA_RETAINED_FLIGHT_BASE_ID
        row_index = 0
        with conn.cursor(name='flight_bases') as cursor:
            cursor.execute(
                '''
                select
                    id,                       -- [0]
                    operating_carrier,        -- [1]
                    operating_carrier_iata,   -- [2]
                    operating_flight_number,  -- [3]
                    leg_seq_number,           -- [4]
                    departure_station,        -- [5]
                    departure_station_iata,   -- [6]
                    scheduled_departure_time, -- [7]
                    departure_terminal,       -- [8]
                    arrival_station,          -- [9]
                    arrival_station_iata,     -- [10]
                    scheduled_arrival_time,   -- [11]
                    arrival_terminal,         -- [12]
                    aircraft_model,           -- [13]
                    flying_carrier_iata,      -- [14]
                    intl_dom_status,          -- [15]
                    traffic_restriction_code, -- [16]
                    designated_carrier,       -- [17]
                    source                    -- [18]
                from
                    {}
                order by
                    operating_carrier,        -- [1]
                    operating_carrier_iata,   -- [2]
                    operating_flight_number,  -- [3]
                    leg_seq_number,           -- [4]
                    departure_station,        -- [5]
                    departure_station_iata,   -- [6]
                    scheduled_departure_time, -- [7]
                    departure_terminal,       -- [8]
                    arrival_station,          -- [9]
                    arrival_station_iata,     -- [10]
                    scheduled_arrival_time,   -- [11]
                    arrival_terminal,         -- [12]
                    aircraft_model,           -- [13]
                    flying_carrier_iata,      -- [14]
                    intl_dom_status,          -- [15]
                    traffic_restriction_code, -- [16]
                    source                    -- [18]
                '''.format(
                    ExistingFlightBase.__table__
                )
            )
            current_fb_key = None
            current_fb_id = None
            for row in cursor:
                flight_base = TFlightBase()
                flight_base.Id = int(row[0])
                flight_base.OperatingCarrier = int(row[1])
                flight_base.OperatingCarrierIata = ensure_string(row[2])
                flight_base.OperatingFlightNumber = ensure_string(row[3])
                flight_base.LegSeqNumber = int(row[4])

                flight_base.DepartureStation = int(row[5])
                flight_base.DepartureStationIata = ensure_string(row[6])
                flight_base.ScheduledDepartureTime = int(row[7])
                flight_base.DepartureTerminal = convert_to_latin(ensure_string(row[8]))

                flight_base.ArrivalStation = int(row[9])
                flight_base.ArrivalStationIata = ensure_string(row[10])
                flight_base.ScheduledArrivalTime = int(row[11])
                flight_base.ArrivalTerminal = convert_to_latin(ensure_string(row[12]))
                flight_base.AircraftModel = ensure_string(row[13])
                flight_base.AircraftTypeId = self._get_transport_model(transport_models, flight_base.AircraftModel)
                flight_base.FlyingCarrierIata = ensure_string(row[14])
                flight_base.IntlDomesticStatus = ensure_string(row[15])
                flight_base.TrafficRestrictionCode = ensure_string(row[16])
                flight_base.DesignatedCarrier = int(row[17])
                flight_base.Source = int(row[18])
                flight_base.BucketKey = get_bucket_key(
                    flight_base.OperatingCarrier,
                    flight_base.OperatingFlightNumber,
                    flight_base.LegSeqNumber,
                )

                flight_base_key = ExistingFlightBaseKey(
                    operating_carrier=flight_base.OperatingCarrier,
                    operating_carrier_iata=flight_base.OperatingCarrierIata,
                    operating_flight_number=flight_base.OperatingFlightNumber,
                    leg_seq_number=flight_base.LegSeqNumber,
                    departure_station=flight_base.DepartureStation,
                    departure_station_iata=flight_base.DepartureStationIata,
                    scheduled_departure_time=flight_base.ScheduledDepartureTime,
                    departure_terminal=flight_base.DepartureTerminal,
                    arrival_station=flight_base.ArrivalStation,
                    arrival_station_iata=flight_base.ArrivalStationIata,
                    scheduled_arrival_time=flight_base.ScheduledArrivalTime,
                    arrival_terminal=flight_base.ArrivalTerminal,
                    aircraft_model=flight_base.AircraftModel,
                    flying_carrier_iata=flight_base.FlyingCarrierIata,
                    intl_dom_status=flight_base.IntlDomesticStatus,
                    traffic_restriction_code=flight_base.TrafficRestrictionCode,
                    source=flight_base.Source,
                )

                if current_fb_key == flight_base_key:
                    flight_bases_ids[flight_base.Id] = FlightBaseIdEntry(current_fb_id, flight_base.BucketKey)
                    continue

                if flight_base.Source == EFlightBaseSource.TYPE_SIRENA:
                    current_sirena_fb_id += 1
                    flight_bases_ids[flight_base.Id] = FlightBaseIdEntry(current_sirena_fb_id, flight_base.BucketKey)
                    flight_base.Id = current_sirena_fb_id
                else:
                    current_amadeus_fb_id += 1
                    flight_bases_ids[flight_base.Id] = FlightBaseIdEntry(current_amadeus_fb_id, flight_base.BucketKey)
                    flight_base.Id = current_amadeus_fb_id

                write_binary_string(mem_entry, flight_base.SerializeToString())
                flight_bases_stations[flight_base.Id] = FlightBaseStations(
                    flight_base.DepartureStation,
                    flight_base.ArrivalStation,
                )
                current_fb_key = flight_base_key
                current_fb_id = flight_base.Id

                row_index += 1
                if row_index % LOG_INDEX_STEP == 0:
                    self.logger.info('Loaded %d existing flight bases', row_index)
        self.logger.info('Loaded %d existing flight bases', row_index)
        return flight_bases_ids

    def fetch_flyouts(self, cursor):
        cursor.execute(
            '''
            select
                flight_departure_day,     -- [0]
                marketing_carrier,        -- [1]
                marketing_flight_number,  -- [2]
                leg_seq_number,           -- [3]
                source,                   -- [4]
                flight_base_id,           -- [5]
                marketing_carrier_iata,   -- [6]
                is_codeshare,             -- [7]
                arrival_day_shift,        -- [8]
                designated_carrier,       -- [9]
                departure_day_shift,      -- [10]
                is_derivative             -- [11]
            from
                {}
            order by
                marketing_carrier,        -- [1]
                marketing_flight_number,  -- [2]
                leg_seq_number,           -- [3]
                source,                   -- [4]
                flight_base_id,           -- [5]
                marketing_carrier_iata,   -- [6]
                is_codeshare,             -- [7]
                arrival_day_shift,        -- [8]
                designated_carrier,       -- [9]
                departure_day_shift,      -- [10]
                is_derivative             -- [11]
            '''.format(
                ExistingFlightPattern.__table__
            )
        )
        current_key = None
        current_dates_list = []
        row_index = 0
        for row in cursor:
            row_key = FlightPatternKey(
                row[1],
                row[2],
                row[3],
                row[4],
                row[5],
                row[6],
                row[7],
                row[8],
                row[9],
                row[10],
                row[11],
            )
            if row_key != current_key:
                if current_key:
                    yield current_key, current_dates_list
                current_key = row_key
                current_dates_list = []

            current_dates_list.append(row[0].strftime('%Y-%m-%d'))

            row_index += 1
            if row_index % LOG_INDEX_STEP == 0:
                self.logger.info('Loaded %d existing flyouts', row_index)

        if current_key:
            yield (current_key, current_dates_list)

    def generate_flight_patterns_from_flyouts(self, flyouts, flight_bases_data, codeshares_map):
        date_index = DateIndex(self._start_date)
        dmm = DateMaskMatcher(date_index, self._start_date, MAX_PAST_DAYS + 1)
        current_fp_amadeus_id = MIN_AMADEUS_RETAINED_FLIGHT_PATTERN_ID
        current_fp_sirena_id = MIN_SIRENA_RETAINED_FLIGHT_PATTERN_ID
        self.logger.info('Converting existing flyouts into flight patterns')
        for key, dates_list in flyouts:
            fb_data = flight_bases_data[key.flight_base_id]
            date_mask = dmm.new_date_mask()

            for d in dates_list:
                dmm.add_date_str(d, date_mask)

            # exclude past dates that already present in the new data to avoid duplicates in the flight patterns cache
            new_data_date_mask = dmm.new_date_mask()
            new_data_flight_patterns = None
            marketing_bucket_key = '{}.{}.{}'.format(
                key.marketing_carrier, key.marketing_flight_number, key.leg_seq_number
            )
            if key.source == EFlightBaseSource.TYPE_SIRENA:
                new_data_flight_patterns = codeshares_map.get_sirena_flights(marketing_bucket_key)
            elif key.source == EFlightBaseSource.TYPE_AMADEUS:
                new_data_flight_patterns = codeshares_map.get_amadeus_flights(marketing_bucket_key)
            else:
                # don't process existing flight patterns from sources other than Amadeus or Sirena
                continue
            if new_data_flight_patterns:
                for flight_pattern in new_data_flight_patterns:
                    dmm.add_range(
                        flight_pattern.OperatingFromDate,
                        flight_pattern.OperatingUntilDate,
                        flight_pattern.OperatingOnDays,
                        new_data_date_mask,
                    )
                dmm.remove_mask(date_mask, new_data_date_mask)

            for mask in dmm.generate_masks(date_mask):
                flight_pattern = TFlightPattern()
                flight_pattern.FlightId = fb_data.id

                if key.source == EFlightBaseSource.TYPE_SIRENA:
                    current_fp_sirena_id += 1
                    flight_pattern.Id = current_fp_sirena_id
                elif key.source == EFlightBaseSource.TYPE_AMADEUS:
                    current_fp_amadeus_id += 1
                    flight_pattern.Id = current_fp_amadeus_id
                else:
                    # don't process existing flight patterns from sources other than Amadeus or Sirena
                    continue
                flight_pattern.BucketKey = fb_data.bucket_key
                flight_pattern.FlightLegKey = '{}.{}'.format(fb_data.bucket_key, key.leg_seq_number)
                flight_pattern.OperatingFromDate = mask[0].replace('.', '-')
                flight_pattern.OperatingUntilDate = mask[1].replace('.', '-')
                flight_pattern.OperatingOnDays = mask[2]

                flight_pattern.MarketingCarrier = key.marketing_carrier
                flight_pattern.MarketingCarrierIata = key.marketing_carrier_iata
                flight_pattern.MarketingFlightNumber = key.marketing_flight_number
                flight_pattern.IsCodeshare = key.is_codeshare
                flight_pattern.ArrivalDayShift = key.arrival_day_shift
                flight_pattern.DepartureDayShift = key.departure_day_shift
                flight_pattern.LegSeqNumber = key.leg_seq_number
                yield flight_pattern
