# -*- coding: utf-8 -*-
import logging
from datetime import date

from travel.avia.avia_statistics.flights_updater.lib.collector import FlightsCollector
from travel.avia.avia_statistics.flights_updater.lib.table import FlightsTable, Flight
from travel.avia.library.python.iata_correction import IATACorrector

logger = logging.getLogger(__name__)


class FlightsUpdater(object):
    def __init__(self, flights_collector, shared_flights_api_base_url, flights_table, iata_correction_batch_size=1000):
        # type: (FlightsCollector, str, FlightsTable, int) -> None
        self._flights_collector = flights_collector
        self._shared_flights_api_base_url = shared_flights_api_base_url
        self._flights_table = flights_table
        self._iata_correction_batch_size = iata_correction_batch_size

    def update_flights(self, target_date):
        # type: (date) -> None
        logger.info('start collecting flights to YT table')
        yt_flights_table = self._flights_collector.collect_flights_to_yt_table(target_date)
        logger.info('flights were collected to %s', yt_flights_table)
        batch = []
        logger.info('start storing flights to YDB')
        self._flights_table.create_if_doesnt_exist()
        total_processed = 0
        for flight in self._flights_collector.iterate_flights(yt_flights_table):
            batch.append(flight)
            if len(batch) == self._iata_correction_batch_size:
                self._process_batch(batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s flights were stored to YDB', total_processed)

        logger.info(
            'start deleting from YDB flights with departure date older than or equal to %s',
            target_date.isoformat(),
        )
        self._flights_table.delete_old(target_date)

    def _process_batch(self, batch):
        company_id_by_flight_number = IATACorrector(
            self._shared_flights_api_base_url,
        ).flight_numbers_to_carriers(
            ((f['airline_code'], f['flight_number']) for f in batch),
        )

        def map_to_flight(f):
            flight_number = '{} {}'.format(f['airline_code'], f['flight_number'])
            return Flight(
                from_id=f['from_id'],
                to_id=f['to_id'],
                company_id=company_id_by_flight_number.get(flight_number) or f['company_id'],
                flight_number=flight_number,
                national_version=f['national_version'],
                departure_date=f['departure_date'],
            )

        self._flights_table.replace_batch(list(map(map_to_flight, batch)))
