# -*- coding: utf-8 -*-

import logging

from django.utils.translation import ugettext as _, gettext_noop as N_

from common.utils.geo import great_circle_distance_km, GreatCircleCalculationError
from travel.rasp.admin.scripts.schedule.utils import RaspImportError


log = logging.getLogger(__name__)


def calculate_geo_distance(rtstations):
    for rts in rtstations:
        rts.geo_distance = None

    rtstations[0].geo_distance = 0.0
    last_good_rts = rtstations[0]

    for rts in rtstations[1:]:
        try:
            rts.geo_distance = last_good_rts.geo_distance + great_circle_distance_km(last_good_rts.station, rts.station)
            last_good_rts = rts

        except GreatCircleCalculationError:
            pass

    correct_distance(rtstations, 'geo_distance')


def calculate_geo_and_combine_distance(rtstations, params):
    calculate_geo_distance(rtstations)

    ratio_min = params['ratio_s_min']
    ratio_max = params['ratio_s_max']

    for rts in rtstations:
        rts.combine_distance = None

    rtstations[0].combine_distance = 0
    prev_rts = rtstations[0]

    for i, rts in enumerate(rtstations[1:], 1):
        if rts.geo_distance is None:
            raise RaspImportError(_(u"Нет geo расстояния для станции {}").format(rts.station.title))

        geo_distance_from_prev_rts = rts.geo_distance - prev_rts.geo_distance

        if rts.distance is None:
            rts.combine_distance = prev_rts.combine_distance + geo_distance_from_prev_rts

        else:
            temp_combine_distance = prev_rts.combine_distance + geo_distance_from_prev_rts
            min_distance = ratio_min * temp_combine_distance
            max_distance = ratio_max * temp_combine_distance

            if min_distance <= rts.distance <= max_distance:
                rts.combine_distance = rts.distance

            else:
                rts.combine_distance = temp_combine_distance

            if rts.combine_distance < prev_rts.combine_distance:
                log.info(
                    N_(
                        u"Комбинированное расстояние '%s' для станции '%s' получилось меньше, "
                        u"чем расстояние '%s' для предыдущей. Переделываем."
                    ),
                    rts.combine_distance,
                    rts.station.title,
                    prev_rts.combine_distance
                )
                rts.combine_distance = prev_rts.combine_distance + geo_distance_from_prev_rts

        prev_rts = rts


def correct_supplier_distance(rtstations):
    correct_distance(rtstations, 'distance')


CONST_DISTANCE = 5.


def correct_distance(rtstations, distance_attr):
    def distance_func(x):
        return getattr(x, distance_attr)

    rts_invalid_ranges, rts_invalid_tail = get_invalids_inside_and_tail(rtstations, distance_func)

    for rts_range in rts_invalid_ranges:
        distance_part = float(rts_range.get_all_distance()) / (len(rts_range) + 1)

        for i in rts_range.indexes():
            distance = distance_func(rtstations[i - 1]) + distance_part
            setattr(rtstations[i], distance_attr, distance)

    for i in rts_invalid_tail.indexes():
        distance = distance_func(rtstations[i - 1]) + CONST_DISTANCE
        setattr(rtstations[i], distance_attr, distance)


def get_invalids_inside_and_tail(rtstations, distance_func):
    """
    Функция выделяет невалидные интервалы расстояний
    Расстояния валидны, если они указаны (не None) и не убывают.
    """

    invalids_lists = []
    invalids_now = []

    last_good_rts = None

    for i, rts in enumerate(rtstations):
        if distance_func(rts) is None:
            invalids_now.append(i)

        elif last_good_rts and distance_func(last_good_rts) > distance_func(rts):
            log.error(
                N_(
                    u"Расстояние '%s' для станции '%s' меньше, "
                    u"чем расстояние '%s' для станции '%s'."
                ),
                distance_func(rts),
                rts.station.title,
                distance_func(last_good_rts),
                last_good_rts.station.title
            )

            invalids_now.append(i)

        else:
            last_good_rts = rts

            if invalids_now:
                invalids_lists.append(invalids_now)
                invalids_now = []

    invalids_tail = invalids_now

    rts_invalid_ranges = []
    for invalids in invalids_lists:
        rts_invalid_ranges.append(RTSInvalidRange(rtstations, invalids[0], invalids[-1], distance_func))

    return rts_invalid_ranges, RTSInvalidTail(rtstations, invalids_tail)


class RTSInvalidRange(object):
    def __init__(self, rtstations, first_index, last_index, distance_func):
        self.rtstations = rtstations
        self.first_index = first_index
        self.last_index = last_index

        self.good_before = first_index - 1
        self.good_after = last_index + 1

        self.good_before_rts = rtstations[self.good_before]
        self.good_after_rts = rtstations[self.good_after]

        self.distance_func = distance_func

    def __len__(self):
        return self.good_after - self.first_index

    def indexes(self):
        return xrange(self.first_index, self.good_after)

    def get_all_distance(self):
        return self.distance_func(self.good_after_rts) - self.distance_func(self.good_before_rts)


class RTSInvalidTail(object):
    def __init__(self, rtstations, indexes):
        self.rtstations = rtstations

        self.first_index = len(self.rtstations)
        if indexes:
            self.first_index = indexes[0]

    def indexes(self):
        return xrange(self.first_index, len(self.rtstations))
