import logging

import numpy as np
from scipy import sparse
from sklearn.neighbors import NearestNeighbors

from jafar.estimators.base import BaseEstimator
from jafar.similarity_metrics import cosine
from jafar.storages.exceptions import StorageKeyError

logging.getLogger('yt.packages.requests').setLevel(logging.WARN)
logger = logging.getLogger(__name__)


class BaseKNN(BaseEstimator):
    def __init__(self, n_users, n_items, storage=None, key_prefix=None, k=30,
                 similarity=cosine, similarity_params=None):
        super(BaseKNN, self).__init__(n_users, n_items, storage, key_prefix)
        self.k = int(min(k, n_items - 1))
        self.similarity = similarity
        self.similarity_params = similarity_params

    @staticmethod
    def knn_sparsify_metric(metric, k):
        """
        leave top N neighbor items for each item and re-calculate metric. it will make the matrix much more sparse.
        :param metric:
        :param k:
        :return:
        """

        # items count
        assert metric.shape[0] == metric.shape[1]
        n = metric.shape[0]
        k = min(k, n - 1)

        # will fill it item by item
        knn_indices = np.zeros((n, k), dtype=np.int)
        knn_distances = np.ones((n, k))

        # k neighbors + current item
        k += 1
        row_idx = np.repeat(np.arange(n), k)
        col_idx = np.zeros_like(row_idx)
        value = np.zeros_like(row_idx, dtype=np.float32)

        # we cannot do 'metric.todense' since it will not fit in memory
        # we cannot work with it as a sparse matrix since it is slow
        # so let's convert it into dense matrix by blocks
        metric_block_size = 1000
        metric_block_number = -1
        for item_idx in xrange(n):

            # update current block
            if item_idx / 1000 != metric_block_number:
                metric_block_number = item_idx / 1000
                metric_block_start = metric_block_number * metric_block_size
                metric_block = metric[metric_block_start:min(metric_block_start + 1000, n), :].todense()

            # find k nearest items
            metric_row = metric_block[item_idx % metric_block_size].A.ravel()
            item_knn_indices = np.argpartition(metric_row, -k)[-k:]
            item_knn_metrics = metric_row[item_knn_indices]
            col_idx[item_idx * k:item_idx * k + k] = item_knn_indices
            value[item_idx * k:item_idx * k + k] = item_knn_metrics

            # update indices/distances for 'predict with basket'
            item_knn_indices_without_self = [item_knn_indices[i] for i in xrange(k) if item_knn_indices[i] != item_idx]
            item_knn_indices_without_self = item_knn_indices_without_self[:k - 1]  # in case all metrics are 0.0
            knn_indices[item_idx, :] = item_knn_indices_without_self
            knn_distances[item_idx, :] = 1.0 - metric_row[item_knn_indices_without_self]

        # make a matrix
        metric = sparse.csr_matrix((value, (row_idx, col_idx)), shape=(n, n))
        return metric, knn_indices, knn_distances

    @staticmethod
    def knn_sparsify_metric_sklearn(metric, k):
        """
        sklearn version of knn_sparsify_metric
        :param metric:
        :param k:
        :return:
        """

        # items count
        assert metric.shape[0] == metric.shape[1]
        n = metric.shape[0]

        neighbors = NearestNeighbors(n_neighbors=k, algorithm='auto', metric='precomputed').fit(1 - metric.toarray())
        knn_distances, knn_indices = neighbors.kneighbors()

        # leave only top-k neighbors from similarity metric, set others to zero
        row_idx = np.repeat(np.arange(n), k)
        col_idx = knn_indices.ravel(order='C')
        value = (1 - knn_distances).ravel(order='C')

        # add diagonal entries
        row_idx = np.hstack([row_idx, range(n)])
        col_idx = np.hstack([col_idx, range(n)])
        value = np.hstack([value, [1.] * n])

        # make a matrix
        metric = sparse.csr_matrix((value, (row_idx, col_idx)), shape=(n, n))
        return metric, knn_indices, knn_distances

    def fit(self, X):
        X.assert_has_columns(('user', 'item', 'value'))

        # parameter checks
        assert (isinstance(self.k, int)), 'k (neighbor count) must be integer'

        # convert to pivot sparse matrix
        M = self.construct_sparse_matrix(X['user'], X['item'], X['value']).astype('float32')
        similarity_params = self.similarity_params or {}
        metric = self.similarity(M.T, **similarity_params)

        # make metric matrix more sparse leaving top N neighbor items for each item
        metric, knn_indices, knn_distances = self.knn_sparsify_metric(metric, self.k)

        self.storage.store(self.key_for('knn_indices'), knn_indices)
        self.storage.store(self.key_for('knn_distances'), knn_distances)
        self.storage.store(self.key_for('k'), [self.k])

        return self

    def get_neighbors(self, X):
        items = np.int32(X['item'])
        neighbors = self.storage.get_matrix_rows(self.key_for('knn_indices'), items).ravel(order='C')
        scores = self.storage.get_matrix_rows(self.key_for('knn_distances'), items).ravel(order='C')

        try:
            k = self.storage.get_object(self.key_for('k'))[0]
        except StorageKeyError:
            k = self.k

        result = X.repeat(k)
        # keep original items as reference/explanation for recommendation
        result = result.append_column(result['item'], 'similar_to')
        result['item'] = neighbors
        result = result.append_column(1 - scores, 'value')
        result.sort('value')
        return result[result['value'] != 0][::-1]
