# coding=utf-8
from __future__ import unicode_literals

import sqlalchemy as sa
from collections import defaultdict, namedtuple
from datetime import datetime, timedelta

from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_mask import DateMaskMatcher
from travel.avia.shared_flights.lib.python.db_models.flight_base import FlightBase, SirenaFlightBase
from travel.avia.shared_flights.lib.python.db_models.flight_pattern import FlightPattern, SirenaFlightPattern
from travel.avia.shared_flights.lib.python.db_models.existing_flight_base import ExistingFlightBase
from travel.avia.shared_flights.lib.python.db_models.existing_flight_pattern import ExistingFlightPattern
from travel.proto.shared_flights.ssim.flights_pb2 import EFlightBaseSource


YYYYMMDD = '%Y-%m-%d'


FlyoutKey = namedtuple('FlyoutKey', ['departure_day', 'marketing_carrier', 'marketing_flight_number', 'leg_seq_number'])
SegmentKey = namedtuple('SegmentKey', ['flight_id', 'marketing_carrier', 'marketing_flight_number', 'leg_seq_number'])


class ExistingFlightsLoader(object):

    def __init__(self, logger):
        self.logger = logger
        self.set_date_mask_matcher(datetime.now(), 31)

    def set_date_mask_matcher(self, now, max_days):
        self._now = now
        self._max_days = max_days
        self._start_date = self._now - timedelta(days=self._max_days)
        self._date_index = DateIndex(self._start_date)
        self._dmm = DateMaskMatcher(self._date_index, self._start_date, self._max_days-1)

    def update_existing_amadeus_flights(self, session):
        self.logger.info('Fetching existing Amadeus flights')
        amadeus_conditions = []
        now = datetime.now()
        for num_days in range(1, -31, -1):
            condition_day = (now - timedelta(days=num_days)).strftime(YYYYMMDD)
            amadeus_conditions.append(sa.and_(
                FlightPattern.operating_from <= condition_day,
                FlightPattern.operating_until >= condition_day,
            ))

        fps = session.query(FlightPattern).filter(sa.or_(*amadeus_conditions)).all()

        fps_to_keep, fbs_ids = self.get_flights_to_keep(fps, session, EFlightBaseSource.TYPE_AMADEUS)
        fbs = session.query(FlightBase).filter(FlightBase.id.in_(fbs_ids)).all()

        self.logger.info('Retaining Amadeus flight patterns: %d flyouts', len(fps_to_keep))
        self.logger.info('Existing Amadeus flight bases: %d, expecting %d', len(fbs), len(fbs_ids))
        self.renumber_and_save(
            session,
            fps_to_keep,
            fbs,
            EFlightBaseSource.TYPE_AMADEUS,
        )
        session.close()
        self.logger.info('Done fetching existing Amadeus flights')

    def update_existing_sirena_flights(self, session):
        self.logger.info('Fetching existing Sirena flights')
        sirena_conditions = []
        now = datetime.now()
        for num_days in range(-1, -31, -1):
            condition_day = (now - timedelta(days=num_days)).strftime(YYYYMMDD)
            sirena_conditions.append(sa.and_(
                SirenaFlightPattern.operating_from <= condition_day,
                SirenaFlightPattern.operating_until >= condition_day,
            ))

        sirena_fps = session.query(SirenaFlightPattern).filter(sa.or_(*sirena_conditions)).all()

        fps_to_keep, fbs_ids = self.get_flights_to_keep(sirena_fps, session, EFlightBaseSource.TYPE_SIRENA)
        sirena_fbs = session.query(SirenaFlightBase).filter(SirenaFlightBase.id.in_(fbs_ids)).all()

        self.renumber_and_save(
            session,
            fps_to_keep,
            sirena_fbs,
            EFlightBaseSource.TYPE_SIRENA,
        )
        session.close()

        self.logger.info('Existing Sirena flight patterns: %d, retaining %d', len(sirena_fps), len(fps_to_keep))
        self.logger.info('Existing Sirena flight bases: %d', len(sirena_fbs))
        self.logger.info('Done fetching existing Sirena flights')

    def get_flights_to_keep(self, flight_patterns, session, source_id):
        # fetch existing keys, so we can perform bulk insert without conflicts
        existing_flyout_keys = set()
        # need to subtract a day since self._now is not a midnight
        last_day_to_select = self._start_date - timedelta(days=1)
        cursor = session.query(
            ExistingFlightPattern.flight_departure_day,
            ExistingFlightPattern.marketing_carrier,
            ExistingFlightPattern.marketing_flight_number,
            ExistingFlightPattern.leg_seq_number,
            ).yield_per(1000).enable_eagerloads(False).filter(
                ExistingFlightPattern.source == source_id).filter(
                ExistingFlightPattern.flight_departure_day >= last_day_to_select)
        for row in cursor:
            existing_flyout_keys.add(FlyoutKey(row[0].strftime(YYYYMMDD), row[1], row[2], row[3]))
        self.logger.info('Found %d existing flyout keys', len(existing_flyout_keys))
        return self.get_flights_to_keep_internal(flight_patterns, existing_flyout_keys, source_id)

    def get_flights_to_keep_internal(self, flight_patterns, existing_flyout_keys, source_id):
        # flight_base_id => list(flight_patterns)
        fetched_flights = defaultdict(list)
        for fp in flight_patterns:
            fetched_flights[
                SegmentKey(fp.flight_base_id, fp.marketing_carrier, fp.marketing_flight_number, fp.leg_seq_number)].append(fp)

        fbs_ids = set()
        flyouts = []
        for segment_key, fp_list in fetched_flights.items():
            date_mask = self._dmm.new_date_mask()
            for fp in fp_list:
                period_start_date = fp.operating_from.strftime(YYYYMMDD)
                period_end_date = fp.operating_until.strftime(YYYYMMDD)
                self._dmm.add_range(
                    self._date_index.adjust_strdate(period_start_date),
                    self._date_index.adjust_strdate(period_end_date),
                    fp.operating_on_days,
                    date_mask,
                )
            if not date_mask.is_empty():
                added_flyouts = 0
                for flight_day_index in self._dmm.get_date_indexes(date_mask):
                    flight_day = self._date_index.get_date(flight_day_index)
                    flyout_key = FlyoutKey(
                        flight_day.strftime(YYYYMMDD),
                        segment_key.marketing_carrier,
                        segment_key.marketing_flight_number,
                        segment_key.leg_seq_number,
                    )
                    if flyout_key in existing_flyout_keys:
                        continue

                    flyout = ExistingFlightPattern()
                    flyout.flight_departure_day = flight_day
                    flyout.marketing_carrier = segment_key.marketing_carrier
                    flyout.marketing_flight_number = segment_key.marketing_flight_number
                    flyout.leg_seq_number = segment_key.leg_seq_number
                    flyout.source = source_id
                    flyout.flight_base_id = fp.flight_base_id
                    flyout.marketing_carrier_iata = fp.marketing_carrier_iata
                    flyout.is_administrative = fp.is_administrative
                    flyout.is_codeshare = fp.is_codeshare
                    flyout.is_derivative = fp.is_derivative
                    flyout.arrival_day_shift = fp.arrival_day_shift
                    flyout.designated_carrier = fp.designated_carrier
                    flyout.departure_day_shift = fp.departure_day_shift
                    flyout.created_at = self._now

                    existing_flyout_keys.add(flyout_key)
                    flyouts.append(flyout)
                    added_flyouts += 1

                if added_flyouts:
                    fbs_ids.add(fp.flight_base_id)

        return flyouts, fbs_ids

    def renumber_and_save(self, session, fps_to_keep, fbs, source_id):
        current_fb_id = session.query(sa.func.max(sa.func.coalesce(ExistingFlightBase.id, 0))).scalar()
        self.logger.info("Current max flight base id: %s", current_fb_id)
        if not current_fb_id:
            current_fb_id = 0

        fps_to_keep, flight_bases_to_save = self.renumber(fps_to_keep, fbs, current_fb_id, source_id)

        self.logger.info(
            "Prepared %d flight-bases and %d flight-patterns to store",
            len(flight_bases_to_save),
            len(fps_to_keep),
        )
        try:
            self.logger.info("Storing flight-bases")
            flight_bases_count = 0
            dicts = []
            for fb in flight_bases_to_save.values():
                dicts.append(fb.to_dict())
                flight_bases_count += 1
                if flight_bases_count == 5 or flight_bases_count % 2000 == 0:
                    session.execute(ExistingFlightBase.__table__.insert().values(dicts))
                    self.logger.info('Saved flight bases so far: %d', flight_bases_count)
                    dicts = []
            if dicts:
                session.execute(ExistingFlightBase.__table__.insert().values(dicts))
            self.logger.info('Saved flight bases: %d', flight_bases_count)

            self.logger.info("Storing flight-patterns")
            flight_patterns_count = 0
            dicts = []
            for fp in fps_to_keep:
                dicts.append(fp.to_dict())
                flight_patterns_count += 1
                if flight_patterns_count == 5 or flight_patterns_count % 2000 == 0:
                    session.execute(ExistingFlightPattern.__table__.insert().values(dicts))
                    self.logger.info('Saved flight patterns so far: %d', flight_patterns_count)
                    dicts = []
            if dicts:
                session.execute(ExistingFlightPattern.__table__.insert().values(dicts))
            self.logger.info('Saved flight patterns: %d', flight_patterns_count)

            # Clean old records
            # The correct way is to delete flight bases by id (using flight patterns to figure out ids).
            # It should still be okay, as long as we delete flight pattern no later as we delete its corresponding flight base.
            # And SQL works much faster this way.
            old_date_for_patterns = self._start_date - timedelta(days=2)
            old_date_for_bases = self._start_date - timedelta(days=3)
            session.execute(ExistingFlightPattern.__table__.delete().where(ExistingFlightPattern.created_at < old_date_for_patterns))
            session.execute(ExistingFlightBase.__table__.delete().where(ExistingFlightBase.created_at < old_date_for_bases))

            session.commit()
        except Exception as e:
            session.rollback()
            raise e

    def renumber(self, fps_to_keep, fbs, current_fb_id, source_id):
        flight_bases_to_save = {}
        for fb in fbs:
            fb_id = fb.id
            current_fb_id += 1
            new_fb = ExistingFlightBase()
            new_fb.id = current_fb_id
            new_fb.operating_carrier = fb.operating_carrier
            new_fb.operating_carrier_iata = fb.operating_carrier_iata
            new_fb.operating_flight_number = fb.operating_flight_number
            new_fb.leg_seq_number = fb.leg_seq_number
            new_fb.departure_station = fb.departure_station
            new_fb.departure_station_iata = fb.departure_station_iata
            new_fb.scheduled_departure_time = fb.scheduled_departure_time
            new_fb.departure_terminal = fb.departure_terminal
            new_fb.arrival_station = fb.arrival_station
            new_fb.arrival_station_iata = fb.arrival_station_iata
            new_fb.scheduled_arrival_time = fb.scheduled_arrival_time
            new_fb.arrival_terminal = fb.arrival_terminal
            new_fb.aircraft_model = fb.aircraft_model
            new_fb.flying_carrier_iata = fb.flying_carrier_iata
            new_fb.intl_dom_status = fb.intl_dom_status
            new_fb.traffic_restriction_code = fb.traffic_restriction_code
            new_fb.designated_carrier = fb.designated_carrier
            new_fb.created_at = self._now
            new_fb.source = source_id
            flight_bases_to_save[fb_id] = new_fb

        for fp in fps_to_keep:
            fp.flight_base_id = flight_bases_to_save[fp.flight_base_id].id

        return fps_to_keep, flight_bases_to_save
