# coding: utf-8

from __future__ import absolute_import

import logging
from collections import defaultdict

from django.utils.functional import cached_property

from common.models.geo import Station
from common.models.schedule import Route, RThread, RThreadType, TrainSchedulePlan, RTStation
from common.utils.caching import cache_until_switch
from common.utils.date import RunMask
from travel.rasp.admin.scripts.support_methods import esr_leading_zeros

from travel.rasp.admin.scripts.schedule.af_processors.suburban.exceptions import AfSuburbanProcessorError
from travel.rasp.admin.scripts.schedule.af_processors.suburban.utils import str_dates


log = logging.getLogger(__name__)


def get_affected_threads_by_basic_thread(thread, thread_el, today):
    """
    Ищем все нитки, с которыми есть пересечение по дням хождения.
    Нитки идентифицируются по номеру и начальной станции.
    """
    return AffectedThreadsByBasicThreadBuilder(thread, thread_el, today).build()


class AffectedThreadsByBasicThreadBuilder(object):
    def __init__(self, thread, thread_el, today):
        self.thread = thread
        self.thread_el = thread_el
        self.today = today

    def build(self):
        affected_threads = self.find_affected_threads()

        return self.group_by_basic_thread(affected_threads)

    def find_affected_threads(self):
        db_threads = self.find_threads()
        self.set_mask(db_threads)

        affected_threads = [db_thread for db_thread in db_threads if self.thread.mask & db_thread.mask]
        if not affected_threads:
            raise AfSuburbanProcessorError(
                u'Не нашли ни одной нитки number="%s" startstationparent="%s" совпадающей по дням хождения %s',
                self.thread.number, self.start_station_esr, u', '.join(map(str, self.thread.mask.dates())))
        return affected_threads

    @cached_property
    def start_station_esr(self):
        return esr_leading_zeros(self.thread_el.get('startstationparent', '0'))

    @cached_property
    def start_station(self):
        try:
            return Station.get_by_code('esr', self.start_station_esr)
        except Station.DoesNotExist:
            raise AfSuburbanProcessorError(u"Не нашли начальной станции с esr %s", self.start_station_esr)

    def possible_threads_qs(self):
        route_uids = generate_possible_route_uids(self.thread.number, self.thread.supplier, self.start_station)
        route_ids = list(Route.objects.filter(route_uid__in=route_uids).values_list('id', flat=True))
        return RThread.objects.filter(route_id__in=route_ids)

    def get_basic_threads(self):
        basic_db_thread = list(self.possible_threads_qs().filter(type=RThreadType.BASIC_ID))
        if not basic_db_thread:
            raise AfSuburbanProcessorError(u"Не нашли основных ниток с номером %s стартующих со станции %s %s",
                                           self.thread_el.get('number'), self.start_station.id,
                                           self.start_station.title)
        return basic_db_thread

    def find_threads(self):
        db_threads = list(self.possible_threads_qs())
        if not db_threads:
            raise AfSuburbanProcessorError(u"Не нашли ниток с номером %s стартующих со станции %s %s",
                                           self.thread_el.get('number'), self.start_station.id,
                                           self.start_station.title)
        return db_threads

    def group_by_basic_thread(self, affected_threads):
        affected_threads_by_basic_thread = defaultdict(list)

        basic_thread_getter = BasicThreadGetter(affected_threads, self.today)
        for db_thread in affected_threads:
            affected_threads_by_basic_thread[basic_thread_getter(db_thread)].append(
                (db_thread, self.thread.mask & db_thread.mask)
            )

        if len(affected_threads_by_basic_thread) > 1:
            log.warning(u'Изменение number="%s" startstationparent="%s" %s задевает несколько базовых ниток!!!',
                        self.thread.number, self.start_station_esr, str_dates(self.thread.mask.dates()))

        return affected_threads_by_basic_thread

    def set_mask(self, db_threads):
        for db_thread in db_threads:
            set_mask(db_thread, self.today)


class BasicThreadGetter(object):
    """
    RASPADMIN-812 Питоновский объект, возвращаемый в __call__,
    должен совпадать с тем, что лежит в threads
    """

    def __init__(self, threads, today):
        self.threads_by_id = {t.id: t for t in threads}
        self.today = today

    def __call__(self, thread):
        basic_thread_id = thread.basic_thread_id or thread.id
        basic_thread = self.threads_by_id.get(basic_thread_id, RThread.objects.get(id=basic_thread_id))
        set_mask(basic_thread, self.today)

        if basic_thread.type_id != RThreadType.BASIC_ID:
            if basic_thread == thread:
                raise AfSuburbanProcessorError(u"Нитка %s не имеет базовой", thread.uid)
            else:
                raise AfSuburbanProcessorError(u"У нитки %s базовая нитка не имеет базового типа", thread.uid)

        return basic_thread


@cache_until_switch
def _get_schedule_plans():
    return list(TrainSchedulePlan.objects.all())


def generate_possible_route_uids(number, supplier, start_station):
    route_uids = []

    rtstations = [RTStation(station=start_station)]
    for schedule_plan in [None] + _get_schedule_plans():
        temp_thread = RThread(number=number, supplier=supplier)
        temp_thread.rtstations = rtstations
        temp_thread.schedule_plan = schedule_plan
        route_uids.append(temp_thread.gen_route_uid(use_start_station=True, thread_type_check=False))

    return route_uids


def set_mask(db_thread, today):
    if not isinstance(getattr(db_thread, 'mask', None), RunMask):
        db_thread.mask = db_thread.get_mask(today=today)
