from sqlalchemy.orm import Session

from travel.avia.flight_extras.application.models import Flight, FlightPassengerExperience, Source


class FlightPassengerExperienceUpdater(object):
    BUFFER_SIZE = 10000

    def __init__(self, session):
        # type: (Session) -> None
        self._session = session
        self._flights = {}
        self._flight_passenger_experiences = {}
        self._count = 0

    def add_flight_passenger_experience(self, pe):
        # type: (FlightPassengerExperience) -> None
        if pe.key not in self._flight_passenger_experiences:
            self._add_flight(pe.flight)
            self._flight_passenger_experiences[pe.key] = pe
            self._count += 1
            if self._count % self.BUFFER_SIZE == 0:
                self.flush()

    def flush(self):
        if self._flight_passenger_experiences:
            self._flush_flights()
            self._flush_flight_passenger_experiences()
            self._flight_passenger_experiences.clear()

    @property
    def count(self):
        return self._count

    def _add_flight(self, flight):
        # type: (Flight) -> None
        if flight.key not in self._flights:
            self._flights[flight.key] = flight

    def _load_flights_ids(self):
        if not self._flights:
            return

        where = [
            "(company_iata='{}' and number='{}')".format(
                f.company_iata, f.number
            ) for f in self._flights.values() if not f.id
        ]

        if not where:
            return

        result = self._session.execute('select id, company_iata, number from {} where {}'.format(
            Flight.__tablename__,
            ' or '.join(where)),
        )
        for pk, iata, number in result:
            key = '{} {}'.format(iata, number)
            if key in self._flights:
                self._flights[key].id = pk

    def _flush_flights(self):
        self._load_flights_ids()

        insert = [
            "('{}', '{}')".format(f.company_iata, f.number) for f in self._flights.values() if not f.id
        ]

        if not insert:
            return

        self._session.execute('insert into {} (company_iata, number) values {}'.format(
            Flight.__tablename__,
            ', '.join(insert)),
        )

        self._load_flights_ids()

    def _flush_flight_passenger_experiences(self):
        for exp in self._flight_passenger_experiences.values():
            exp.flight_id = self._flights[exp.flight.key].id

        self._session.execute(
            FlightPassengerExperience.get_multi_insert_sql(
                self._flight_passenger_experiences.values(),
                self._update_where(),
            ),
        )

    @staticmethod
    def _update_where():
        return """
               {table}.source_id in (select id from {source_table} where name like '{fs_prefix}%')
                  or {table}.aircraft=EXCLUDED.aircraft
               """.format(
            table=FlightPassengerExperience.__tablename__,
            source_table=Source.__tablename__,
            fs_prefix=Source.FLIGHT_STATS_PREFIX,
        )
