# coding: utf-8

import datetime

from common.utils.date import timedelta2minutes
from travel.rasp.admin.timecorrection.utils import n_wize


class ThreadDurationManipulation(object):
    def __init__(self, thread):
        self.thread = thread
        self._naive_start_dt = self.get_naive_start_dt(thread)
        self._first_rts_departure_dt = None

    def get_shift_and_stop_between_rts(self, rts_from, rts_to):
        """
        stop time on rts_from, moving time between rts_from and rts_to
        :type rts_to: common.models.schedule.RTStation
        :type rts_from: common.models.schedule.RTStation
        """

        if rts_from.tz_arrival is None:
            shift_time = rts_to.get_arrival_dt(self._naive_start_dt) - rts_from.get_departure_dt(
                self._naive_start_dt)
            stop_time = 0
        else:
            shift_time = rts_to.get_arrival_dt(self._naive_start_dt) - rts_from.get_arrival_dt(
                self._naive_start_dt)
            stop_time = rts_from.tz_departure - rts_from.tz_arrival

        shift_time = timedelta2minutes(shift_time)

        if stop_time == 1:
            stop_time = 0
        else:
            shift_time -= stop_time

        return shift_time, stop_time

    def get_arrival_shift_in_current_tz(self, shift, rts):
        """
        Возвращает время в минутах от старта нитки с поправкой на TZ станции

         :param shift: Время от начальной станции нитки в TZ старта
         :type rts: common.models.schedule.RTStation

                 Start(10:00 GMT+2)  shift 60 мин. Arrival(10:00 GMT+1)
                 return 0
        """
        if self.thread.pytz == rts.pytz:
            return shift

        if self._first_rts_departure_dt is None:
            first_rts = self.thread.path.first()
            self._first_rts_departure_dt = first_rts.get_departure_dt(self._naive_start_dt)

        arrival_dt = self._first_rts_departure_dt + datetime.timedelta(minutes=shift)
        arrival_dt = arrival_dt.astimezone(rts.pytz)
        arrival_dt = arrival_dt.replace(tzinfo=self._first_rts_departure_dt.tzinfo)
        return timedelta2minutes(arrival_dt - self._first_rts_departure_dt)

    @staticmethod
    def get_naive_start_dt(thread):
        first_run = thread.first_run(datetime.datetime.now().date())
        return datetime.datetime.combine(first_run, thread.tz_start_time)

    @staticmethod
    def calc_timedelta_from_start(start_rts, current_rts, naive_start_dt):
        if start_rts != current_rts:
            return current_rts.get_arrival_dt(naive_start_dt) - start_rts.get_departure_dt(naive_start_dt)

    @classmethod
    def get_supplier_duration_list(cls, thread):
        time_manipulation = cls(thread)
        return (time_manipulation.get_shift_and_stop_between_rts(rts_from, rts_to)[0]
                for rts_from, rts_to in n_wize(thread.path))
