# coding: utf-8

import logging
import operator
import signal
from itertools import izip
from functools import partial

from common.utils.exceptions import SimpleUnicodeException
from travel.rasp.admin.timecorrection.duration_processing import ThreadDurationManipulation
from travel.rasp.admin.timecorrection.models_data_access import PathSpanProxy
from travel.rasp.admin.timecorrection.utils import accumulate, Constants

MIN_SPEED = 40 / Constants.MINUTES_IN_HOUR  # в минуту
MAX_SPEED = 90 / Constants.MINUTES_IN_HOUR

LENGTH_PENALTY = 100000
INVALID_SPAN_PENALTY = 10000
STATION_FUZZY_PENALTY = 1000

log = logging.getLogger(__name__)


def get_compare_diff_array(path_span_list, supplier_duration_list):
    """
    Возвращает массив чисел, где каждое число является количеством минут на которое необходимо откорректировать
    соответствующую, по порядковому номеру RTS
    :param supplier_duration_list:
    :param path_span_list:
    :return: diff-array,  len(diff-array) == len(thread.path)
    """

    diff_array = [0, ]
    for path_span, supplier_duration in izip(path_span_list, supplier_duration_list):
        if path_span.is_duration_need_correct(supplier_duration):
            diff_array.append(path_span.duration - supplier_duration)
        else:
            diff_array.append(0)

    return accumulate(diff_array)


def is_time_real(distance, duration):
    """
    Проверка реалистичности времени для данного расстояния
    :param distance: расстояние в километрах
    :param duration: время в минутах
    :return: bool
    """

    if duration == 0:
        return distance == 0
    current_speed = distance / float(duration)

    return MIN_SPEED <= current_speed <= MAX_SPEED


def objective_function(path_span_list, supplier_duration_list, new_supplier_duration_list,
                       use_station_fuzzy_penalty=False):
    """
    Целевая функция корректировки get_yaa_diff_array && get_rsa_diff_array
    :param use_station_fuzzy_penalty: Штраф за признак корректировки у станции
    :param path_span_list:
    :param supplier_duration_list:
    :param new_supplier_duration_list:
    :return:
    """

    result = 0
    if sum(supplier_duration_list) != sum(new_supplier_duration_list):
        result += LENGTH_PENALTY
    for supplier_duration, path_span in izip(new_supplier_duration_list, path_span_list):
        result += path_span.is_duration_need_correct(supplier_duration) * INVALID_SPAN_PENALTY

    if use_station_fuzzy_penalty:
        need_correction_flags = [path_span.is_duration_need_correct(supplier_duration) for path_span, supplier_duration
                                 in izip(path_span_list, supplier_duration_list)]

        for old_duration, new_duration, flag in izip(accumulate(supplier_duration_list),
                                                     accumulate(new_supplier_duration_list),
                                                     need_correction_flags):
            diff = abs(old_duration - new_duration)
            result += diff
            if not flag and round(diff):
                result += STATION_FUZZY_PENALTY
    else:
        for old_duration, new_duration in izip(accumulate(supplier_duration_list),
                                               accumulate(new_supplier_duration_list)):
            result += abs(old_duration - new_duration)

    return result


def yac(supplier_duration_array, path_span_list, target_function):
    """Yet another correction
    Корректирует нитку восстанавливая время между станциями до минимально возможного.
    Отклонение не делится.
    """
    best_array = supplier_duration_array
    best_target_function_result = target_function(path_span_list, supplier_duration_array, best_array)

    for n, path_span in enumerate(path_span_list):
        if not path_span.is_duration_need_correct(best_array[n]):
            continue

        work_array = best_array[:]
        delta = path_span.get_min_duration_to_valid(work_array[n])  # находим необходимую дельту
        work_array[n] += delta  # увеличиваем участок

        for i in xrange(len(best_array)):
            if work_array[i] <= delta:
                continue

            work_array[i] -= delta

            target_function_result = target_function(path_span_list, supplier_duration_array, work_array)
            if target_function_result < best_target_function_result:
                best_target_function_result = target_function_result
                best_array = work_array[:]

            work_array[i] += delta

    return best_array


def get_yac_diff_array(path_span_list, supplier_duration_list):
    """
    Возвращает diff алгоритма YAC
    Корректирует нитку восстанавливая время между станциями до минимально возможного.
    Отклонение не делится.
    :param supplier_duration_list:
    :param path_span_list:
    :return: diff-array,  len(diff-array) == len(thread.path)
    """
    new_supplier_duration_list = yac(supplier_duration_list, path_span_list, objective_function)
    diff = [0, ] + map(operator.sub, new_supplier_duration_list, supplier_duration_list)
    return accumulate(diff)


class BaseTimeOutTerminate(object):
    class TimeOutError(SimpleUnicodeException):
        pass

    ALGORITHM_TIMEOUT = 10

    def _do_work(self, *args, **kwargs):
        raise NotImplementedError

    def _do_work_with_timeout(self, *args, **kwargs):
        signal.signal(signal.SIGALRM, self.handler)

        signal.alarm(self.ALGORITHM_TIMEOUT)
        try:
            return self._do_work(*args, **kwargs)
        except self.TimeOutError:
            pass
        finally:
            signal.alarm(0)

    @classmethod
    def handler(cls, signum, frame):
        raise cls.TimeOutError('correction time out')


class RecursiveSearchBestDurationArray(BaseTimeOutTerminate):
    """
    Корректирует нитку восстанавливая время между станциями до минимально возможного.
    Отклонение перемещается на соседние участки. В отличии от get_yac_diff_array отклонение можно делить
    Большая вычислительная сложность
    """

    def __init__(self, supplier_duration_list, path_span_list):
        self.supplier_duration_list = supplier_duration_list
        self.path_span_list = path_span_list
        self.item_numbers_set = set(range(len(path_span_list)))
        self.objective_function = partial(objective_function, path_span_list=path_span_list,
                                          supplier_duration_list=supplier_duration_list)
        self._best_duration_list = None
        self._best_objective_function_result = 0

    def get_best_duration_combination(self):
        self._best_duration_list = self.supplier_duration_list[:]
        total_delta = 0
        corrected_path_set = set()
        for n, path_span in enumerate(self.path_span_list):
            if not path_span.is_one_country_path:
                corrected_path_set.add(n)  # чтобы не трогать участки через границу
            elif path_span.is_duration_need_correct(self._best_duration_list[n]):
                delta = path_span.get_min_duration_to_valid(self._best_duration_list[n])  # находим необходимую дельту
                self._best_duration_list[n] += delta  # увеличиваем участок
                total_delta += delta

        self._best_objective_function_result = self.objective_function(
            new_supplier_duration_list=self._best_duration_list)

        self._do_work_with_timeout(self._best_duration_list, total_delta, corrected_path_set)

        return self._best_duration_list

    def _do_work(self, duration_list, duration_delta, corrected_numbers_set):
        if duration_delta == 0 or not self.item_numbers_set ^ corrected_numbers_set:
            objective_function_result = self.objective_function(new_supplier_duration_list=duration_list)
            if objective_function_result < self._best_objective_function_result:
                self._best_objective_function_result = objective_function_result
                self._best_duration_list = duration_list
            return

        for item_number in self.item_numbers_set ^ corrected_numbers_set:
            copy_corrected_numbers_set = corrected_numbers_set.copy()
            copy_corrected_numbers_set.add(item_number)

            duration_list_copy = duration_list[:]
            if duration_delta > 0:  # участок надо уменьшать
                # на сколько можно уменьшить участок, значение >= 0
                max_delta = duration_list[item_number] - self.path_span_list[item_number].min_drive_time
                if max_delta > duration_delta:
                    duration_list_copy[item_number] -= duration_delta
                    sub_delta = 0
                elif max_delta > 0:
                    duration_list_copy[item_number] -= max_delta
                    sub_delta = duration_delta - max_delta
                else:
                    continue
            elif duration_delta < 0:
                # на сколько можно увеличить участок, значение >= 0
                min_delta = self.path_span_list[item_number].max_drive_time - duration_list[item_number]

                if min_delta > abs(duration_delta):
                    duration_list_copy[item_number] -= duration_delta
                    sub_delta = 0
                elif min_delta > 0:
                    duration_list_copy[item_number] += min_delta
                    sub_delta = duration_delta + min_delta
                else:
                    continue

            self._do_work(duration_list_copy, sub_delta, copy_corrected_numbers_set)


def get_rsa_diff_array(path_span_list, supplier_duration_list):
    """
    rsa - Recursive Search Best Duration Array
    Корректирует нитку восстанавливая время между станциями до минимально возможного.
    Отклонение перемещается на соседние участки. В отличии от get_yaa_diff_array отклонение можно делить
    Большая вычислительная сложность
    :param supplier_duration_list:
    :param path_span_list:
    :return: diff-array,  len(diff-array) == len(thread.path)
    """

    rsa = RecursiveSearchBestDurationArray(supplier_duration_list, path_span_list)
    new_supplier_duration_list = rsa.get_best_duration_combination()
    diff = [0, ] + map(operator.sub, new_supplier_duration_list, supplier_duration_list)
    return accumulate(diff)


def data_presentation_wrapper(correction_function):
    """
    Обертка для методов коррекции. Используется для отображения т.к. во view лучше передавать нитку целиком
    :param correction_function: метод коррекции
    :return: wrap(thread)
    """

    def wrap(thread):
        path_span_list = PathSpanProxy.get_pathspan_list(thread.path)
        supplier_duration_list = list(ThreadDurationManipulation.get_supplier_duration_list(thread))
        return correction_function(path_span_list, supplier_duration_list)

    return wrap


class CorrectionFunctions(object):
    COMPARE = 'compare'
    YAC = 'yac'
    RSA = 'rsa'

    _correction_functions = {
        COMPARE: get_compare_diff_array,
        YAC: get_yac_diff_array,
        RSA: get_rsa_diff_array,
    }

    def __new__(cls, correction_type):
        if correction_type not in cls._correction_functions:
            raise NotImplementedError(u'Не существующий тип корректировки %s' % correction_type)
        return data_presentation_wrapper(cls._correction_functions[correction_type])

    @classmethod
    def get_all_correction_function_types(cls):
        return set(cls._correction_functions.keys())

    @classmethod
    def select_correction_algorithm_for_view(cls, thread):
        path = list(thread.path)
        supplier_duration_list = list(ThreadDurationManipulation.get_supplier_duration_list(thread))
        return cls.select_correction_algorithm(path, supplier_duration_list)

    @classmethod
    def select_correction_algorithm(cls, path, supplier_duration_list):
        path_span_list = PathSpanProxy.get_pathspan_list(path)

        path_len = len(path)
        path_distance = sum(path_span.distance for path_span in path_span_list)
        path_duration = sum(supplier_duration_list)

        if path_len > 2 and is_time_real(path_distance, path_duration):
            return cls.YAC
        else:
            return cls.COMPARE

    @classmethod
    def get_correction_function_for_import(cls, rtstations, supplier_duration_list):
        correction_algorithm = cls.select_correction_algorithm(rtstations, supplier_duration_list)
        return cls._correction_functions[correction_algorithm]
