# coding=utf-8
from __future__ import unicode_literals

from collections import defaultdict, namedtuple
from datetime import datetime, timedelta
import gzip
import six

from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_mask import DateMaskMatcher
from travel.avia.shared_flights.tasks.amadeus_parser.double_flyouts_filter import DoubleFlyoutsFilter
from travel.avia.shared_flights.tasks.amadeus_parser.flight_arr_times_cacher import FlightArrivalTimesCacher
from travel.avia.shared_flights.tasks.amadeus_parser.flight_base_data import FlightBaseDataFactory
from travel.avia.shared_flights.tasks.amadeus_parser.flight_dates import FlightDatesManager
from travel.avia.shared_flights.tasks.amadeus_parser.flight_segment import FlightSegmentFactory
from travel.avia.shared_flights.tasks.amadeus_parser.codeshares import CodesharesMap
import travel.proto.shared_flights.ssim.flights_pb2 as flights_pb2


FlightTitle = namedtuple('FlightTitle', ['marketing_carrier', 'marketing_flight_number'])
FlightData = namedtuple('FlightData', ['flight_base', 'flight_patterns'])

# This is what we store for each segment discovered in csv file before converting the cache into flight bases/flight patterns
FlightSegmentValue = namedtuple('FlightSegmentValue', ['fb_id', 'operating_dates'])


class AmadeusFlightsParser(object):
    """ Fetches CSV data file fetched from Amadeus over sftp. """

    def __init__(self, logger):
        self._logger = logger
        self._all_flight_legs = None
        self._start_date = datetime.now() - timedelta(days=75)
        max_days = 440
        self._date_index = DateIndex(self._start_date)
        self._dmm = DateMaskMatcher(self._date_index, self._start_date, max_days)
        self._double_flyouts_filter = DoubleFlyoutsFilter()
        self._arrival_times_cacher = FlightArrivalTimesCacher()

    def parse_flights_csv(self, csv_file):
        self._cache_segments(self._get_lines_from_csv(csv_file))
        self._parse_flights(self._get_lines_from_csv(csv_file))

    def parse_flights_zip(self, zip_file):
        self._cache_segments(self._get_lines_from_zip(zip_file))
        self._parse_flights(self._get_lines_from_zip(zip_file))

    def _get_lines_from_csv(self, csv_file):
        with open(csv_file) as input_file:
            for line in input_file:
                yield line

    def _get_lines_from_zip(self, zip_file):
        with gzip.GzipFile(zip_file) as input_file:
            for line in input_file:
                yield line

    def _cache_segments(self, lines_iterator):
        self._logger.info('Caching segments')
        def _cache_segment(flight_segment):
            self._double_flyouts_filter.cache_segment_on_date(flight_segment)
            self._arrival_times_cacher.cache_arrival_time(flight_segment)
        self._loop_over(lines_iterator, _cache_segment)
        self._logger.info('Done caching segments')

    def _parse_flights(self, lines_iterator):
        self._logger.info('Parsing flights')
        flight_base_data_factory = FlightBaseDataFactory(
            self._start_date,
            self._arrival_times_cacher,
            self._logger,
        )
        flight_dates_manager = FlightDatesManager(self._dmm)
        flight_base_id = Counter(0)
        self._double_flyouts_filter.extract_duplicate_flyouts()
        self._all_flight_legs = defaultdict(
            lambda: FlightSegmentValue(flight_base_id.inc(), flight_dates_manager.new_flight_dates())
        )

        def _parse_segment(flight_segment):
            flight_base_data = flight_base_data_factory.new_flight_base_data(flight_segment)
            if not flight_base_data:
                return
            if not flight_base_data.operating_carrier or len(flight_base_data.operating_carrier) != 2:
                self._logger.error('No carrier in flight base data: %s', flight_base_data)
                return
            if self._double_flyouts_filter.is_duplicate(flight_segment.operating_carrier, flight_segment.operating_flight_code, flight_segment.dep_date):
                self._logger.error('Duplicate flight: %s', flight_base_data)
                return
            flight_dates_manager.add_operating_day(self._all_flight_legs[flight_base_data].operating_dates, flight_segment.dep_date)
            if flight_segment.codeshare_flights:
                for flight in flight_segment.codeshare_flights.split(','):
                    flight_title_parts = [x for x in flight.split(' ') if x]
                    if len(flight_title_parts) != 2:
                        self._logger.error('incorrect flight codeshares: %s', flight_segment.__dict__)
                        return
                    flight_title = FlightTitle(flight_title_parts[0], flight_title_parts[1])
                    flight_dates_manager.add_codeshare_day(
                        self._all_flight_legs[flight_base_data].operating_dates,
                        flight_title,
                        flight_segment.dep_date,
                    )

        self._loop_over(lines_iterator, _parse_segment)
        flight_base_data_factory.log_skipped_flights()
        self._logger.info(
            'Duplicate flights count: %d, flight*days count: %d',
            self._double_flyouts_filter.flights_count(),
            self._double_flyouts_filter.flights_and_days_count(),
        )
        self._logger.info('Done parsing flights')

    def _loop_over(self, lines_iterator, loop_body_func):
        flight_segment_factory = None
        for line_index, line in enumerate(lines_iterator):
            if not line:
                continue
            if line_index == 0:
                headers = six.ensure_str(line).rstrip().split('^')
                flight_segment_factory = FlightSegmentFactory(headers)
                continue
            if not flight_segment_factory:
                raise Exception('No headers in the Amadeus data file - no way to parse it')
            if line_index % 1000000 == 0:
                self._logger.info('csv line-count: {}.'.format(line_index))
            flight_segment = flight_segment_factory.parse_flight_segment(line)
            loop_body_func(flight_segment)

    def list_flights(self, on_flight_base, on_flight_pattern, on_codeshare):
        ''' Fills in the codeshares map as a side effect '''
        if not self._all_flight_legs:
            return

        flight_pattern_id = 0
        codeshares_map = CodesharesMap()
        max_flight_base_id = 0
        for fb_key, flight_segment_value in self._all_flight_legs.items():
            flight_base_id = flight_segment_value.fb_id
            flight_dates = flight_segment_value.operating_dates
            max_flight_base_id = max(max_flight_base_id, flight_base_id)
            flight_base = self.new_flight_base(fb_key, flight_base_id)
            on_flight_base(flight_base)

            for fp_mask in self._dmm.generate_masks(flight_dates.operating_mask):
                flight_pattern_id += 1
                flight_pattern = self.new_flight_pattern(
                    flight_base_id,
                    flight_pattern_id,
                    flight_base.OperatingCarrierIata,
                    flight_base.OperatingFlightNumber,
                    flight_base.LegSeqNumber,
                    fb_key,
                    fp_mask,
                    False,
                )
                flight_pattern.DepartureDayShift = fb_key.departure_day_shift
                flight_pattern.ArrivalDayShift = fb_key.arrival_day_shift
                on_flight_pattern(flight_pattern)

            for flight_title, codeshare_mask in flight_dates.codeshares.items():
                for fp_mask in self._dmm.generate_masks(codeshare_mask):
                    flight_pattern_id += 1
                    fp_codeshare = self.new_flight_pattern(
                        flight_base_id,
                        flight_pattern_id,
                        flight_title.marketing_carrier,
                        flight_title.marketing_flight_number,
                        flight_base.LegSeqNumber,
                        fb_key,
                        fp_mask,
                        True,
                    )
                    fp_codeshare.DepartureDayShift = fb_key.departure_day_shift
                    fp_codeshare.ArrivalDayShift = fb_key.arrival_day_shift
                    on_flight_pattern(fp_codeshare)
                    codeshares_map.add_codeshare(
                        flight_title.marketing_carrier,
                        flight_title.marketing_flight_number,
                        flight_base.LegSeqNumber,
                        flight_base.OperatingCarrierIata,
                        flight_base.OperatingFlightNumber,
                        flight_base.LegSeqNumber,
                        fp_codeshare.OperatingFromDate,
                        fp_codeshare.OperatingUntilDate,
                    )
        self._logger.info('Total flight bases   : {:,}'.format(max_flight_base_id))
        self._logger.info('Total flight patterns: {:,}'.format(flight_pattern_id))

        codeshares_count = 0
        for codeshare in codeshares_map.protos():
            codeshares_count += 1
            on_codeshare(codeshare)
        self._logger.info('Total codeshares {:,}'.format(codeshares_count))

    def new_flight_base(self, fb_key, flight_base_id):
        fb = flights_pb2.TFlightBase()
        fb.Id = flight_base_id
        fb.OperatingCarrierIata = fb_key.operating_carrier
        fb.OperatingFlightNumber = fb_key.operating_flight_number
        fb.LegSeqNumber = fb_key.leg_number

        fb.DepartureStationIata = fb_key.departure_station_iata
        fb.ScheduledDepartureTime = AmadeusFlightsParser.parse_time(fb_key.departure_time)
        fb.DepartureTerminal = fb_key.departure_terminal

        fb.ArrivalStationIata = fb_key.arrival_station_iata
        fb.ScheduledArrivalTime = AmadeusFlightsParser.parse_time(fb_key.arrival_time)
        fb.ArrivalTerminal = fb_key.arrival_terminal
        fb.AircraftModel = fb_key.aircraft_model
        return fb

    def new_flight_pattern(
        self,
        flight_base_id,
        flight_pattern_id,
        marketing_carrier,
        marketing_flight_number,
        leg_number,
        fb_key,
        fp_mask,
        is_codeshare,
    ):
        fp = flights_pb2.TFlightPattern()

        fp.Id = flight_pattern_id
        fp.MarketingCarrierIata = marketing_carrier
        fp.MarketingFlightNumber = marketing_flight_number
        fp.LegSeqNumber = leg_number
        fp.FlightId = flight_base_id

        fp.OperatingFromDate = fp_mask[0].replace('.', '-')
        fp.OperatingUntilDate = fp_mask[1].replace('.', '-')
        fp.OperatingOnDays = fp_mask[2]

        fp.FlightLegKey = AmadeusFlightsParser.get_leg_key(
            fp.MarketingCarrierIata,
            fp.MarketingFlightNumber,
            fp.LegSeqNumber,
            flight_base_id,
        )

        fp.IsCodeshare = is_codeshare
        if is_codeshare:
            fp.BucketKey = AmadeusFlightsParser.get_bucket_key(
                fb_key.operating_carrier,
                fb_key.operating_flight_number,
                fb_key.leg_number,
            )
        else:
            fp.BucketKey = AmadeusFlightsParser.get_bucket_key(
                fp.MarketingCarrierIata,
                fp.MarketingFlightNumber,
                fp.LegSeqNumber,
            )

        fp.DepartureDayShift = fb_key.departure_day_shift
        fp.ArrivalDayShift = fb_key.arrival_day_shift

        return fp

    @staticmethod
    def parse_time(time_text):
        # type: (str) -> int
        time_number = int('1' + time_text.zfill(4))
        return time_number-10000

    @staticmethod
    def get_leg_key(marketing_carrier_iata, marketing_flight_number, leg_seq_number, itinerary_variation_identifier):
        return '{}.{}.{}.{}'.format(
            marketing_carrier_iata,
            marketing_flight_number.lstrip(' ').lstrip('0'),
            leg_seq_number,
            itinerary_variation_identifier,
        )

    @staticmethod
    def get_bucket_key(operating_carrier_iata, operating_flight_number, leg_seq_number):
        return '{}.{}.{}'.format(
            operating_carrier_iata,
            operating_flight_number.lstrip(' ').lstrip('0'),
            leg_seq_number,
        )


class Counter(object):

    def __init__(self, init_value):
        self._value = init_value

    def inc(self):
        self._value += 1
        return self._value
