import logging

import numpy as np
import scipy
from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2
from sklearn.metrics.pairwise import cosine_similarity

from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock

logger = logging.getLogger(__name__)
from jafar.pipelines.misc import extract_feature_names
SOLUTION_LIMIT=50


class PermutationBlock(SingleContextBlock):
    def __init__(self, features, config_frame=ids.FRAME_KEY_RANGE,
                 predictions_column='value', input_data=ids.FRAME_KEY_PREDICTIONS):
        self.config_frame = config_frame
        self.features = features
        self.predictions_column = predictions_column
        super(PermutationBlock, self).__init__(
            input_data=[input_data], output_data=None, destroyed_data=None
        )

    @staticmethod
    def get_route(manager, routing, assignment):
        index = routing.Start(0)
        route_distances = []
        route = []
        while not routing.IsEnd(index):
            route.append(manager.IndexToNode(index))
            previous_index = index
            index = assignment.Value(routing.NextVar(index))
            route_distances.append(routing.GetArcCostForVehicle(previous_index, index, 0))
        return route_distances, route

    @staticmethod
    def normalize_matrix(distance_matrix):
        n = distance_matrix.shape[0]
        distance_matrix[np.arange(n), np.arange(n)] = 0
        distance_matrix = np.round(distance_matrix).astype(np.int)
        return distance_matrix.tolist()

    @staticmethod
    def get_search_parameters():
        search_parameters = pywrapcp.DefaultRoutingSearchParameters()
        search_parameters.local_search_metaheuristic = (
            routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH)
        search_parameters.solution_limit = SOLUTION_LIMIT
        return search_parameters

    def find_best_permutation(self, X):
        """find hamilton path in the full graph with weights based on similarity of ALS embeddings"""
        n = X.shape[0]
        # invert closeness matrix to perform search of smallest hamilton path
        distance_matrix = (1 - cosine_similarity(X))*100
        distance_matrix = self.normalize_matrix(distance_matrix)
        manager = pywrapcp.RoutingIndexManager(n, 1, 0)
        routing = pywrapcp.RoutingModel(manager)

        def distance_callback(from_index, to_index):
            from_node = manager.IndexToNode(from_index)
            to_node = manager.IndexToNode(to_index)
            return distance_matrix[from_node][to_node]

        transit_callback_index = routing.RegisterTransitCallback(distance_callback)
        routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

        assignment = routing.SolveWithParameters(self.get_search_parameters())
        if assignment is None:
            logger.debug('No assignment found')
            return np.arange(n)  # leave order the same
        values, route = self.get_route(manager, routing, assignment)
        route = np.roll(route, - np.argmax(values) - 1)  # to split the weakest edge
        return route

    def apply(self, context, train):
        if train or self.config_frame not in context.data:
            logger.debug('Rearrangement was not performed')
            return context
        range_pairs = context.data[self.config_frame]
        predictions = context.data[self.input_data[0]]
        extracted_features = extract_feature_names(predictions, [self.features])
        if not len(extracted_features):
            logger.debug('Rearrangement was not performed')
            return context
        score = np.zeros(predictions.shape[0])
        for (key, user_index) in predictions.arggroupby('user'):  # user_index stands for users list of applications
            ranks = scipy.stats.rankdata(predictions[user_index]['value'], method='ordinal') - 1
            score[user_index] = ranks
            features = predictions[user_index][extracted_features].to_2d_array()
            for pair in range_pairs.to_list_of_dicts():
                if pair['right'] - pair['left'] <= 1 or pair['right'] < 0 or pair['left'] < 0:
                    logger.debug('Pair set incorrectly %d : %d', pair['left'], pair['right'])
                    continue
                # invert range for rearrange
                left, right = (max(ranks) - pair['right'] + 1, max(ranks) - pair['left'] + 1)
                # range index indicates whether user application falls into current range of ranks
                range_index = (ranks >= left) & (ranks < right)
                if not np.any(range_index):
                    logger.debug('Pair is out of range %d : %d', pair['left'], pair['right'])
                    continue
                permutation_index = self.find_best_permutation(features[range_index])
                # permutation_index rearranges applications in current range of ranks according to their neighbourship
                score[user_index[range_index][permutation_index]] = np.arange(max(left, 0), right)
        context.data[self.input_data[0]] = predictions.append_column(score, self.predictions_column, replace=True)
        return context
