import numpy as np

from jafar.estimators.base import BaseEstimator
from jafar.utils.structarrays import DataFrame


class LocalityEstimator(BaseEstimator):
    basket_required = False

    def __init__(self, region_column='lbs_region_city', threshold=0.5, **kwargs):
        super(LocalityEstimator, self).__init__(**kwargs)
        self.region_column = region_column
        self.threshold = threshold

    def fit(self, X):
        X.assert_has_columns(['item', 'region', 'score'])
        shape = X['region'].max() + 1, self.n_items
        item_region_matrix = self.construct_sparse_matrix(X['region'], X['item'], X['score'], shape=shape)
        self.storage.store(self.key_for('item_region_matrix'), item_region_matrix)
        return self

    def predict(self, X, basket=None):
        X.assert_has_columns(['user', 'item', self.region_column])
        item_region_matrix = self.storage.get_proxy(self.key_for('item_region_matrix'))
        regions = X[self.region_column]
        valid_regions = (regions < item_region_matrix.shape[0]) & (regions >= 0)
        scores = np.zeros(len(X), dtype=np.float32)
        scores[valid_regions] = item_region_matrix[regions[valid_regions], X['item'][valid_regions]].A.ravel()
        return DataFrame.from_dict({
            'user': X['user'],
            'item': X['item'],
            'value': scores,
        })

    def predict_top_n(self, X, n, basket=None):
        X.assert_has_columns(['user', self.region_column])
        regions = X[self.region_column]
        item_region_matrix = self.storage.get_proxy(self.key_for('item_region_matrix'))

        recommendations = []
        recommendation_lengths = []
        recommendation_scores = []
        for region in regions:
            if region < item_region_matrix.shape[0]:
                row = item_region_matrix.getrow(region)
                indices_above_threshold = np.nonzero(row.data > self.threshold)[0]
                kth = min(len(indices_above_threshold), n)
                if kth > 0:
                    top_k_indices = np.argpartition(row.data[indices_above_threshold], -kth)[-kth:]
                    indices = indices_above_threshold[top_k_indices]
                    recommendation_lengths.append(len(indices))
                    recommendations.append(row.indices[indices])
                    recommendation_scores.append(row.data[indices])
                    continue
            recommendation_lengths.append(0)
            recommendations.append(np.empty(0, dtype=np.int32))
            recommendation_scores.append(np.empty(0, dtype=np.float32))
        return DataFrame.from_dict({
            'user': X['user'].repeat(recommendation_lengths),
            'item': np.hstack(recommendations),
            'value': np.hstack(recommendation_scores)
        })
