# coding: utf-8
from __future__ import absolute_import, unicode_literals

import logging
from operator import attrgetter

from django.db.models import Q

from travel.rasp.admin.scripts.suburban_stops import walk_thread
from travel.rasp.admin.scripts.z_station_schedule import AFSuburbanProcessor


log = logging.getLogger(__name__)


def update_station_schedule(thread, recount_schedule_on_the_fly):
    if not recount_schedule_on_the_fly:
        return

    class TmpProcessor(AFSuburbanProcessor):
        def get_thread_filter(self):
            filter_ = super(TmpProcessor, self).get_thread_filter()

            return Q(pk=thread.id) & filter_

    schedule_processor = TmpProcessor()
    schedule_processor.run()
    walk_thread(thread, recount_schedule_on_the_fly=True)


def str_dates(dates):
    return ', '.join(map(str, dates))


ATTRS_TO_COMPARE_RTSTATIONS = [
    'station_id',
    'tz_arrival',
    'tz_departure',
    'platform',
    'is_combined',
    'is_searchable_from',
    'is_searchable_to',
    'in_station_schedule'
]


def threads_isequal(thread1, thread2):
    rts_attr_getters = [attrgetter(a) for a in ATTRS_TO_COMPARE_RTSTATIONS]
    if thread1.tz_start_time != thread2.tz_start_time:
        return False

    if thread1.is_combined != thread2.is_combined:
        return False

    rtstations1 = getattr(thread1, 'rtstations', None) or thread1.path
    rtstations2 = getattr(thread2, 'rtstations', None) or thread2.path

    for rts1, rts2 in zip(rtstations1, rtstations2):
        if any(ag(rts1) != ag(rts2) for ag in rts_attr_getters):
            return False

    return True


def copy_thread_attributes(new_thread, thread):
    for attr in ['t_type', 'supplier', 'number', 'company',
                 'express_type', 'express_lite', 'schedule_plan', 't_subtype']:

        setattr(new_thread, attr, getattr(thread, attr))


def path_is_equal(thread, db_thread):
    if len(thread.rtstations) != len(db_thread.path):
        return False

    for rts_parsed, rts_db in zip(thread.rtstations, db_thread.path):
        if rts_parsed.station != rts_db.station:
            return False

    return True


def copy_path(thread, db_thread):
    copied_attrs = ('id', 'thread_id', 'arrival_direction_id', 'departure_direction_id',
                    'arrival_subdir', 'departure_subdir', 'is_from_subdir')
    for rts_parsed, rts_db in zip(thread.rtstations, db_thread.path):
        for attr in copied_attrs:
            setattr(rts_parsed, attr, getattr(rts_db, attr))

        rts_parsed.save()
