# coding=utf-8
from __future__ import unicode_literals

import io
from collections import defaultdict, namedtuple
from datetime import datetime, timedelta

from travel.avia.shared_flights.diff_builder.utils import write_binary_string
from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_mask import DateMaskMatcher
from travel.avia.shared_flights.lib.python.date_utils.date_matcher import DateMatcher
from travel.proto.shared_flights.snapshots.p2p_cache_pb2 import TFlightKey, TP2PCacheEntry

RouteKey = namedtuple('RouteKey', ['station_from', 'station_to'])
FlightKey = namedtuple('FlightKey', ['carrier', 'flight_number'])
DateMaskValue = namedtuple('DateMaskValue', ['date_mask', 'arrival_shift', 'departure_shift'])


class P2PCache(object):
    def __init__(self, logger, start_date=None):
        self._logger = logger
        self._start_date = start_date if start_date else datetime.now() - timedelta(days=75)
        self._max_days = 440
        self._date_index = DateIndex(self._start_date)
        self._date_matcher = DateMatcher(self._date_index)
        self._dmm = DateMaskMatcher(self._date_index, self._start_date, self._max_days - 1)
        self._cache = defaultdict(
            lambda: defaultdict(lambda: defaultdict(lambda: DateMaskValue(self._dmm.new_date_mask(), 0, 0)))
        )
        self._is_operating = defaultdict(bool)

    # Collect data into { FlightKey -> { leg_number -> { RouteKey -> DateMaskValue } } } map structure
    def add_segment(self, flight_pattern, station_from, station_to):
        flight_key = FlightKey(flight_pattern.MarketingCarrier, flight_pattern.MarketingFlightNumber)
        route_key = RouteKey(station_from, station_to)
        value = self._cache[flight_key][flight_pattern.LegSeqNumber][route_key]
        self._dmm.add_range(
            flight_pattern.OperatingFromDate,
            flight_pattern.OperatingUntilDate,
            flight_pattern.OperatingOnDays,
            value.date_mask,
        )
        self._is_operating[flight_key] = self._is_operating[flight_key] or not flight_pattern.IsCodeshare
        self._cache[flight_key][flight_pattern.LegSeqNumber][route_key] = DateMaskValue(
            value.date_mask, flight_pattern.ArrivalDayShift, flight_pattern.DepartureDayShift
        )

    def generate_p2p_cache(self):
        result = defaultdict(set)

        for flight_key, flight_value in self._cache.items():
            self.add_routes(flight_key, flight_value, result)

        for route_key, flights_set in result.items():
            yield route_key, sorted(flights_set)

    def add_routes(self, flight_key, segments_cache, result):
        self.add_routes_internal(flight_key, 1, [], None, 0, segments_cache, result)

    def add_routes_internal(
        self, flight_key, leg_number, current_route, current_mask, arrival_shift, segments_cache, result
    ):
        if current_route:
            station_from = current_route[0]
            station_to = current_route[-1]
            for station in current_route:
                if station != station_from:
                    result[RouteKey(station_from, station)].add(flight_key)
                if station != station_to:
                    result[RouteKey(station, station_to)].add(flight_key)
            # Special case: flight U4 100 KTM-KTM and alike
            if len(current_route) > 1 and station_from == station_to:
                result[RouteKey(station_from, station_to)].add(flight_key)
        segments = segments_cache.get(leg_number)
        if not segments:
            return
        for route_key, mask_value in segments.items():
            mask = mask_value.date_mask
            if mask.is_empty():
                continue
            if not current_route:
                route = [route_key.station_from, route_key.station_to]
                new_mask = self._dmm.new_date_mask()
                self._dmm.add_mask(new_mask, mask)
                self.add_routes_internal(
                    flight_key, leg_number + 1, route, new_mask, mask_value.arrival_shift, segments_cache, result
                )
            else:
                if current_route[-1] != route_key.station_from:
                    continue
                new_mask = self._dmm.new_date_mask()
                self._dmm.add_mask(new_mask, current_mask)
                day_shift = max(mask_value.departure_shift, arrival_shift)
                if day_shift:
                    self._dmm.shift_days(new_mask, day_shift)
                self._dmm.intersect_mask(new_mask, mask)
                if new_mask.is_empty():
                    continue
                route = current_route.copy()
                route.append(route_key.station_to)
                self.add_routes_internal(
                    flight_key, leg_number + 1, route, new_mask, mask_value.arrival_shift, segments_cache, result
                )

    def write_to_mem_file(self):
        mem_entry = io.BytesIO()
        self._logger.info('Dumping p2p-cache')
        routes_count = 0
        flights_count = 0
        for route_key, flights_list in self.generate_p2p_cache():
            p2p_cache_entry = TP2PCacheEntry()
            p2p_cache_entry.DepartureStationId = route_key.station_from
            p2p_cache_entry.ArrivalStationId = route_key.station_to
            for flight in flights_list:
                flight_proto = TFlightKey()
                flight_proto.MarketingCarrierId = flight.carrier
                flight_proto.MarketingFlightNumber = flight.flight_number
                flight_proto.IsCodeshare = not self._is_operating[flight]
                p2p_cache_entry.Flights.extend([flight_proto])
                flights_count += 1
            write_binary_string(mem_entry, p2p_cache_entry.SerializeToString())
            routes_count += 1
        self._logger.info(
            'Done with p2p-cache, total routes: %d, total flight entries: %d', routes_count, flights_count
        )
        return mem_entry

    def print_cache(self):
        output = []
        for flight_key, flight_value in self._cache.items():
            output.append('Flight: {}'.format(flight_key))
            for leg_number, leg_value in flight_value.items():
                output.append('  Leg: {}'.format(leg_number))
                for route_key, mask_value in leg_value.items():
                    output.append('    Route: {}'.format(route_key))
                    output.append('      Arr shift: {}'.format(mask_value.arrival_shift))
                    output.append('      Dep shift: {}'.format(mask_value.departure_shift))
                    output.append('      Is codeshare: {}'.format(not self._is_operating[flight_key]))
                    for int_date in self._dmm.get_dates(mask_value.date_mask):
                        output.append('      Date: {}'.format(int_date))
        return '\n'.join(output)
