import json
import math
import sys
from collections import defaultdict, Counter
from itertools import combinations
import codecs

class DecisionTree(object):
    def __init__(self, config, rules, root=None):
        self.root = root

        self.rules = rules

        attributes = [  'max_depth',
                        'min_node_size',
                        'dwelltime_boost',
                        'min_split_scores_ratio',
                        'dynamic_split_thresh_inc',
                        'use_balanced_split',
                        'balance_ratio',
                        'dynamic_balance_inc',
                        'use_integer_split',
                        'use_integer_leaf_aggregation',
                        'max_dbd_score',
                        'use_pruning'
                        ]

        for attr in attributes:
            setattr(self, attr, config[attr])

        print 'DecisionTree config:'
        print json.dumps(config, indent=4)

    def get_split_score(self, split_pairs_tuple, integer_mode=False):
        success_pairs, total_pairs = split_pairs_tuple

        if integer_mode:
            return success_pairs, total_pairs - success_pairs

        if total_pairs == 0:
            return 0.0, 0.0

        ratio = float(success_pairs) / total_pairs
        return ratio, 1.0 - ratio

    def fit_with_custom_rules(self, images, custom_rules):
        self.root = Node(images)
        self.propagate_custom_rules(self.root, custom_rules)
        self.fit(images)

    def propagate_custom_rules(self, node, custom_rules):
        if len(custom_rules) == 0:
            return

        center = len(custom_rules) / 2
        node.best_rule = custom_rules[center]
        node.best_split_score = 0.0 # FIXME write real score
        node.create_children()

        self.propagate_custom_rules(node.left, custom_rules[:center])
        self.propagate_custom_rules(node.right, custom_rules[center+1:])

    def fit(self, images):
        if self.root is None:
            self.root = Node(images)

        self.fit_node(self.root)

        if self.use_pruning:
            print 'Pruning...'
            preliminary_leaves = self.root.get_leaves()
            print 'Preliminary {} leaves'.format(len(preliminary_leaves))
            self.aggregate_leaves(preliminary_leaves)
            self.normalize_leaves(preliminary_leaves)
            self.pruning(self.root)

        all_leaves = self.root.get_leaves()
        print 'Total {} leaves'.format(len(all_leaves))

        scores_matrix, score_row_sum = self.aggregate_leaves(all_leaves)
        self.normalize_leaves(all_leaves)
        dbd_match_matrix, dbd_row_sums, success_pairs, total_pairs = self.calc_dbd_consistency(all_leaves)

        print 'Scores matrix'
        self.print_leaf_matrix(scores_matrix, score_row_sum, '{0:.2f}')
        print 'DbD match matrix'
        self.print_leaf_matrix(dbd_match_matrix, dbd_row_sums, '{0:06.1f}')
        print 'DbD pairs predicted: {0}/{1} ({2:.1f}%)'.format(success_pairs, total_pairs, 100.0 * success_pairs / total_pairs)

        self.print_leaf_factor_scores(all_leaves)

    def fit_node(self, node):
        stopping_criteria = Counter()

        # depth stopping criterion
        if node.depth >= self.max_depth:
            stopping_criteria['depth'] += 1
            node.set_stopping_criteria(stopping_criteria)
            return

        if node.best_rule is None:
            for rule in self.rules:
                left_images, right_images = node.try_rule(rule)
                n_left_images = Node.images_size(left_images)
                n_right_images = Node.images_size(right_images)

                # skip for min node size stopping criterion
                if n_left_images < self.min_node_size or n_right_images < self.min_node_size:
                    stopping_criteria['min_node_size'] += 1
                    continue

                if self.use_balanced_split:
                    current_balance = float(n_right_images) / n_left_images
                    balance_thresh = self.balance_ratio + node.depth * self.dynamic_balance_inc
                    # skip for balanced stopping criterion
                    if not (1.0 / balance_thresh <= current_balance and current_balance <= balance_thresh):
                        print "Balanced split is not possible"
                        stopping_criteria['balance'] += 1
                        continue

                right_score, left_score = self.get_split_score(Node.calc_split_score(left_images,
                                                                                    right_images,
                                                                                    click_weight=self.dwelltime_boost,
                                                                                    max_dbd_score=self.max_dbd_score),
                                                                integer_mode=self.use_integer_split)

                # skip for min split score stopping criterion
                if right_score <= (self.min_split_scores_ratio + node.depth * self.dynamic_split_thresh_inc) * left_score:
                    stopping_criteria['min_split_score'] += 1
                    continue

                node.new_rule(rule, right_score)

                print right_score, rule.name, Node.images_size(left_images), Node.images_size(right_images)

            # no rules within criteria
            if node.rules_size() == 0:
                node.set_stopping_criteria(stopping_criteria)
                return

            node.choose_best_rule()
            node.create_children()

        self.fit_node(node.left)
        self.fit_node(node.right)

    def pruning(self, node):
        if not node:
            return
        if node.left and node.left.is_leaf and node.right and node.right.is_leaf:
            if node.left.factor_score <= node.right.factor_score: # node is correct
                print 'Node {} is correct'.format(node.best_rule.name)
                return
            print "Applying pruning for {} ({}): L/R {}/{}".format(node.best_rule.name, node.best_split_score, node.left.factor_score, node.right.factor_score)
            stopping_criteria = Counter()
            stopping_criteria['pruning'] += 1
            node.set_stopping_criteria(stopping_criteria)

            node.is_leaf = True

            node.left = None
            node.right = None

            node.best_rule = None
            node.best_split_score = None

            node.factor_score = 0.0
            node.unit_factor_score = 0.0
        else:
            if node.is_leaf:
                node.factor_score = 0.0
                node.unit_factor_score = 0.0
            else:
                self.pruning(node.left)
                self.pruning(node.right)

    def calc_wins_matrix(self, all_leaves, integer_mode):
        wins_matrix = [[0] * len(all_leaves) for x in range(len(all_leaves))]

        # fill the matrix with non-diagonal elements
        for (left_i, left_leaf), (right_i, right_leaf) in combinations(enumerate(all_leaves), 2):
            split_score_tuple = Node.calc_split_score(left_leaf.images,
                                                    right_leaf.images,
                                                    click_weight=self.dwelltime_boost,
                                                    max_dbd_score=self.max_dbd_score)
            print split_score_tuple
            right_score, left_score = self.get_split_score(split_score_tuple, integer_mode=integer_mode)

            wins_matrix[right_i][left_i] = right_score
            wins_matrix[left_i][right_i] = left_score

        # fill the matrix with diagonal elements
        for leaf_i, leaf in enumerate(all_leaves):
            if integer_mode:
                wins_matrix[leaf_i][leaf_i] = leaf.get_n_pairs() / 2.0
            else:
                wins_matrix[leaf_i][leaf_i] = 0.5

        # sum the matrix by rows
        score_row_sums = []
        for row in wins_matrix:
            score_row_sums.append(sum(row))

        return wins_matrix, score_row_sums

    def aggregate_leaves(self, all_leaves):
        scores_matrix, score_row_sums = self.calc_wins_matrix(all_leaves, self.use_integer_leaf_aggregation)

        for (left_i, left_leaf), (right_i, right_leaf) in combinations(enumerate(all_leaves), 2):
            left_leaf.add_factor_score(scores_matrix[left_i][right_i])
            right_leaf.add_factor_score(scores_matrix[right_i][left_i])

        return scores_matrix, score_row_sums

    def normalize_leaves(self, all_leaves):
        if len(all_leaves) == 0:
            print 'Normalization for no leaves is impossible'
            return

        min_val = min(all_leaves, key=lambda x: x.factor_score).factor_score
        max_val = max(all_leaves, key=lambda x: x.factor_score).factor_score

        if max_val - min_val < 0.01:
            print 'Too small leaves scores diff for normalization'
            return

        for leaf in all_leaves:
            leaf.unit_interval_map(min_val, max_val)

    def calc_dbd_consistency(self, all_leaves):
        dbd_match_matrix, dbd_row_sums = self.calc_wins_matrix(all_leaves, True)

        # calc dbd consistency for non-diagonal elements
        total_pairs = 0
        success_pairs = 0
        for (left_i, left_leaf), (right_i, right_leaf) in combinations(enumerate(all_leaves), 2):
            total_pairs += dbd_match_matrix[left_i][right_i]
            total_pairs += dbd_match_matrix[right_i][left_i]
            if left_leaf.factor_score < right_leaf.factor_score:
                success_pairs += dbd_match_matrix[right_i][left_i]
            elif left_leaf.factor_score > right_leaf.factor_score:
                success_pairs += dbd_match_matrix[left_i][right_i]
            else:
                success_pairs += (dbd_match_matrix[right_i][left_i] + dbd_match_matrix[left_i][right_i]) / 2.0

        # calc consistency for diagonal elements
        for leaf_i, leaf in enumerate(all_leaves):
            success_pairs += dbd_match_matrix[leaf_i][leaf_i]
            total_pairs += dbd_match_matrix[leaf_i][leaf_i] * 2

        return dbd_match_matrix, dbd_row_sums, success_pairs, total_pairs

    def predict(self, data):
        return self.root.propagate_data(data)

    def print_leaf_matrix(self, matrix, row_sums, str_template):
        for i, row in enumerate(matrix):
            print '\t'.join(map(lambda x: str_template.format(x), row) + ['|', str(row_sums[i])])

    def print_leaf_factor_scores(self, all_leaves):
        print 'Leaves factors:'
        for leaf in all_leaves:
            print leaf.factor_score

    def dump_tree_json(self, filename):
        print "Dumping tree to json", filename
        with open(filename, 'w') as f:
            if self.root is None:
                print >>f, '{}'
            else:
                json.dump(self.root.to_json(), f, indent=4)

    def dump_tree_dot(self, filename):
        print "Dumping tree to dot", filename
        with open(filename, 'w') as f:
            print >>f, self.root.to_dot()

    def save_dt_score_and_dump_leaves(self, filename):
        all_leaves = self.root.get_leaves()

        out_data = []
        for leaf in all_leaves:
            leaf_score = leaf.get_final_score()
            for query in leaf.images:
                for img in leaf.images[query]:
                    img['dt_score'] = leaf_score
                    out_data.append(img)

        print "Dumping {} images to file {}".format(len(out_data), filename)
        with codecs.open(filename, 'w', 'utf8') as f:
            json.dump(out_data, f, indent=4, ensure_ascii=False)

    @classmethod
    def from_file(cls, config, rules, filename):
        with open(filename, 'r') as f:
            tree_json = json.load(f)
        root = Node()
        rules_dict = {r.name:r for r in rules}
        root.propagate_rules(rules_dict, tree_json)
        return cls(config, rules, root)


class Node(object):
    def __init__(self, images={}, depth=0):
        self.left = None
        self.right = None
        self.parent = None

        self.images = images

        self.stopping_criteria = []

        self.rules = []
        self.split_scores = []
        self.best_rule = None
        self.best_split_score = None

        self.is_leaf = True
        self.depth = depth

        self.factor_score = 0.0
        self.unit_factor_score = 0.0

    @staticmethod
    def images_size(images):
        return sum(map(len, images.values()))

    @staticmethod
    def calc_split_score(left_images, right_images, click_weight=0.0, max_dbd_score=1.0):
        total = 0
        success = 0

        target_field = 'dbd_normed_v3'
        click_field = 'sps_boost'

        for query in left_images:
            if query not in right_images:
                continue

            for img_left in left_images[query]:
                for img_right in right_images[query]:
                    if target_field not in img_left or target_field not in img_right:
                        continue
                    total += 1
                    if (1.0 - click_weight) * img_left[target_field] + click_weight * img_left[click_field] * max_dbd_score < \
                            (1.0 - click_weight) * img_right[target_field] + click_weight * img_right[click_field] * max_dbd_score:
                        success += 1

        return min(success, total), total

    def set_stopping_criteria(self, criteria):
        self.stopping_criteria = criteria.most_common()

    def get_n_pairs(self):
        n_pairs = 0

        for query in self.images:
            n_images = len(self.images[query])
            n_pairs += n_images * (n_images - 1) / 2

        return n_pairs

    def add_factor_score(self, score):
        self.factor_score += score

    def get_final_score(self):
        return self.unit_factor_score

    def unit_interval_map(self, min_val, max_val):
        self.unit_factor_score = max(0.0, min(1.0, (self.factor_score - min_val) / (max_val - min_val)))

    def get_leaves(self):
        if self.is_leaf:
            return [self]

        if self.left is not None:
            left_result = self.left.get_leaves()
        else:
            left_result = []

        if self.right is not None:
            right_result = self.right.get_leaves()
        else:
            right_result = []

        return left_result + right_result

    def choose_best_rule(self):
        best_ind, best_split_score = sorted(enumerate(self.split_scores), key=lambda x: x[1], reverse=True)[0]
        self.best_rule = self.rules[best_ind]
        self.best_split_score = best_split_score
        print 'Best split rule: {}, score {}'.format(self.best_rule, self.best_split_score)

    def rules_size(self):
        return len(self.rules)

    def new_rule(self, rule, split_score):
        self.rules.append(rule)
        self.split_scores.append(split_score)

    def _split(self, images, rule, flat=False):
        if flat:
            left_images = []
            right_images = []

            for img in images:
                if rule(img):
                    right_images.append(img)
                else:
                    left_images.append(img)

            return left_images, right_images

        # else
        left_images = defaultdict(list)
        right_images = defaultdict(list)

        for query in images:
            for img in images[query]:
                if rule(img):
                    right_images[query].append(img)
                else:
                    left_images[query].append(img)

        return left_images, right_images

    def split_node(self):
        return self._split(self.images, self.best_rule, False)

    def try_rule(self, rule):
        return self._split(self.images, rule, False)

    def propagate_data(self, data, leaf_id=0):
        if len(data) == 0:
            return data
        if self.is_leaf:
            for img in data:
                img['dt_score'] = self.get_final_score()
                img['dt_leaf_id'] = str(leaf_id)
            return data
        left_data, right_data = self._split(data, self.best_rule, True)
        return self.left.propagate_data(left_data, leaf_id * 2 + 1) + self.right.propagate_data(right_data, leaf_id * 2 + 2)

    def create_children(self):
        left_images, right_images = self.split_node()
        self.left = Node(left_images, depth=self.depth+1)
        self.right = Node(right_images, depth=self.depth+1)
        self.is_leaf = False

    def propagate_rules(self, rules_dict, tree_json):
        if 'rule' in tree_json: # non-leaf node
            assert tree_json['rule'] in rules_dict, "Rule {} is not found".format(tree_json['rule'])
            current_rule = rules_dict[tree_json['rule']]

            assert 'split_score' in tree_json, "Split score is not present for node {}".format(current_rule.name)
            assert 'left' in tree_json, "No left child node for {}".format(current_rule.name)
            assert 'right' in tree_json, "No right child node for {}".format(current_rule.name)

            self.best_rule = current_rule
            self.best_split_score = tree_json['split_score']
            self.create_children()

            print "Node {} is successfully loaded".format(current_rule.name)
            self.left.propagate_rules(rules_dict, tree_json['left'])
            self.right.propagate_rules(rules_dict, tree_json['right'])
        else: # leaf node
            assert 'unit_factor_score' in tree_json, "No unit score field"
            assert 'factor_score' in tree_json, "No score field"
            assert 'stopping_criteria' in tree_json, "No stopping criteria field"

            self.factor_score = tree_json['factor_score']
            self.unit_factor_score = tree_json['unit_factor_score']

            stopping_criteria = []
            for criterion in tree_json['stopping_criteria'].split(';'):
                if not criterion:
                    continue
                reason, cnt = criterion.split(':')
                stopping_criteria.append((reason, int(cnt)))
            self.stopping_criteria = stopping_criteria

            print "Leaf is successfully loaded"

    def to_json(self):
        if self.best_rule is None or self.best_split_score is None:
            return {'n_images': Node.images_size(self.images),
                    'factor_score': self.factor_score,
                    'unit_factor_score': self.unit_factor_score,
                    'stopping_criteria': ';'.join('{}:{}'.format(criterion, cnt) for criterion, cnt in self.stopping_criteria)}

        if self.left:
            left_json = self.left.to_json()
        else:
            left_json = {}

        if self.right:
            right_json = self.right.to_json()
        else:
            right_json = {}

        return {'rule': self.best_rule.name,
                'split_score': self.best_split_score,
                'left': left_json,
                'right': right_json}

    def to_dot(self):
        def get_dot_edges(node, index):
            if node.best_rule is None or node.best_split_score is None:
                return [], '[{0}] {1} images, {2:.2f} score'.format(index, Node.images_size(node.images), node.unit_factor_score)

            if node.left:
                left_edges, left_name = get_dot_edges(node.left, 2 * index + 1)
            else:
                left_edges, left_name = [], '[{}] None'.format(index)

            if node.right:
                right_edges, right_name = get_dot_edges(node.right, 2 * index + 2)
            else:
                right_edges, right_name = [], '[{}] None'.format(index)

            name = '[{}] {}'.format(index, node.best_rule.name)
            new_edges = ['"{}" -> "{}" [label=False]'.format(name, left_name),
                        '"{}" -> "{}" [label=True]'.format(name, right_name)]

            return new_edges + left_edges + right_edges, name

        dot_rules, _ = get_dot_edges(self, 0)

        dot = 'digraph DesicionTree {{\n{}\n}}'.format('\n'.join(['    ' + rule for rule in dot_rules]))

        return dot
