# coding: utf-8

from __future__ import absolute_import

import logging
import pytz
from datetime import timedelta, datetime

from common.models.schedule import RThread
from stationschedule.models import ZTablo2

from travel.rasp.admin.scripts.schedule.af_processors.suburban.affected_threads_finder import get_affected_threads_by_basic_thread


log = logging.getLogger(__name__)


def apply_changemode_delay(thread, thread_el, today, **kwargs):
    if thread.uid:
        try:
            db_thread = RThread.objects.get(uid=thread.uid)
        except RThread.DoesNotExist:
            log.info(u'Thread was not found by uid %s', thread.uid)
            return
        db_threads = [db_thread]

    else:
        basic_groups = get_affected_threads_by_basic_thread(thread, thread_el, today)

        db_threads = []
        for thread_mask_pairs in basic_groups.values():
            db_threads.extend([t for t, _m in thread_mask_pairs])

    dates = thread.mask.dates(past=True)

    for db_thread in db_threads:
        if thread.raw_stations:
            apply_delay_by_station(db_thread, dates, thread.raw_stations, today)

        elif thread_el.get('delay_type') in ('cancel', 'delay'):
            apply_thread_delay(db_thread, thread_el, dates, today=today)


def apply_delay_by_station(db_thread, dates, raw_stations, today):
    def get_comments(station_el):
        comments = {}
        for attr, value in station_el.attrib.items():
            if attr.startswith('delay_comment'):
                comments[attr.replace('delay_', '')] = value

        return comments

    mask = db_thread.get_mask(today=today)

    for day in dates:
        if day not in mask:
            continue

        naive_start_dt = datetime.combine(day, db_thread.tz_start_time)

        # Применяем изменения к станциям, в порядке следования
        tmp_raw_stations = list(raw_stations)
        next_station, next_station_el = tmp_raw_stations.pop(0)

        for rts in db_thread.path:
            if not rts.station == next_station:
                continue

            _station, station_el = next_station, next_station_el  # noqa

            if tmp_raw_stations:
                next_station, next_station_el = tmp_raw_stations.pop(0)
            else:
                next_station, next_station_el = None, None

            if station_el.get('delay_type', 'none') not in ('cancel', 'delay'):
                # Remove delay
                apply_rts_delay(naive_start_dt, db_thread, rts, arrival_delay=0, departure_delay=0)
                continue

            delay_type = station_el.get('delay_type')
            comments = get_comments(station_el)

            if delay_type == 'delay':
                arrival_delay = int(station_el.get('delay_arrival', '0'))
                departure_delay = int(station_el.get('delay_departure', '0'))

                apply_rts_delay(naive_start_dt, db_thread, rts,
                                arrival_delay=arrival_delay, departure_delay=departure_delay,
                                comments=comments)
            elif delay_type == 'cancel':
                apply_rts_delay(naive_start_dt, db_thread, rts,
                                arrival_cancelled=True, departure_cancelled=True,
                                comments=comments)
            else:
                raise ValueError('Invalid delay type %s' % delay_type)


def apply_thread_delay(db_thread, thread_el, dates, today, raw_stations=[]):
    delay_type = thread_el.get('delay_type')

    if delay_type == 'delay':
        delay_value = int(thread_el.get('delay_value'))
    elif delay_type == 'cancel':
        delay_value = None
    else:
        raise ValueError('Invalid delay type %s' % delay_type)

    comments = {}
    for attr, value in thread_el.attrib.items():
        if attr.startswith('delay_comment'):
            comments[attr.replace('delay_', '')] = value

    mask = db_thread.get_mask(today=today)

    for day in dates:
        if day not in mask:
            continue

        naive_start_dt = datetime.combine(day, db_thread.tz_start_time)

        for rts in db_thread.path:
            log.info(u'Применяем изменения для станции %s delay=%s type=%s',
                     rts.station.title, delay_value, delay_type)
            if delay_type == 'delay':
                apply_rts_delay(naive_start_dt, db_thread, rts,
                                arrival_delay=delay_value, departure_delay=delay_value,
                                comments=comments)
            elif delay_type == 'cancel':
                apply_rts_delay(naive_start_dt, db_thread, rts,
                                arrival_cancelled=True, departure_cancelled=True,
                                comments=comments)


def apply_rts_delay(naive_start_dt, db_thread, rts,
                    arrival_delay=None, departure_delay=None,
                    arrival_cancelled=False, departure_cancelled=False,
                    comments=None):
    original_arrival = rts.get_loc_arrival_dt(naive_start_dt)
    original_departure = rts.get_loc_departure_dt(naive_start_dt)

    if original_arrival:
        original_arrival = original_arrival.replace(tzinfo=None)
        range_filter = {'original_arrival__range': (original_arrival - timedelta(hours=3),
                                                    original_arrival + timedelta(hours=3))}
    if original_departure:
        original_departure = original_departure.replace(tzinfo=None)
        range_filter = {'original_departure__range': (original_departure - timedelta(hours=3),
                                                      original_departure + timedelta(hours=3))}

    ZTablo2.objects.filter(**range_filter).filter(rtstation=rts).delete()

    # Если изменений нет, неичего не делаем
    if not (
        arrival_delay or departure_delay or
        arrival_cancelled or departure_cancelled or
        comments
    ):
        return

    z_tablo = ZTablo2()
    z_tablo.thread = db_thread
    z_tablo.route = db_thread.route
    z_tablo.station = rts.station
    z_tablo.platform = rts.platform
    z_tablo.company_id = db_thread.company_id

    z_tablo.route_id = db_thread.route.id
    z_tablo.route_uid = db_thread.route.route_uid
    z_tablo.number = db_thread.number

    z_tablo.t_type_id = db_thread.t_type_id
    z_tablo.title = db_thread.title

    z_tablo.original_arrival = original_arrival
    z_tablo.original_departure = original_departure
    z_tablo.arrival = z_tablo.original_arrival
    z_tablo.departure = z_tablo.original_departure

    z_tablo.arrival_cancelled = arrival_cancelled
    z_tablo.departure_cancelled = departure_cancelled

    z_tablo.start_datetime = naive_start_dt
    z_tablo.utc_start_datetime = db_thread.pytz.localize(naive_start_dt).\
        astimezone(pytz.UTC).replace(tzinfo=None)

    z_tablo.is_fuzzy = rts.is_fuzzy

    if original_arrival and arrival_delay:
        z_tablo.real_arrival = original_arrival + timedelta(minutes=arrival_delay)

    if original_departure and departure_delay:
        z_tablo.real_departure = original_departure + timedelta(minutes=departure_delay)

    if comments:
        for attr, value in comments.items():
            setattr(z_tablo, attr, value)

    z_tablo.save()
