import json

import numpy as np
import os.path
from flask import current_app as app

import jafar
from jafar.estimators.base import BaseEstimator, GroupFeatureAwareEstimator
from jafar.pipelines.blocks.mapping import MISSING_ID
from jafar.storages import make_key
from jafar.utils import bincount_relative
from jafar.utils.structarrays import DataFrame


def sort_items_by_popularity(items, scores):
    idx = np.argsort(scores)[::-1]
    return items[idx], scores[idx]


class Popular(BaseEstimator):
    """
    Simple `most popular` recommender.
    """
    basket_required = False

    def _get_scores(self, X):
        # leaving only positive interactions
        items = np.array(X[X['value'] == 1]['item'], dtype=np.int32)
        # NOTE: bincount is supposed to be the fastest way to estimate counts:
        # https://stackoverflow.com/a/42862472/2405210
        return bincount_relative(items, length=self.n_items)

    def fit(self, X):
        X.assert_has_columns(['item', 'value'])

        popularity_scores = self._get_scores(X)
        self.storage.store(self.key_for('popularity'), popularity_scores)

        # for top-n prediction, also store items ordered by popularity
        sorted_items, sorted_scores = sort_items_by_popularity(np.arange(0, self.n_items), popularity_scores)

        # NOTE: use better keys than 'items' and 'scores'
        self.storage.store(self.key_for('items'), sorted_items)
        self.storage.store(self.key_for('scores'), sorted_scores)

        return self

    def predict(self, X, basket=None):
        X.assert_has_columns(('user', 'item'))
        # no personalization; just return popularity for requested items
        popularity_scores = self.storage.get_proxy(self.key_for('popularity'))

        return DataFrame.from_dict(dict(
            user=X['user'],
            item=X['item'],
            value=popularity_scores[X['item'].astype(np.int32)].astype(np.float32)
        ))

    def predict_top_n(self, X, n, basket=None):
        X.assert_has_columns(('user',))

        top_n_items = self.storage.get_proxy(self.key_for('items'))[:n]
        top_n_scores = self.storage.get_proxy(self.key_for('scores'))[:n]

        return DataFrame.from_dict(dict(
            user=np.repeat(X['user'], len(top_n_items)).astype(np.int32),
            item=np.tile(top_n_items, len(X['user'])).astype(np.int32),
            value=np.tile(top_n_scores, len(X['user'])).astype(np.float32)
        ))


class ConstantPopular(Popular):
    def _get_scores(self, X):
        result = np.zeros(self.n_items)
        path = os.path.join(app.config['BASE_DIR'], 'fixtures/json', 'ranking.json')
        with open(path) as read_file:
            ranking = json.load(read_file)
        top_map = jafar.top_class_name_loader.load(ranking.keys(), country='RU')
        classnames, popularities = [], []
        for classname_dict in top_map.itervalues():
            full_name = '{package_name}/{class_name}'.format(**classname_dict[0])
            classnames.append(full_name)
            popularities.append(ranking[classname_dict[0]['package_name']])

        item_map_key = make_key(self.key_prefix, 'item_map')
        idx = self.storage.map_values(item_map_key, classnames, None)
        for i, popularity in zip(idx, popularities):
            if i is not None:
                result[i] = popularity
        assert sum(result) != 0, "All popularities are zero"
        return result


class GroupFeatureAwarePopular(GroupFeatureAwareEstimator):
    """
    Splits items by values of some group (categorical) feature, estimates popularity
    inside each group. Also keeps n_group_features x n_top_popular.
    """
    group_feature_name = None
    basket_required = False

    def __init__(self, n_users, n_items, normalize=True, storage=None, key_prefix=None, top_n_items=200):
        super(GroupFeatureAwarePopular, self).__init__(n_users, n_items, storage, key_prefix)
        self.normalize = normalize
        self.top_n_items = top_n_items

    def fit(self, X):
        group_feature_names = self.get_group_features_list()

        X.assert_has_columns(['item', 'value'] + group_feature_names)
        X = X.drop_null(columns=group_feature_names)

        # keep overall popularity scores as well
        overall_estimator = Popular(
            n_users=self.n_users,
            n_items=self.n_items
        ).fit(X)

        for key in ('items', 'scores'):
            value = overall_estimator.storage.get_object(
                overall_estimator.key_for(key)
            )
            self.storage.store(self.key_for('overall_top_{}'.format(key)), value)

        # leaving only positive interactions
        idx = X['value'] == 1

        # leave only relevant columns
        X = X[['item'] + group_feature_names][idx]

        group_feature_encoder = self.create_group_feature_encoder(group_feature_names, X)
        n_group_features = len(group_feature_encoder.classes_)
        popularity_scores = np.zeros(self.n_items, dtype=np.float32)
        top_items = np.full(shape=(n_group_features, self.top_n_items), fill_value=MISSING_ID, dtype=np.int32)
        top_scores = np.zeros(shape=(n_group_features, self.top_n_items), dtype=np.float32)

        for group_feature_value, frame in X.groupby(group_feature_names):
            group_feature_idx = group_feature_encoder.transform(
                np.array([tuple(group_feature_value)], dtype=group_feature_value.dtype))[0]
            # cannot use bincount here because items are not ordered integers
            group_feature_items, group_feature_scores = np.unique(frame['item'], return_counts=True)
            group_feature_scores = group_feature_scores.astype(np.float32)
            if self.normalize:
                group_feature_scores /= group_feature_scores.sum()
            group_feature_items, group_feature_scores = sort_items_by_popularity(
                group_feature_items, group_feature_scores
            )
            popularity_scores[group_feature_items] = group_feature_scores
            top_n = min(self.top_n_items, len(group_feature_items))
            top_items[group_feature_idx, :top_n] = group_feature_items[:top_n]
            top_scores[group_feature_idx, :top_n] = group_feature_scores[:top_n]

        self.storage.store(self.key_for('popularity'), popularity_scores)
        self.storage.store(self.key_for('top_items'), top_items)
        self.storage.store(self.key_for('top_scores'), top_scores)

        return self

    def predict(self, X, basket=None):
        X.assert_has_columns(('user', 'item'))
        popularity_scores = self.storage.get_proxy(self.key_for('popularity'))

        return DataFrame.from_dict(dict(
            user=X['user'],
            item=X['item'],
            value=popularity_scores[X['item'].astype(np.int32)].astype(np.float32)
        ))

    def predict_top_n_per_group_feature(self, X, n, basket=None, group_feature_values=None):
        """
        :param X: frame containing users to recommend to
        :param n: number of recommendations to return (per user, per category)
        :param basket: frame containing items users already own
        :param group_feature_values: list of group feature values (e.g. categories)
                                     to restrict recommendations to (by default all categories are used)
        """
        group_feature_names = self.get_group_features_list()

        X.assert_has_columns(('user',))

        # NOTE: special case here: overall popularity cannot be obtained
        # from group popularities but is stored separately
        if group_feature_values is None:
            # add axis for compatibility
            top_n_items = np.atleast_2d(self.storage.get_proxy(self.key_for('overall_top_items'))[:n])
            top_n_scores = np.atleast_2d(self.storage.get_proxy(self.key_for('overall_top_scores'))[:n])
            n_group_features = 1
        else:
            group_feature_encoder = self.load_group_feature_encoder(group_feature_names)
            idx = self.get_group_feature_idx(group_feature_encoder, group_feature_values)
            n_group_features = len(idx)

            top_n_items = self.storage.get_proxy(self.key_for('top_items'))[idx, :n]
            top_n_scores = self.storage.get_proxy(self.key_for('top_scores'))[idx, :n]

        result = DataFrame.from_dict(dict(
            user=X['user'].repeat(top_n_items.shape[1] * n_group_features).astype(np.int32),
            item=np.tile(top_n_items.ravel(), len(X['user'])).astype(np.int32),
            value=np.tile(top_n_scores.ravel(), len(X['user'])).astype(np.float32),
        ))
        return result[result['item'] != MISSING_ID]


class CategoryAwarePopular(GroupFeatureAwarePopular):
    group_feature_name = 'category'


class SummarisingPopular(Popular):
    def _get_scores(self, X):
        result = np.zeros(self.n_items)
        for key, idx in X.arggroupby('item'):
            result[key] = np.sum(X['value'][idx])
        return result
