# coding=utf-8
from __future__ import unicode_literals

import pytz

from datetime import datetime, timedelta
from travel.proto.shared_flights.snapshots.station_with_codes_pb2 import TStationWithCodes
from typing import Dict


class DateShiftCalculator(object):

    def __init__(self, stations: Dict[int, TStationWithCodes], timezones, logger):
        self.logger = logger
        self.timezones = timezones
        self.stations: Dict[int, TStationWithCodes] = stations

    def calculate_arrival_day_shift(self, flight_pattern_proto, flight_base_proto):
        departure_station = self.stations.get(flight_base_proto.DepartureStation)
        if not departure_station:
            self.logger.error('Missing departure station: %s', flight_base_proto)
            return 0

        departure_time_zone_code = self.timezones.get(departure_station.Station.TimeZoneId)
        if not departure_time_zone_code:
            self.logger.error('Missing departure time zone: %s', departure_station)
            return 0

        departure_date = flight_pattern_proto.OperatingFromDate.replace('-', '').replace('.', '')
        if len(departure_date) != 8:
            self.logger.error('Invalid departure date: %s', flight_pattern_proto)
            return 0

        departure_time = flight_base_proto.ScheduledDepartureTime
        if departure_time < 0 or departure_time > 2400:
            self.logger.error('Invalid departure time: %s', flight_base_proto)
            return 0

        arrival_time = flight_base_proto.ScheduledArrivalTime
        if arrival_time < 0 or arrival_time > 2400:
            self.logger.error('Invalid arrival time: %s', flight_base_proto)
            return 0

        arrival_station = self.stations.get(flight_base_proto.ArrivalStation)
        if not arrival_station:
            self.logger.error('Missing arrival station: %s', flight_base_proto)
            return 0

        arrival_time_zone_code = self.timezones.get(arrival_station.Station.TimeZoneId)
        if not arrival_time_zone_code:
            self.logger.error('Missing departure time zone: %s', arrival_station)
            return 0

        departure = pytz.timezone(departure_time_zone_code).localize(
            datetime(
                int(departure_date[0:4]),
                int(departure_date[4:6].lstrip('0')),
                int(departure_date[6:].lstrip('0')),
                int(departure_time / 100),
                departure_time % 100,
            )
        )

        arrival = pytz.timezone(arrival_time_zone_code).localize(
            datetime(
                int(departure_date[0:4]),
                int(departure_date[4:6].lstrip('0')),
                int(departure_date[6:].lstrip('0')),
                int(arrival_time / 100),
                arrival_time % 100,
            )
        )
        departure_in_arrival_tz = departure.astimezone(pytz.timezone(arrival_time_zone_code))
        tomorrow_in_arrival_tz = departure_in_arrival_tz + timedelta(days=1)

        current_shift = 0
        while arrival <= departure_in_arrival_tz:
            current_shift += 1
            arrival = arrival + timedelta(days=1)

        if not current_shift:
            while arrival >= tomorrow_in_arrival_tz:
                arrival = arrival - timedelta(days=1)

        updated_departure_utc = datetime(
            departure.year,
            departure.month,
            departure.day,
            0,
            0,
            0,
            0,
            tzinfo=pytz.utc,
        )

        updated_arrival_utc = datetime(
            arrival.year,
            arrival.month,
            arrival.day,
            0,
            0,
            0,
            0,
            tzinfo=pytz.utc,
        )

        return (updated_arrival_utc - updated_departure_utc).days
