from collections import Counter

import numpy as np

from jafar.exceptions import FeedException
from jafar.feed.base import FeedBlock, FeedPage
from jafar.feed.recommenders import TopNRecommender
from jafar.feed.strategy import (
    OneHotPositionConstraint,
    RecommendationGroupCapacityConstraint,
    RecommendationGroupAdjacencyConstraint,
    CardAdjacencyConstraint,
    BalancedRecommenderConstraint
)
from jafar.feed.strategy import Problem
from jafar.models.cards import Card
from jafar.tests import JafarTestCase
from jafar.tests.fixtures.mongo_configs import FAKE_EXPIRE_AT
from jafar.utils.structarrays import DataFrame


class StrategyTestCase(JafarTestCase):

    def get_test_recommendations(self, prefix, size=100):
        return DataFrame.from_dict(dict(
            item=np.array(["{}_item_{}".format(prefix, i) for i in xrange(size)]),
            value=np.ones(size)
        ))


class IntegerProblemTestCase(StrategyTestCase):
    title_pairs = [('title', 'subtitle')]

    def get_card_type(self, number):
        return 'type{}'.format(number)

    def get_content_type(self, number):
        return 'content{}'.format(number)

    def solve_shortcut(self, n_blocks, group_sizes, card_sizes, constraints=None,
                       supported_types_sets=None, recommenders=None, previous_page=None,
                       card_weights=None, group_weights=None):
        constraints = constraints or []
        if not recommenders:
            recommenders = [TopNRecommender(pipeline='kano', title_pairs=self.title_pairs)] * len(group_sizes)

        if group_weights is None:
            group_weights = np.ones(len(group_sizes))
        recommendation_groups = [
            recommender.recommendation_group_class(
                explanation='group{}'.format(i), recommender=recommender,
                recommendations=self.get_test_recommendations('group{}'.format(i), size=capacity),
                weight=group_weights[i]
            ) for i, (capacity, recommender) in enumerate(zip(group_sizes, recommenders))
        ]
        card_types = []
        if not supported_types_sets:
            supported_types_sets = [('apps', 'promo')] * len(card_sizes)
        for i, (card_size, types) in enumerate(zip(card_sizes, supported_types_sets)):
            card_types.append(
                Card(
                    min_count=card_size, reserved_count=card_size - 1,
                    card_type=self.get_card_type(i), supported_content_types=types
                )
            )
        if card_weights is None:
            card_weights = np.ones(len(card_types))

        randomization = 1.0
        problem = Problem(
            n_blocks, card_types, recommendation_groups, card_weights,
            group_weights, randomization, previous_page
        )
        for constraint in constraints:
            constraint.apply(problem)
        try:
            solution = problem.solve()
        except FeedException as e:
            self.fail("Couldn't solve a problem: {}".format(e))
        return solution, problem

    def test_unconstrained_problem(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[100, 50, 200], card_sizes=[5, 10, 2]
        )
        # just verify that solution is a valid-size numpy array
        self.assertIsInstance(solution, np.ndarray)
        self.assertEqual(len(solution), np.product(problem.dimensions))

    def test_one_hot_constraint(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[100, 50, 200],
            card_sizes=[5, 10, 2],
            constraints=[OneHotPositionConstraint()]
        )
        # verify one-hot condition
        solution = solution.reshape(problem.dimensions)
        for solution_slice in solution:
            self.assertEqual(solution_slice.sum(), 1)

    def test_capacity_constraint(self):
        """
        Capacity constraint makes sense only with onehot-constraint:
        otherwise simple "set everything to zero" solution will be possible.
        """
        group_sizes = [2, 2, 2]
        card_sizes = [1]
        constraints = [OneHotPositionConstraint(), RecommendationGroupCapacityConstraint()]
        solution, problem = self.solve_shortcut(
            n_blocks=6, group_sizes=group_sizes,
            card_sizes=card_sizes, constraints=constraints
        )
        solution = solution.reshape(problem.dimensions)
        for j in xrange(problem.n_recommendation_groups):
            self.assertLessEqual(np.sum(solution[:, j, :] * problem.card_sizes), group_sizes[j])

        # asking for more then 6 blocks should make the problem infeasible
        with self.assertRaises(AssertionError):
            self.solve_shortcut(
                n_blocks=7, group_sizes=group_sizes,
                card_sizes=card_sizes, constraints=constraints
            )

    def _verify_recommendation_group_adjacency_constraint(self, solution, problem):
        solution = solution.reshape(problem.dimensions)
        for i in xrange(problem.n_blocks - 1):
            group_idx, _ = solution[i].nonzero()
            next_group_idx, _ = solution[i + 1].nonzero()
            self.assertEqual(len(group_idx), 1)
            self.assertEqual(len(next_group_idx), 1)
            self.assertNotEqual(group_idx[0], next_group_idx[0])

    def _verify_card_adjacency_constraint(self, solution, problem):
        solution = solution.reshape(problem.dimensions)
        for i in xrange(problem.n_blocks - 1):
            _, card_idx = solution[i].nonzero()
            _, next_card_idx = solution[i + 1].nonzero()
            self.assertEqual(len(card_idx), 1)
            self.assertEqual(len(next_card_idx), 1)
            self.assertNotEqual(card_idx[0], next_card_idx[0])

    def test_recommendation_group_adjacency_constraint(self):
        """
        Again, one-hot constraint is required here to have a meaningful problem.
        """
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
            constraints=[OneHotPositionConstraint(), RecommendationGroupAdjacencyConstraint()]
        )
        self._verify_recommendation_group_adjacency_constraint(solution, problem)

    def test_card_adjacency_constraint(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
            constraints=[OneHotPositionConstraint(), CardAdjacencyConstraint()]
        )
        self._verify_card_adjacency_constraint(solution, problem)

    def test_both_adjacency_constraints(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
            constraints=[OneHotPositionConstraint(), CardAdjacencyConstraint(), RecommendationGroupAdjacencyConstraint()]
        )
        self._verify_recommendation_group_adjacency_constraint(solution, problem)
        self._verify_card_adjacency_constraint(solution, problem)

    def test_previous_page(self):
        """
        if previous page is present, its last card must affect card adjacency constraint
        """
        card_sizes = [1, 5, 10]
        for i, card_size in enumerate(card_sizes):
            previous_page = FeedPage(blocks=[
                FeedBlock(
                    items=[{"package_name": "dummy"}],
                    title="dummy block",
                    subtitle="",
                    algorithm="handmade",
                    card_type=self.get_card_type(i),
                    content_type=self.get_content_type(i)
                )
            ], expire_at=FAKE_EXPIRE_AT)
            solution, problem = self.solve_shortcut(
                n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
                constraints=[OneHotPositionConstraint(), CardAdjacencyConstraint()],
                previous_page=previous_page
            )
            solution = solution.reshape(problem.dimensions)
            _, card_idx = solution[0].nonzero()
            self.assertNotEqual(card_idx[0], i)

    def _count_recommendation_groups_in_solution(self, solution, problem):
        groups = []
        solution = solution.reshape(problem.dimensions)
        for solution_slice in solution:
            group_idx, card_idx = solution_slice.nonzero()
            groups.append(group_idx[0])
        return Counter(groups)

    def test_group_weights_negative(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
            group_weights=np.array([1, 0, -10]),
            constraints=[
                OneHotPositionConstraint(),
                CardAdjacencyConstraint(),
                RecommendationGroupAdjacencyConstraint(),
                BalancedRecommenderConstraint()
            ]
        )
        solution = solution.reshape(problem.dimensions)
        # group 2 shouldn't be present at all
        self.assertAlmostEqual(solution[:, 2, :].sum(), 0.0)

    def test_group_weights_positive(self):
        solution, problem = self.solve_shortcut(
            n_blocks=5, group_sizes=[10, 50, 200], card_sizes=[1, 5, 10],
            group_weights=np.array([5, 5, 1]),
            constraints=[
                OneHotPositionConstraint(),
                CardAdjacencyConstraint(),
                RecommendationGroupAdjacencyConstraint(),
                BalancedRecommenderConstraint()
            ]
        )
        solution = solution.reshape(problem.dimensions)
        # all groups should be prefent due to balancing constraint
        for group_idx in xrange(3):
            self.assertNotAlmostEqual(solution[:, group_idx, :].sum(), 0.0)
        # but group 2 should only be present once
        self.assertAlmostEqual(solution[:, 2, :].sum(), 1.0)
