from __future__ import unicode_literals

from six.moves.urllib.parse import urlencode
from travel.avia.library.python.shared_dicts.common import get_binary_file
from travel.avia.library.python.shared_dicts.rasp import get_repository, ResourceType
from travel.avia.shared_flights.data_importer.existing_flights_loader import ExistingFlightsLoader
from travel.avia.shared_flights.data_importer.parse_args import parse_args
from travel.proto.dicts.rasp.carrier_pb2 import TCarrier
from travel.proto.dicts.rasp.transport_pb2 import TTransport
from travel.proto.shared_flights.snapshots.station_with_codes_pb2 import TStationWithCodes
from travel.proto.shared_flights.ssim.flights_pb2 import TFlightBase, TFlightPattern

import io
import json
import logging
import os
import struct
import tempfile
import zipfile
from collections import defaultdict
from datetime import datetime, timedelta
from sandbox.common import rest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from travel.avia.shared_flights.data_importer.date_shift_calculator import DateShiftCalculator
from travel.avia.shared_flights.data_importer.latest_ready_resource import LatestReadyResource
from travel.avia.shared_flights.data_importer.sirena_airlines_repository import SirenaAirlinesRepository
from travel.avia.shared_flights.data_importer.sirena_flights_parser import SirenaFlightsParser
from travel.avia.shared_flights.data_importer.sirena_routes_repository import SirenaRoutesRepository
from travel.avia.shared_flights.data_importer.storage.carrier import CarrierStorage
from travel.avia.shared_flights.data_importer.storage.missing_data import MissingDataStorage, InvalidDataEmailDumper, \
    ExtraDataStorage, flight_base_route
from travel.avia.shared_flights.data_importer.storage.station_codes import StationCodesStorage
from travel.avia.shared_flights.data_importer.storage.stations import StationStorage
from travel.avia.shared_flights.lib.python.db_locks.db_lock_handle import DbLockHandle
from travel.avia.shared_flights.lib.python.db_models.base import CustomEncoder
from travel.avia.shared_flights.lib.python.db_models.carrier import Carrier
from travel.avia.shared_flights.lib.python.db_models.db_lock import DbLockType
from travel.avia.shared_flights.lib.python.db_models.flight_base import FlightBase, SirenaFlightBase
from travel.avia.shared_flights.lib.python.db_models.flight_pattern import FlightPattern, SirenaFlightPattern
from travel.avia.shared_flights.lib.python.db_models.last_imported_info import LastImportedInfo
from travel.avia.shared_flights.lib.python.db_models.station import Station
from travel.avia.shared_flights.lib.python.db_models.stop_point import StopPoint
from travel.avia.shared_flights.lib.python.db_models.timezone import Timezone
from travel.avia.shared_flights.lib.python.db_models.transport_model import TransportModel
from travel.avia.shared_flights.lib.python.settings import (
    PGAAS_CLUSTER_ID,
    PGAAS_DATABASE_NAME,
    PGAAS_USER,
    PGAAS_PASSWORD,
    PGAAS_PORT,
    SSL_ROOT_CERT,
    ENVIRONMENT,
    SOLOMON_TOKEN,
)
from travel.avia.shared_flights.tasks.monitoring.db import collect_and_send
from travel.avia.shared_flights.tasks.monitoring.helpers import solomon_push_reporter_maker, UnsupportedEnvironmentError
from travel.library.python.dicts.transport_model_repository import TransportModelRepository
from travel.library.python.sender import TransactionalApi
from typing import Dict

AMADEUS_FLIGHTS_RESOURCE_TYPE = 'AVIA_SHARED_FLIGHTS_AMADEUS_{}_RESOURCE'.format(ENVIRONMENT)
SIRENA_FLIGHTS_RESOURCE_TYPE = 'AVIA_SHARED_FLIGHTS_SIRENA_{}_RESOURCE'.format(ENVIRONMENT)
SIRENA_FLIGHTS_PRODUCTION_RESOURCE_TYPE = 'AVIA_SHARED_FLIGHTS_SIRENA_PRODUCTION_RESOURCE'
EXTRA_DATA_COMPUTED = 'EXTRA_DATA_COMPUTED_{}'.format(ENVIRONMENT)

SENDER_HOST = os.getenv('AVIA_SENDER_HOST', 'https://test.sender.yandex-team.ru')
SENDER_AUTH_KEY = os.getenv('AVIA_SENDER_AUTH_KEY')
SENDER_ACCOUNT_SLUG = os.getenv('AVIA_SENDER_ACCOUNT_SLUG', 'ya.tickets')
SENDER_MISSING_DATA_CAMPAIGN_SLUG = os.getenv('AVIA_SENDER_MISSING_DATA_CAMPAIGN_SLUG', 'SB9UCGQ3-YL02')
SENDER_FOR_TESTING = os.getenv('AVIA_SENDER_FOR_TESTING', '').lower() == 'true'
MISSING_DATA_ALERT_EMAILS = os.getenv('AVIA_DATA_IMPORTER_MISSING_DATA_ALERT_EMAILS',
                                      'avia-alerts@yandex-team.ru').split(',')


class FlightsDataImporter(object):

    def __init__(self, logger, start_date=None, oauth_token=None):
        self.logger: logging.Logger = logger
        self.carriers: CarrierStorage = None
        self.station_codes = None
        self.timezones = None
        self.stations: StationStorage = None
        self.missing_data = MissingDataStorage()
        self.extra_data = ExtraDataStorage()
        self.email_client = TransactionalApi(
            host=SENDER_HOST,
            auth=(SENDER_AUTH_KEY, ''),
            account_slug=SENDER_ACCOUNT_SLUG,
            campaign_slug=SENDER_MISSING_DATA_CAMPAIGN_SLUG,
        )
        self.invalid_data_email_dumper = InvalidDataEmailDumper(
            self.missing_data,
            self.extra_data,
            self.email_client,
            ENVIRONMENT
        )
        self._reimported_ = False
        self.oauth_token = oauth_token

    def import_flights_data(self, args):
        self.logger.info('Importing flights data')

        if ENVIRONMENT not in ['TESTING', 'PRODUCTION']:
            raise Exception('Please specify YANDEX_ENVIRONMENT_TYPE to be TESTING or PRODUCTION')

        # connect to database
        conn_params = {
            'port': PGAAS_PORT,
            'sslmode': 'require',
            'target_session_attrs': 'read-write',
            'sslrootcert': SSL_ROOT_CERT,
        }
        conn_string_for_work = 'postgresql+psycopg2://{user}:{password}@{hosts}/{database}?{query_string}'.format(
            user=PGAAS_USER,
            password=PGAAS_PASSWORD,
            database=PGAAS_DATABASE_NAME,
            hosts='c-{cluster_id}.rw.db.yandex.net'.format(cluster_id=PGAAS_CLUSTER_ID),
            query_string=urlencode(conn_params),
        )

        db_engine = create_engine(conn_string_for_work, echo=args.echosql)
        if not db_engine:
            raise Exception('Unable to create DB engine')
        db_engine.execution_options(stream_results=True)

        session_factory = sessionmaker(bind=db_engine)

        try:
            solomon_push_api = solomon_push_reporter_maker(ENVIRONMENT, SOLOMON_TOKEN)
            collect_and_send(session_factory, solomon_push_api)
        except UnsupportedEnvironmentError:
            self.logger.warning(
                'Sending metrics is not supported for %s environment',
                ENVIRONMENT,
            )
        except:
            self.logger.exception('Unexpected error while sending metrics to solomon')

        session = session_factory()
        if not session:
            raise Exception('Unable to create DB session')

        if args.noamadeus:
            self.logger.info('Importing Amadeus data')
            self.import_amadeus_data(session_factory, args)
        if args.nosirena:
            self.logger.info('Importing data from Sirena')
            self.import_sirena_data(session_factory, args)
        if args.nostoppoints:
            self.logger.info('Importing unknown stop points from the flight sources')
            self.import_stop_points(session_factory, args)
        if args.notransportmodels:
            self.logger.info('Importing transport models from the sandbox resource')
            self.import_transport_models(session_factory, args)
        if args.noraspdata:
            self.logger.info('Importing carriers, stations and timezones from rasp sandbox resources')
            self.import_carriers(session_factory)
            self.import_timezones(session_factory)
            self.import_stations(session_factory)
        if args.noemail and (self._reimported_ or args.forceemail):
            try:
                self.extra_data.collect_if_needed(session_factory)
            except:
                self.logger.exception('Cannot collect extra data')
            for email in MISSING_DATA_ALERT_EMAILS:
                if not email:
                    continue
                self.invalid_data_email_dumper.send(email, SENDER_FOR_TESTING)
        else:
            self.logger.info(
                'Emails were not sent: forced not to send (%s) or no new schedule data (%s) and not forced to send (%s)',
                not args.noemail,
                not self._reimported_,
                not args.forceemail,
            )

    def import_amadeus_data(self, session_factory, args):
        # test if another import is in progress
        timeout = 10000  # seconds
        session = session_factory()
        with DbLockHandle(session, DbLockType.DBLOCK_IMPORT_AMADEUS, self.logger, args.force, timeout):
            # grab out when was the last time we've imported data
            with LatestReadyResource([AMADEUS_FLIGHTS_RESOURCE_TYPE], self.logger) as last_resource:
                if not last_resource:
                    self.logger.warning('Unable to fetch the latest Amadeus resource. Result: %s.', last_resource)
                    return

                last_resource_proxy_url = last_resource.get_resource_proxy_url()
                if not last_resource_proxy_url:
                    self.logger.warning('The latest Amadeus resource does not have url to download. Result: %r.',
                                        last_resource)
                    return

                last_resource_attributes = last_resource.get_attributes()
                if not last_resource_attributes:
                    self.logger.warning('The latest Amadeus resource does not have any attributes. Result: %r.',
                                        last_resource_attributes)
                    return

                last_imported = session.query(LastImportedInfo).filter(
                    LastImportedInfo.resource_type == AMADEUS_FLIGHTS_RESOURCE_TYPE).one_or_none()
                if (
                    last_imported and
                    last_imported.imported_resource_id == last_resource.get_id() and
                    not args.forceamadeus
                ):
                    self.logger.info('Resource %r has been already imported into the database.', last_resource)
                    return
                if args.forceamadeus:
                    self.logger.info('Amadeus parser is forced')
                self.logger.info(
                    'Last resource used for Amadeus import: %r, created at: %r, data mark: %r.',
                    last_imported.imported_resource_id,
                    last_imported.created_at,
                    last_imported.imported_date,
                )
                self.logger.info(
                    'New resource to be used for Amadeus import: %r, created at: %r, data mark: %r.',
                    last_resource.get_id(),
                    last_resource_attributes.get('created_at'),
                    last_resource_attributes.get('source_data_date'),
                )

                try:
                    # Fetch rasp reference files
                    self.fetch_rasp_dicts()

                    # Fetch flight bases
                    flight_bases: Dict[int, TFlightBase] = self.fetch_flight_bases(
                        self.carriers,
                        self.stations,
                        last_resource_proxy_url,
                    )

                    # Fetch flight patterns
                    flight_patterns = self.fetch_flight_patterns(
                        self.carriers,
                        flight_bases,
                        last_resource_proxy_url,
                    )

                    existing_flights_loader = ExistingFlightsLoader(self.logger)
                    existing_flights_loader.update_existing_amadeus_flights(session_factory())

                    flight_patterns_temp_file = self.save_flight_patterns_to_temp_file(flight_patterns)
                    flight_bases_temp_file = self.save_flight_bases_to_temp_file(flight_bases)

                    # Let GC free some memory
                    del flight_bases
                    del flight_patterns

                    # Push temp files into the database.
                    # Truncate tables first.
                    session.close()
                    session = session_factory()

                    self.logger.info('About to clean the flight bases and flight patterns from the database.')
                    cursor = session.bind.raw_connection().cursor()
                    cursor.execute('SET LOCAL lock_timeout = 60000')
                    cursor.execute('TRUNCATE {}'.format(FlightBase.__tablename__))
                    cursor.execute('TRUNCATE {}'.format(FlightPattern.__tablename__))

                    self.logger.info('Truncated the flight bases and flight patterns database tables.')

                    # Now push the files to the database
                    fb_count = self.push_text_file_to_table(cursor, flight_bases_temp_file, FlightBase.__tablename__)
                    self.logger.info('New Amadeuse\'s flight bases count %d', fb_count)
                    fp_count = self.push_text_file_to_table(cursor, flight_patterns_temp_file, FlightPattern.__tablename__)
                    self.logger.info('New Amadeuse\'s flight patterns count %d', fp_count)

                    self.logger.info('Updated Amadeuse\'s flight bases/flight patterns tables')

                    if not last_imported:
                        last_imported = LastImportedInfo()
                        last_imported.created_at = datetime.now()

                    last_imported.updated_at = last_resource_attributes.get('created_at')
                    last_imported.imported_date = last_resource_attributes.get('source_data_date')
                    last_imported.imported_resource_id = last_resource.get_id()
                    last_imported.resource_type = AMADEUS_FLIGHTS_RESOURCE_TYPE
                    session.merge(last_imported)
                    session.commit()
                    self._reimported_ = True
                except Exception as e:
                    self.logger.exception('Amadeus data import has been aborted')
                    raise e
                finally:
                    session.commit()

    def import_sirena_data(self, session_factory, args):
        # test if another import is in progress
        timeout = 900  # seconds
        session = session_factory()
        with DbLockHandle(session, DbLockType.DBLOCK_IMPORT_SIRENA, self.logger, args.force, timeout):
            # grab out when was the last time we've imported data
            resource_types = [SIRENA_FLIGHTS_RESOURCE_TYPE]
            if ENVIRONMENT == 'TESTING':
                resource_types.append(SIRENA_FLIGHTS_PRODUCTION_RESOURCE_TYPE)
            with LatestReadyResource(resource_types, self.logger) as last_resource:
                if not last_resource:
                    self.logger.warning('Unable to fetch the latest Sirena resource. Result: %s.', last_resource)
                    return

                last_resource_proxy_url = last_resource.get_resource_proxy_url()
                if not last_resource_proxy_url:
                    self.logger.warning('The latest Sirena resource does not have url to download. Result: %r.',
                                        last_resource)
                    return

                last_resource_attributes = last_resource.get_attributes()
                if not last_resource_attributes:
                    self.logger.warning('The latest Sirena resource does not have any attributes. Result: %r.',
                                        last_resource_attributes)
                    return

                last_imported = session.query(LastImportedInfo).filter(
                    LastImportedInfo.resource_type == SIRENA_FLIGHTS_RESOURCE_TYPE).one_or_none()
                if (
                    last_imported and
                    last_imported.imported_resource_id == last_resource.get_id() and
                    not args.forcesirena
                ):
                    self.logger.info('Resource %r has been already imported into the database.', last_resource)
                    return

                if args.forcesirena:
                    self.logger.info('Sirena parser is forced')

                if last_imported:
                    self.logger.info(
                        'Last resource used for Sirena import: %r, created at: %r, data mark: %r.',
                        last_imported.imported_resource_id,
                        last_imported.created_at,
                        last_imported.imported_date,
                    )
                self.logger.info(
                    'New resource to be used for Sirena import: %r, created at: %r, data mark: %r.',
                    last_resource.get_id(),
                    last_resource_attributes.get('created_at'),
                    last_resource_attributes.get('source_data_date'),
                )

                try:
                    # Fetch rasp reference files
                    self.fetch_rasp_dicts()

                    # Fetch sirena flights
                    flight_bases, flight_patterns, unknown_stop_points = self.fetch_sirena_flights(
                        self.carriers.by_id.values(),
                        self.stations,
                        self.timezones,
                        last_resource_proxy_url,
                    )
                    existing_flights_loader = ExistingFlightsLoader(self.logger)
                    existing_flights_loader.update_existing_sirena_flights(session_factory())

                    flight_bases_temp_file = self.save_flight_bases_to_temp_file(flight_bases)
                    flight_patterns_temp_file = self.save_flight_patterns_to_temp_file(flight_patterns)

                    # Push temp files into the database.
                    # Truncate tables first.
                    session.close()
                    session = session_factory()

                    self.logger.info('About to clean the flight bases and flight patterns from the database.')
                    cursor = session.bind.raw_connection().cursor()
                    cursor.execute('SET LOCAL lock_timeout = 60000')
                    cursor.execute('TRUNCATE {}'.format(SirenaFlightBase.__tablename__))
                    cursor.execute('TRUNCATE {}'.format(SirenaFlightPattern.__tablename__))

                    self.logger.info('Truncated Sirena\'s flight bases/flight patterns tables.')

                    # Now push the files to the database
                    fb_count = self.push_text_file_to_table(cursor, flight_bases_temp_file, SirenaFlightBase.__tablename__)
                    self.logger.info('New Sirena\'s flight bases count %d', fb_count)
                    fp_count = self.push_text_file_to_table(cursor, flight_patterns_temp_file, SirenaFlightPattern.__tablename__)
                    self.logger.info('New Sirena\'s flight patterns count %d', fp_count)

                    self.logger.info('Updated Sirena\'s flight bases/flight patterns tables')

                    if not last_imported:
                        last_imported = LastImportedInfo()
                        last_imported.created_at = datetime.now()

                    last_imported.updated_at = datetime.now()
                    last_imported.imported_date = last_resource_attributes.get('created_at')
                    last_imported.imported_resource_id = last_resource.get_id()
                    last_imported.resource_type = SIRENA_FLIGHTS_RESOURCE_TYPE
                    session.merge(last_imported)
                    session.commit()

                    self.push_unknown_stop_points_to_table(unknown_stop_points, session_factory(), add_only=True)
                    self._reimported_ = True
                except Exception as e:
                    self.logger.exception('Data import from Sirena has been aborted')
                    session.rollback()
                    raise e
                finally:
                    session.close()

    def import_stop_points(self, session_factory, args):
        # test if another import is in progress
        timeout = 1200  # seconds
        session = session_factory()
        with DbLockHandle(session, DbLockType.DBLOCK_IMPORT_STOP_POINTS, self.logger, args.force, timeout):
            try:
                # Fetch rasp reference files
                self.fetch_rasp_dicts()

                stop_points_from_db = session.query(StopPoint).all()
                stop_points_by_code = {}
                for stop_point in stop_points_from_db:
                    if not stop_point.station_code:
                        continue
                    stop_points_by_code[stop_point.station_code] = stop_point

                marketing_carrier_column = 4
                marketing_flight_number_column = 5
                leg_number_column = 6
                flight_statuses = session.execute(
                    '''
                    select
                        departureroutepointfrom,
                        departureroutepointto,
                        arrivalroutepointfrom,
                        arrivalroutepointto,
                        airlineid,
                        flightnumber,
                        legnumber
                    from
                        flight_status
                    '''
                )

                stations = self.stations.by_id.values()
                for routepoints in flight_statuses:
                    for index in range(0, 4):
                        known_point = False
                        routepoint = routepoints[index]
                        if not routepoint:
                            continue
                        if routepoint in self.stations.by_iata or routepoint in stop_points_by_code:
                            continue
                        for station in stations:
                            if station.SirenaCode == routepoint or station.Station.TitleDefault == routepoint:
                                known_point = True
                                break
                        if known_point:
                            continue
                        unknown_point = StopPoint()
                        unknown_point.station_code = routepoint
                        unknown_point.unknown_since = datetime.now()
                        unknown_point.leg_key = '{}.{}.{}'.format(
                            routepoints[marketing_carrier_column],
                            routepoints[marketing_flight_number_column],
                            routepoints[leg_number_column],
                        )
                        stop_points_by_code[routepoint] = unknown_point

                routepoints_to_delete = []
                for routepoint in stop_points_by_code.keys():
                    if routepoint in self.stations.by_iata:
                        routepoints_to_delete.append(routepoint)
                        continue

                    for station in stations:
                        if station.SirenaCode == routepoint or station.Station.TitleDefault == routepoint:
                            routepoints_to_delete.append(routepoint)
                            break

                for routepoint in routepoints_to_delete:
                    stop_points_by_code.pop(routepoint, None)

                self.push_unknown_stop_points_to_table(stop_points_by_code, session_factory())
            except Exception as e:
                self.logger.exception('Stop points import has been aborted')
                session.rollback()
                raise e
            finally:
                session.close()

    def import_transport_models(self, session_factory, args):
        # test if another import is in progress
        timeout = 60  # seconds
        session = session_factory()
        with DbLockHandle(session, DbLockType.DBLOCK_IMPORT_TRANSPORT_MODELS, self.logger, args.force, timeout):
            try:
                if args.transportmodelsfile:
                    transport_models = self.import_transport_models_from_file(args.transportmodelsfile)
                else:
                    transport_models = self.import_transport_models_from_sandbox()

                # save transport models into table
                self.push_transport_models_to_table(transport_models, session)
            except Exception as e:
                self.logger.exception('Transport models import has been aborted')
                session.rollback()
                raise e
            finally:
                session.close()

    def import_transport_models_from_sandbox(self):
        transport_models_repository = get_repository(
            ResourceType.TRAVEL_DICT_RASP_TRANSPORT_MODEL_PROD,
            oauth=self.oauth_token,
        )
        if not transport_models_repository:
            self.logger.error('Unable to fetch the transport models data.')
            return None

        transport_models = {}
        for transport_model in transport_models_repository.itervalues():
            if not transport_model.Code and not transport_model.CodeEn:
                continue
            transport_type = transport_model.TransportType
            if transport_type != TTransport.EType.TYPE_PLANE and transport_type != TTransport.EType.TYPE_HELICOPTER:
                continue
            transport_models[transport_model.Id] = transport_model

        self.logger.info('Parsed transport models from sandbox: %s', format_number(len(transport_models)))
        return transport_models

    def import_transport_models_from_file(self, file_name):
        transport_models = {}
        with open(file_name, mode='rb') as bin_proto_file:
            fileContent = bin_proto_file.read()
            transport_models_repo = TransportModelRepository()
            transport_models_repo.load_from_string(fileContent)
            for transport_model in transport_models_repo.itervalues():
                if not transport_model.Code and not transport_model.CodeEn:
                    continue
                transport_type = transport_model.TransportType
                if transport_type != TTransport.EType.TYPE_PLANE and transport_type != TTransport.EType.TYPE_HELICOPTER:
                    continue
                transport_models[transport_model.Id] = transport_model

        self.logger.info('Parsed transport models from %s: %s', file_name, format_number(len(transport_models)))
        return transport_models

    def fetch_timezones(self, sandbox):
        timezone_repository = get_repository(
            ResourceType.TRAVEL_DICT_RASP_TIMEZONE_PROD,
            oauth=self.oauth_token,
        )
        if not timezone_repository:
            self.logger.error('Unable to fetch the time zones data.')
            return None
        timezones = {}
        for timezone in timezone_repository.itervalues():
            if not timezone.Code:
                continue
            timezones[timezone.Id] = timezone.Code

        self.logger.info('Parsed time zones: %s', format_number(len(timezones)))
        return timezones

    def fetch_rasp_dicts(self):
        sandbox = rest.Client()
        if not self.carriers:
            self.carriers = CarrierStorage(self.logger, sandbox, oauth_token=self.oauth_token)

        if not self.station_codes:
            self.station_codes = StationCodesStorage(self.logger, sandbox, oauth_token=self.oauth_token)

        if not self.timezones:
            self.timezones = self.fetch_timezones(sandbox)

        if not self.stations:
            self.stations = StationStorage(self.logger, sandbox, self.station_codes, oauth_token=self.oauth_token)

    def fetch_flight_bases(self, carriers: CarrierStorage, stations: StationStorage, last_resource_proxy_url):
        self.logger.info('Going to fetch flight bases from %s', last_resource_proxy_url)
        flight_bases_content = get_binary_file('{}/flight_bases.bin'.format(last_resource_proxy_url), self.oauth_token)
        if not flight_bases_content:
            raise Exception('Unable to load flight bases')
        self.logger.info('Fetched flight_bases file: %s', format_number(len(flight_bases_content)))

        flight_bases = {}
        flight_bases_count = 0
        skipped_due_to_unknown_endpoint = 0
        pos = 0
        while pos < len(flight_bases_content):
            size = struct.unpack('<i', flight_bases_content[pos:(pos + 4)])[0]
            pos += 4
            flight_base_proto = TFlightBase()
            flight_base_proto.ParseFromString(flight_bases_content[pos:(pos + size)])
            pos += size
            flight_bases_count += 1
            if not flight_base_proto.OperatingCarrierIata:
                raise Exception('Invalid flight base proto: no operating carrier: {}'.format(flight_base_proto))

            carrier = carriers.by_code.get(flight_base_proto.OperatingCarrierIata)
            if not carrier:
                self.missing_data.add_carrier(flight_base_proto.OperatingCarrierIata, meta={
                    'source': 'amadeus',
                    'route': flight_base_route(flight_base_proto),
                })
                continue

            origin = stations.by_iata.get(flight_base_proto.DepartureStationIata)
            if not origin:
                self.missing_data.add_station(flight_base_proto.DepartureStationIata, meta={
                    'type': 'departure station',
                    'source': 'amadeus',
                    'route': flight_base_route(flight_base_proto),
                })

            destination = stations.by_iata.get(flight_base_proto.ArrivalStationIata)
            if not destination:
                self.missing_data.add_station(flight_base_proto.ArrivalStationIata, meta={
                    'type': 'arrival station',
                    'source': 'amadeus',
                    'route': flight_base_route(flight_base_proto),
                })

            if not origin or not destination:
                skipped_due_to_unknown_endpoint += 1
                continue

            flight_base_proto.OperatingCarrier = carrier.Id
            flight_base_proto.DepartureStation = origin.Station.Id
            flight_base_proto.ArrivalStation = destination.Station.Id
            flight_bases[flight_base_proto.Id] = flight_base_proto

        self.logger.info('Parsed flight_bases: %s, stored: %s', format_number(flight_bases_count),
                         format_number(len(flight_bases)))
        self.logger.info('Skipped because of unknown endpoint: %s', format_number(skipped_due_to_unknown_endpoint))

        return flight_bases

    def save_flight_bases_to_temp_file(self, flight_bases):
        flight_bases_temp_file = tempfile.NamedTemporaryFile().name
        with open(flight_bases_temp_file, 'w') as flight_bases_file:
            for flight_base in flight_bases.values():
                line = '^'.join(FlightsDataImporter.to_str_list([
                    flight_base.Id,
                    FlightsDataImporter.flight_base_bucket_key(flight_base),
                    flight_base.OperatingCarrier,
                    flight_base.OperatingCarrierIata,
                    flight_base.OperatingFlightNumber,
                    flight_base.ItineraryVariationIdentifier,
                    flight_base.LegSeqNumber,
                    flight_base.DepartureStation,
                    flight_base.DepartureStationIata,
                    flight_base.ScheduledDepartureTime,
                    flight_base.LocalDepartureUtcTimeVariation,
                    flight_base.DepartureTerminal,
                    flight_base.ArrivalStation,
                    flight_base.ArrivalStationIata,
                    flight_base.ScheduledArrivalTime,
                    flight_base.LocalArrivalUtcTimeVariation,
                    flight_base.ArrivalTerminal,
                    flight_base.AircraftModel,
                    flight_base.FlyingCarrierIata,
                    flight_base.IntlDomesticStatus,
                    flight_base.TrafficRestrictionCode,
                    flight_base.DesignatedCarrier,
                ]))
                flight_bases_file.write('{}\n'.format(line))
        self.logger.info('Flight bases temp file: %s', flight_bases_temp_file)
        return flight_bases_temp_file

    def fetch_flight_patterns(
        self,
        carriers: CarrierStorage,
        flight_bases: Dict[int, TFlightBase],
        last_resource_proxy_url,
    ):
        self.logger.info('Going to fetch flight patterns from %s', last_resource_proxy_url)
        date_shift_calculator = DateShiftCalculator(self.stations.by_id, self.timezones, self.logger)
        flight_patterns_content = get_binary_file('{}/flight_patterns.bin'.format(last_resource_proxy_url),
                                                  self.oauth_token)
        if not flight_patterns_content:
            raise Exception('Unable to load flight patterns')
        self.logger.info('Fetched flight_patterns: %s', format_number(len(flight_patterns_content)))

        # Parse flight patterns
        last_day_to_process = (datetime.now() + timedelta(days=365)).strftime('%Y-%m-%d')
        flight_patterns = defaultdict(list)
        pos = 0
        flight_patterns_count = 0

        too_far_in_the_future_legs = 0
        skipped_due_to_unknown_carrier = 0
        skipped_due_to_unknown_flight_base = 0
        while pos < len(flight_patterns_content):
            size = struct.unpack('<i', flight_patterns_content[pos:(pos + 4)])[0]
            pos += 4
            flight_pattern_proto = TFlightPattern()
            flight_pattern_proto.ParseFromString(flight_patterns_content[pos:(pos + size)])
            pos += size
            flight_patterns_count += 1
            if not flight_pattern_proto.MarketingCarrierIata:
                raise Exception('Invalid flight pattern proto: no marketing carrier')

            if flight_patterns_count % 100000 == 0:
                self.logger.info('Processed %d flight patterns so far.', flight_patterns_count)

            if not flight_bases.get(flight_pattern_proto.FlightId):
                skipped_due_to_unknown_flight_base += 1
                continue

            carrier = carriers.by_code.get(flight_pattern_proto.MarketingCarrierIata)
            if not carrier:
                self.missing_data.add_carrier(flight_pattern_proto.MarketingCarrierIata, meta={
                    'source': 'amadeus',
                    'route': flight_base_route(flight_bases.get(flight_pattern_proto.FlightId)),
                })
                skipped_due_to_unknown_carrier += 1
                continue

            # replace carrier IATA with its ID in the flight pattern bucket key
            bucket_key = flight_pattern_proto.BucketKey

            if not bucket_key and flight_pattern_proto.FlightLegKey:
                last_dot = flight_pattern_proto.FlightLegKey.rfind('.')
                if last_dot <= 0:
                    raise Exception('Invalid flight pattern leg key: {}'.format(flight_pattern_proto))
                bucket_key = flight_pattern_proto.FlightLegKey[:last_dot]
            if bucket_key:
                flight_pattern_proto.BucketKey = FlightsDataImporter.replace_carrier_iata_with_id(
                    carriers.by_code,
                    bucket_key,
                )
            else:
                flight_pattern_proto.BucketKey = FlightsDataImporter.marketing_flight_pattern_bucket_key(
                    flight_pattern_proto,
                )

            if last_day_to_process < flight_pattern_proto.OperatingFromDate:
                too_far_in_the_future_legs += 1
                continue

            flight_pattern_proto.MarketingCarrier = carrier.Id

            flight_pattern_proto.ArrivalDayShift = date_shift_calculator.calculate_arrival_day_shift(
                flight_pattern_proto,
                flight_bases[flight_pattern_proto.FlightId],
            )
            flight_patterns[flight_pattern_proto.BucketKey].append(flight_pattern_proto)

        self.logger.info(
            'Flight patterns skipped due to unknown carrier: %s',
            format_number(skipped_due_to_unknown_carrier)
        )
        self.logger.info(
            'Flight patterns skipped due to unknown flight base: %s',
            format_number(skipped_due_to_unknown_flight_base)
        )

        self.logger.info('Parsed flight_patterns: %s', format_number(flight_patterns_count))
        self.logger.info('Too far in the future legs: %s', format_number(too_far_in_the_future_legs))
        return flight_patterns

    def save_flight_patterns_to_temp_file(self, flight_patterns):
        lines_count = 0
        flight_patterns_temp_file = tempfile.NamedTemporaryFile().name
        with open(flight_patterns_temp_file, 'w') as flight_patterns_file:
            for legs_list in flight_patterns.values():
                for flight_pattern in legs_list:
                    lines_count += 1
                    line = '^'.join(FlightsDataImporter.to_str_list([
                        flight_pattern.Id,
                        flight_pattern.BucketKey,
                        flight_pattern.FlightId,
                        flight_pattern.FlightLegKey,
                        flight_pattern.OperatingFromDate,
                        flight_pattern.OperatingUntilDate,
                        flight_pattern.OperatingOnDays,
                        flight_pattern.MarketingCarrier,
                        flight_pattern.MarketingCarrierIata,
                        flight_pattern.MarketingFlightNumber,
                        flight_pattern.IsAdministrative,
                        flight_pattern.IsCodeshare,
                        flight_pattern.Performance,
                        flight_pattern.ArrivalDayShift,
                        flight_pattern.OperatingFlightPatternId,
                        flight_pattern.LegSeqNumber,
                        flight_pattern.DesignatedCarrier,
                        flight_pattern.DepartureDayShift,
                        flight_pattern.IsDerivative,
                    ]))
                    flight_patterns_file.write('{}\n'.format(line))
        self.logger.info('Stored flight_patterns: %s', format_number(lines_count))
        self.logger.info('Flight patterns temp file: %s', flight_patterns_temp_file)
        return flight_patterns_temp_file

    def fetch_sirena_flights(self, carriers: Dict[int, TCarrier], stations: StationStorage, timezones, last_resource_proxy_url):
        mem_zip_bytes = get_binary_file('{}/sirena.zip'.format(last_resource_proxy_url), self.oauth_token)
        if not mem_zip_bytes:
            raise Exception('Unable to load sirena.zip contents')
        self.logger.info('Loaded bytes for sirena flights: %d', len(mem_zip_bytes))
        mem_zip = io.BytesIO(mem_zip_bytes)
        with zipfile.ZipFile(mem_zip) as zip_file:
            airlines_data = zip_file.read("airlines.pb2.bin")
            airlines_repo = SirenaAirlinesRepository()
            airlines_repo.load_from_string(airlines_data)

            self.logger.info('Airlines count: {}'.format(len(airlines_repo.values())))

            routes_data = zip_file.read("routes.pb2.bin")
            routes_repo = SirenaRoutesRepository()
            routes_repo.load_from_string(routes_data)

            self.logger.info('Routes count: {}'.format(len(routes_repo.values())))

            return SirenaFlightsParser(self.logger, carriers, stations.by_id, timezones, self.missing_data).parse_data(
                airlines_repo.values(),
                routes_repo.values()
            )

    def push_text_file_to_table(self, cursor, data_file_name, table_name):
        with open(data_file_name) as data_file:
            cursor.copy_from(data_file, table_name, sep='^', size=10000)
            return cursor.rowcount

    def push_carriers_to_table(self, carriers_by_id: Dict[int, TCarrier], session):
        carriers_by_id_copy = carriers_by_id.copy()
        carriers_from_db = session.query(Carrier).all()
        deleted_carriers = []
        for db_carrier in carriers_from_db:
            carrier = carriers_by_id_copy.pop(db_carrier.id, None)
            if not carrier:
                deleted_carriers.append(db_carrier.id)
            else:
                db_carrier.merge(carrier)
                session.merge(db_carrier)

        for new_carrier in carriers_by_id_copy.values():
            db_carrier = Carrier()
            db_carrier.id = new_carrier.Id
            db_carrier.merge(new_carrier)
            session.merge(db_carrier)

        db_deleted_carriers = Carrier.__table__.delete().where(Carrier.id.in_(deleted_carriers))
        session.execute(db_deleted_carriers)
        session.commit()
        self.logger.info('Deleted carriers: %s', len(deleted_carriers))
        self.logger.info('Added carriers: %s', len(carriers_by_id_copy))
        self.logger.info('Updated carriers: %s', len(carriers_from_db) - len(deleted_carriers))

    def push_timezones_to_table(self, timezones, session):
        timezones_copy = timezones.copy()
        timezones_from_db = session.query(Timezone).all()
        deleted_timezones = []
        try:
            for db_timezone in timezones_from_db:
                timezone_code = timezones_copy.pop(db_timezone.id, None)
                if not timezone_code:
                    deleted_timezones.append(db_timezone.id)
                else:
                    db_timezone.code = timezone_code
                    session.merge(db_timezone)

            for new_timezone_id, new_timezone_code in timezones_copy.items():
                db_timezone = Timezone()
                db_timezone.id = new_timezone_id
                db_timezone.code = new_timezone_code
                session.merge(db_timezone)

            db_deleted_timezones = Timezone.__table__.delete().where(Timezone.id.in_(deleted_timezones))
            session.execute(db_deleted_timezones)
            session.commit()
        except Exception as exc:
            self.logger.exception(
                u'Error while updating timezones. Rolling back. %r', exc
            )
            session.rollback()
            raise
        finally:
            session.close()

        self.logger.info('Deleted timezones: %s', len(deleted_timezones))
        self.logger.info('Added timezones: %s', len(timezones_copy))
        self.logger.info('Updated timezones: %s', len(timezones_from_db) - len(deleted_timezones))

    def push_stations_to_table(self, stations: Dict[int, TStationWithCodes], session):
        stations_copy = stations.copy()
        stations_from_db = session.query(Station).all()
        stations_not_in_dict = []
        stations_to_save = []
        for db_station in stations_from_db:
            station = stations_copy.pop(db_station.id, None)
            if not station:
                stations_not_in_dict.append(db_station.id)
            else:
                old_station_state = json.dumps(db_station, cls=CustomEncoder)
                db_station.merge(station)
                if json.dumps(db_station, cls=CustomEncoder) != old_station_state:
                    stations_to_save.append(db_station)

        for new_station in stations_copy.values():
            db_station = Station()
            db_station.id = new_station.Station.Id
            db_station.merge(new_station)
            stations_to_save.append(db_station)

        if stations_to_save:
            session.bulk_save_objects(stations_to_save)

        session.commit()
        self.logger.info('Stations not in dict: %s', len(stations_not_in_dict))
        self.logger.info('Added stations: %s', len(stations_copy))
        self.logger.info('Updated stations: %s', len(stations_to_save) - len(stations_copy))

    def push_transport_models_to_table(self, transport_models, session):
        transport_models_copy = transport_models.copy()
        transport_models_from_db = session.query(TransportModel).all()
        deleted_transport_models = []
        transport_models_to_save = []
        for db_transport_model in transport_models_from_db:
            transport_model = transport_models_copy.pop(db_transport_model.id, None)
            if not transport_model:
                deleted_transport_models.append(transport_model.id)
            else:
                old_transport_model_state = json.dumps(db_transport_model, cls=CustomEncoder)
                db_transport_model.merge(transport_model)
                if json.dumps(db_transport_model, cls=CustomEncoder) != old_transport_model_state:
                    transport_models_to_save.append(db_transport_model)

        for new_transport_model in transport_models_copy.values():
            db_transport_model = TransportModel()
            db_transport_model.id = new_transport_model.Id
            db_transport_model.merge(new_transport_model)
            transport_models_to_save.append(db_transport_model)

        if transport_models_to_save:
            session.bulk_save_objects(transport_models_to_save)

        if deleted_transport_models:
            db_deleted_transport_models = TransportModel.__table__.delete().where(
                TransportModel.id.in_(deleted_transport_models))
            session.execute(db_deleted_transport_models)
        session.commit()
        self.logger.info('Deleted transport models: %s', len(deleted_transport_models))
        self.logger.info('Added transport models: %s', len(transport_models_copy))
        self.logger.info('Updated transport models: %s', len(transport_models_to_save) - len(transport_models_copy))

    def push_unknown_stop_points_to_table(self, unknown_stop_points, session, add_only=False):
        unknown_stop_points_from_db = session.query(StopPoint).all()
        unknown_stop_points_from_db_dict = {stop_point.station_code: stop_point for stop_point in
                                            unknown_stop_points_from_db}
        if add_only:
            added_points = 0
            for stop_point in unknown_stop_points.values():
                if stop_point.station_code not in unknown_stop_points_from_db_dict:
                    session.merge(stop_point)
                    added_points += 1
            session.commit()
            self.logger.info('Added unknown stop points: %s', added_points)
        else:
            unknown_stop_points_copy = unknown_stop_points.copy()

            deleted_stop_points = []
            for db_stop_point in unknown_stop_points_from_db:
                stop_point = unknown_stop_points_copy.pop(str(db_stop_point), None)
                if not stop_point:
                    deleted_stop_points.append(db_stop_point.id)

            for new_stop_point in unknown_stop_points_copy.values():
                session.merge(new_stop_point)

            db_deleted_stop_points = StopPoint.__table__.delete().where(StopPoint.id.in_(deleted_stop_points))
            session.execute(db_deleted_stop_points)
            session.commit()
            self.logger.info('Deleted unknown stop points: %s', len(deleted_stop_points))
            self.logger.info('Added unknown stop points: %s', len(unknown_stop_points_copy))

    @staticmethod
    def flight_base_bucket_key(flight_base):
        return '{}.{}.{}'.format(
            flight_base.OperatingCarrier,
            flight_base.OperatingFlightNumber,
            flight_base.LegSeqNumber,
        )

    @staticmethod
    def marketing_flight_pattern_bucket_key(flight_pattern):
        parts = flight_pattern.FlightLegKey.split('.') if flight_pattern.FlightLegKey else []
        leg_number = parts[2] if len(parts) >= 3 else 1
        return '{}.{}.{}'.format(
            flight_pattern.MarketingCarrier,
            flight_pattern.MarketingFlightNumber,
            leg_number,
        )

    @staticmethod
    def replace_carrier_iata_with_id(carriers: Dict[str, TCarrier], bucket_key):
        if not bucket_key:
            return bucket_key
        parts = bucket_key.split('.')
        if len(parts) != 3:
            return bucket_key
        carrier = carriers.get(parts[0])
        if not carrier:
            return bucket_key
        return '{}.{}.{}'.format(
            carrier.Id,
            parts[1],
            parts[2],
        )

    @staticmethod
    def to_str_list(items):
        for item in items:
            yield str(item)

    def import_carriers(self, session_factory):
        self.fetch_rasp_dicts()
        self.push_carriers_to_table(self.carriers.by_id, session_factory())

    def import_stations(self, session_factory):
        self.fetch_rasp_dicts()
        self.push_stations_to_table(self.stations.by_id, session_factory())

    def import_timezones(self, session_factory):
        self.fetch_rasp_dicts()
        self.push_timezones_to_table(self.timezones, session_factory())


def format_number(num):
    return '{:,}'.format(num)


def run_import(args):
    logger = logging.getLogger(__name__)
    logger.info('Data importer args: %s', args)

    parsed_args = parse_args(args)
    FlightsDataImporter(logger).import_flights_data(parsed_args)
