# coding: utf-8

from datetime import timedelta, datetime

from common.models.schedule import RThread
from common.models.tariffs import ThreadTariff
from travel.rasp.library.python.common23.date import environment
from common.utils.date import timedelta2minutes, RunMask
from travel.rasp.admin.lib.logs import get_current_file_logger
from travel.rasp.admin.scripts.schedule.utils.times_approximation import fill_middle_times, fill_last_stations_using_middle_speed
from travel.rasp.admin.scripts.schedule.utils.route_compare import get_path_relation
from travel.rasp.admin.www.utils.mysql import fast_delete_tariffs_by_uids


log = get_current_file_logger()


class DuplicateCleaner(object):
    def __init__(self, routes, tariffs=None, distances=None, middle_speed=50):
        self.middle_speed = middle_speed
        self.routes = routes
        self.tariffs = tariffs or {}
        self.distances = distances or {}
        self.remove_full_duplicates_flag = True
        # Если поднитка дубликат, не должна отличаться больше чем на 15 минут
        self.max_sub_start_diff_from_other_station = 15
        self.override_fuzzy_times_from_smaller_thread = True

    def clean(self):
        log.info(u"Удаляем дубликаты")

        threads = list(RThread.objects.filter(route__in=self.routes))

        for thread in threads:
            thread.rtstations = list(thread.rtstation_set.all().select_related('station').order_by('id'))
            for rts in thread.rtstations[1:-1]:
                rts.real_fuzzy = True

            thread.station_path = [rts.station for rts in thread.rtstations]

        threads.sort(key=lambda t: len(t.station_path))

        while threads:
            thread = threads.pop(0)

            for bigger_thread in threads:
                relation = get_path_relation(thread.station_path, bigger_thread.station_path)

                if relation.code in (relation.SUB_PATH, relation.EQUAL):
                    merged = self.merge_threads(thread, bigger_thread)

                    if merged and not thread.mask:
                        log.info(u"У нитки %s не осталось дней хождений удаляем", thread.uid)

                        thread.delete()
                        break
                else:
                    log.debug(u"Нитки %s %s относятся как %s пропускаем", thread.id, bigger_thread.id, relation)

    def merge_threads(self, thread, bigger_thread):
        from travel.rasp.admin.importinfo.two_stage_import.tariffs import gen_tariffs_from_variants

        log.debug(u"Пробуем слить нитки %s %s, %s %s", thread.title, thread.uid, bigger_thread.title,
                  bigger_thread.uid)

        sub_start_station = thread.station_path[0]

        bigger_mask = RunMask(bigger_thread.year_days, today=environment.today())
        mask = RunMask(thread.year_days, today=environment.today())

        sub_start_rts, sub_start_index = self.get_sub_start(sub_start_station, bigger_thread)

        bigger_start_date = bigger_mask.dates()[0]
        bigger_start_dt = datetime.combine(bigger_start_date, bigger_thread.start_time)

        sub_start_dt = bigger_start_dt + timedelta(minutes=sub_start_rts.departure)
        smaller_start_dt = datetime.combine(sub_start_dt.date(), thread.start_time)

        sub_start_diff = abs(timedelta2minutes(smaller_start_dt - sub_start_dt))
        if sub_start_diff > self.max_sub_start_diff_from_other_station:
            log.debug(u"Разница старта поднитки больше %s минут пропускаем", self.max_sub_start_diff_from_other_station)
            return False

        elif sub_start_index == 0 and sub_start_diff > 0:
            log.debug(u"Нитки странтуют с одной станции, но время старта не совпадает, пропускаем")
            return False

        shift = (sub_start_dt.date() - bigger_start_dt.date()).days
        sub_mask = mask.shifted(shift)
        if not (sub_mask & bigger_mask):
            log.debug(u"Маски не пересекаются пропускаем")
            return False

        log.info(u"Сливаем нитки %s %s , %s %s", thread.title, thread.uid, bigger_thread.title,
                 bigger_thread.uid)

        sub_stop_duration = sub_start_rts.departure - (sub_start_rts.arrival or 0)
        sub_start_rts.departure = timedelta2minutes(smaller_start_dt - bigger_start_dt)
        if sub_start_rts.is_fuzzy and sub_start_index != 0:
            sub_start_rts.arrival = sub_start_rts.departure - sub_stop_duration
            sub_start_rts.is_fuzzy = False

        intersection_length = len(thread.rtstations)

        departure_offset = sub_start_rts.departure

        for smaller_rts, bigger_rts in\
            zip(thread.rtstations[1:],
                bigger_thread.rtstations[sub_start_index + 1:intersection_length - 1]):

            if self.override_fuzzy_times_from_smaller_thread:
                if not smaller_rts.is_fuzzy and bigger_rts.is_fuzzy:
                    bigger_rts.departure = smaller_rts.departure + departure_offset
                    bigger_rts.arrival = smaller_rts.arrival + departure_offset
                    bigger_rts.is_fuzzy = False

        for rts, d in zip(bigger_thread.rtstations, self.distances[bigger_thread.uid]):
            if rts.is_fuzzy:
                rts.departure = None
                rts.arrival = None

            rts.distance = d

        fill_last_stations_using_middle_speed(bigger_thread.rtstations, self.middle_speed)

        fill_middle_times(bigger_thread.rtstations, True)

        smaller_thread_variants = self.tariffs.get(thread.uid)
        bigger_thread_variants = self.tariffs.get(bigger_thread.uid)

        if smaller_thread_variants and bigger_thread_variants:
            united_tariff_variants = bigger_thread_variants
            for tariff_variant, mask in smaller_thread_variants.iteritems():
                if tariff_variant in united_tariff_variants:
                    united_tariff_variants[tariff_variant] |= mask
                else:
                    united_tariff_variants[tariff_variant] = mask

        elif smaller_thread_variants:
            united_tariff_variants = smaller_thread_variants
        else:
            united_tariff_variants = None

        for rts in bigger_thread.rtstations:
            rts.save()

        if united_tariff_variants:
            fast_delete_tariffs_by_uids([bigger_thread.uid])

            tariffs = []
            for variant, mask in united_tariff_variants.iteritems():
                tariffs.extend(gen_tariffs_from_variants(bigger_thread, variant, mask))

            ThreadTariff.objects.bulk_create(tariffs)

        log.info(u"Смержили нитку %s %s в %s %s, очищаем ее",
                 thread.uid, thread.title, bigger_thread.uid, bigger_thread.title)

        thread.mask = mask - sub_mask
        if thread.mask:
            thread.year_days = str(thread.mask)
            thread.save()

        return True

    def get_sub_start(self, sub_start_station, bigger_thread):
        sub_start_rts, sub_start_index = None, None
        for index, rts in enumerate(bigger_thread.rtstations):
            if rts.station == sub_start_station:
                sub_start_rts = rts
                sub_start_index = index
                break

        return sub_start_rts, sub_start_index
