import codecs
import sys
import json
import argparse
from random import random, seed
from decision_tree import DecisionTree
from rules import generate_split_rules, add_calculated_factors_to_images
from utils import calc_signals_match, calc_corsa_xl_nodups, calc_corsa_dt_nodups, calc_r2
from collections import defaultdict

def fit(args):
    ####### read images and filter rels #########
    with codecs.open(args.input_pool, 'r', 'utf8') as f:
        data = json.load(f, encoding='utf8')

    images = defaultdict(list)
    for elem in data:
        if '_other' in elem:
            elem.pop('_other')
        if 'dt_score' in elem:
            elem.pop('dt_score')
        if elem['relevance'] in ['RELEVANT_PLUS', 'RELEVANT_MINUS']:
            images[elem['qid']].append(elem)

    add_calculated_factors_to_images(images, flat=False)

    ####### generate splitting rules #######
    rules, _ = generate_split_rules()

    ####### load config file ########
    with open(args.config, 'r') as f:
        config = json.load(f)

    ####### fit the tree #########
    dt = DecisionTree(config=config, rules=rules)
    dt.fit(images)
    dt.dump_tree_json(args.output_tree_json)
    dt.dump_tree_dot(args.output_tree_dot)
    dt.save_dt_score_and_dump_leaves(args.output_pool)

    ####### calculate matched for various metric pairs #######
#    corsa_dbd_wins, corsa_dbd_total, qid_counter = calc_signals_match(images)
#    print 'Corsa Unity v2 match: {0}/{1} ({2:.2f}%)'.format(corsa_dbd_wins, corsa_dbd_total, 100.0 * corsa_dbd_wins / corsa_dbd_total)

    dt_dbd_wins, dt_dbd_total, _ = calc_signals_match(images, 'dt_score', 'dbd_normed_v3')
    print 'DecisionTree DbD match: {0}/{1} ({2:.2f}%)'.format(dt_dbd_wins, dt_dbd_total, 100.0 * dt_dbd_wins / dt_dbd_total)

#    corsa_click_wins, corsa_click_total, _ = calc_signals_match(images, 'corsax', 'dwelltime_boost')
#    print 'Corsa Dwelltime match: {0}/{1} ({2:.2f}%)'.format(corsa_click_wins, corsa_click_total, 100.0 * corsa_click_wins / corsa_click_total)
#
#    dt_click_wins, dt_click_total, _ = calc_signals_match(images, 'dt_score', 'dwelltime_boost')
#    print 'DecisionTree Dwelltime match: {0}/{1} ({2:.2f}%)'.format(dt_click_wins, dt_click_total, 100.0 * dt_click_wins / dt_click_total)
#
#    corsa_xr_dbd_wins, corsa_xr_dbd_total, _ = calc_signals_match(images, 'corsa-xr', 'dbd_score')
#    print 'CorsaXR DbD match: {0}/{1} ({2:.2f}%)'.format(corsa_xr_dbd_wins, corsa_xr_dbd_total, 100.0 * corsa_xr_dbd_wins / corsa_xr_dbd_total)
#
#    corsa_xr_click_wins, corsa_xr_click_total, _ = calc_signals_match(images, 'corsa-xr', 'dwelltime_boost')
#    print 'CorsaXR Dwelltime match: {0}/{1} ({2:.2f}%)'.format(corsa_xr_click_wins, corsa_xr_click_total, 100.0 * corsa_xr_click_wins / corsa_xr_click_total)
#
#    corsa_stove_dbd_wins, corsa_stove_dbd_total, _ = calc_signals_match(images, 'corsa-stove-dbd-precise', 'dbd_score')
#    print 'CorsaStovePreciseDbD DbD match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_dbd_wins, corsa_stove_dbd_total, 100.0 * corsa_stove_dbd_wins / corsa_stove_dbd_total)
#
#    corsa_stove_click_wins, corsa_stove_click_total, _ = calc_signals_match(images, 'corsa-stove-dbd-precise', 'dwelltime_boost')
#    print 'CorsaStovePreciseDbD Dwelltime match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_click_wins, corsa_stove_click_total, 100.0 * corsa_stove_click_wins / corsa_stove_click_total)
#
#    corsa_stove_dbd_wins, corsa_stove_dbd_total, _ = calc_signals_match(images, 'corsa-stove-dbd-reduced', 'dbd_score')
#    print 'CorsaStoveReducedDbD DbD match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_dbd_wins, corsa_stove_dbd_total, 100.0 * corsa_stove_dbd_wins / corsa_stove_dbd_total)
#
#    corsa_stove_click_wins, corsa_stove_click_total, _ = calc_signals_match(images, 'corsa-stove-dbd-reduced', 'dwelltime_boost')
#    print 'CorsaStoveReducedDbD Dwelltime match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_click_wins, corsa_stove_click_total, 100.0 * corsa_stove_click_wins / corsa_stove_click_total)
#
#    corsa_stove_dbd_wins, corsa_stove_dbd_total, _ = calc_signals_match(images, 'corsa-stove-dbd-noutil-reduced', 'dbd_score')
#    print 'CorsaStoveNoutilReducedDbD DbD match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_dbd_wins, corsa_stove_dbd_total, 100.0 * corsa_stove_dbd_wins / corsa_stove_dbd_total)
#
#    corsa_stove_click_wins, corsa_stove_click_total, _ = calc_signals_match(images, 'corsa-stove-dbd-noutil-reduced', 'dwelltime_boost')
#    print 'CorsaStoveNoutilReducedDbD Dwelltime match: {0}/{1} ({2:.2f}%)'.format(corsa_stove_click_wins, corsa_stove_click_total, 100.0 * corsa_stove_click_wins / corsa_stove_click_total)


def predict(args):
    ####### read images #########
    with codecs.open(args.input_pool, 'r', 'utf8') as f:
        data = json.load(f, encoding='utf8')

    for elem in data:
        if '_other' in elem:
            elem.pop('_other')
        if 'dt_score' in elem:
            elem.pop('dt_score')

    add_calculated_factors_to_images(data, flat=True)

    ####### generate splitting rules #######
    rules, _ = generate_split_rules()

    ####### load config file ########
    with open(args.config, 'r') as f:
        config = json.load(f)

    ####### load tree and predict #########
    dt = DecisionTree.from_file(config=config, rules=rules, filename=args.input_tree_json)
    out_data = dt.predict(data)

    ######## save result ###########
    with codecs.open(args.output_pool, 'w', 'utf8') as f:
        json.dump(out_data, f, ensure_ascii=False, indent=4)


def approx_factors_profit(args):
    ####### read images #########
    with codecs.open(args.input_pool, 'r', 'utf8') as f:
        data = json.load(f, encoding='utf8')

    for elem in data:
        if '_other' in elem:
            elem.pop('_other')
        if 'dt_score' in elem:
            elem.pop('dt_score')

    add_calculated_factors_to_images(data, flat=True)

    ####### generate splitting rules #######
    rules, _ = generate_split_rules()

    ####### load config file ########
    with open(args.config, 'r') as f:
        config = json.load(f)

    ####### load tree and calc base metric value #########
    dt = DecisionTree.from_file(config=config, rules=rules, filename=args.input_tree_json)

    signals = ['images_vq3_v2_avg', 'aestetics_mean5', 'square_tanh_2_0', 'viewport', 'utility_v2_avg']
    base_value = calc_corsa_dt_nodups(dt.predict(data))
    print "Current value", base_value

    ####### increase each signal and recalc metric #######
    prob = max(0.0, args.coeff - 1.0)
    for signal in signals:
        data_copy = [img.copy() for img in data]
        sum_before = 0.0
        sum_after = 0.0
        for img in data_copy:
            if signal not in img:
                continue
            sum_before += img[signal]
            if signal in ['viewport', 'square_tanh_2_0']:
                img[signal] = min(1.0, img[signal] * args.coeff)
            else:
                r = random()
                if signal in ['images_vq3_v2_avg', 'aestetics_mean5']:
                    delta = 0.1
                else:
                    delta = 0.066667

                p0 = (1.0 - prob) ** 5
                p1 = (1.0 - prob) ** 4 * prob * 5
                p2 = (1.0 - prob) ** 3 * prob ** 2 * 10

                if r > p0 and r <= p0 + p1: # +1
                    img[signal] = min(1.0, img[signal] + delta)
                elif r > p0 + p1 and r <= p0 + p1 + p2: # +2
                    img[signal] = min(1.0, img[signal] + delta * 2)
                elif r > p0 + p1 + p2: # +3
                    img[signal] = min(1.0, img[signal] + delta * 3)

            sum_after += img[signal]

        new_value = calc_corsa_dt_nodups(dt.predict(data_copy))
        print " {} boost value {}, profit {}%, signal inc {}%".format(signal, new_value, (new_value / base_value - 1.0) * 100.0, (sum_after / sum_before - 1.0) * 100.0)


def custom(args):
    ####### read images #########
    with codecs.open(args.input_pool, 'r', 'utf8') as f:
        data = json.load(f, encoding='utf8')

    images = defaultdict(list)
    for elem in data:
        if '_other' in elem:
            elem.pop('_other')
        if 'dt_score' in elem:
            elem.pop('dt_score')
        if elem['relevance'] in ['RELEVANT_PLUS', 'RELEVANT_MINUS']:
            images[elem['qid']].append(elem)

    add_calculated_factors_to_images(images, flat=False)

    ####### generate splitting rules #######
    rules, _ = generate_split_rules()

    ####### load config file ########
    with open(args.config, 'r') as f:
        config = json.load(f)

    ####### load tree and calc base metric value #########
    dt = DecisionTree.from_file(config=config, rules=rules, filename=args.input_tree_json)
    dt.predict(data)

    ###### custom code ##########
    wins, total, _ = calc_signals_match(images, 'dbd_score', 'corsa-xr')
    r2 = 0 #calc_r2(images, 'dt_score', 'corsa-xr')
    print 'match: {0}/{1} ({2:.2f}%), R2: {3:.2f}'.format(wins, total, 100.0 * wins / total, r2)

    calc_corsa_xl_nodups(data)

    wins, total, _ = calc_signals_match(images, 'dbd_score', 'corsa-xl')

    s = 0.0
    cnt = 0
    for img in data:
        s += img.get('corsa-xl', 0.0)
        cnt += 1
    print s / cnt
    r2 = 0 #calc_r2(images, 'dt_score', 'corsa-xr')
    print 'match: {0}/{1} ({2:.2f}%), R2: {3:.2f}'.format(wins, total, 100.0 * wins / total, r2)


def create_argument_parser():
    parser = argparse.ArgumentParser(description='DecisionTree Scoring Model')
    subparsers = parser.add_subparsers(dest='command')

    parser.add_argument('--input-pool',required=True)
    parser.add_argument('--output-pool',required=True)
    parser.add_argument('--config',required=True)

    fit_parser = subparsers.add_parser('fit', help='DecisionTree Fit')
    fit_parser.add_argument('--output-tree-json', required=True)
    fit_parser.add_argument('--output-tree-dot', required=True)

    predict_parser = subparsers.add_parser('predict', help='DecisionTree Predict')
    predict_parser.add_argument('--input-tree-json', required=True)

    approx_parser = subparsers.add_parser('approx_profit', help='Factors profit approximations for Corsa-XL')
    approx_parser.add_argument('--input-tree-json', required=True)
    approx_parser.add_argument('--coeff', required=True, type=float)

    custom_parser = subparsers.add_parser('custom', help='Custom code')
    custom_parser.add_argument('--input-tree-json', required=True)

    return parser


if __name__ == '__main__':
    parser = create_argument_parser()
    args = parser.parse_args()

    if args.command == 'fit':
        fit(args)
    elif args.command == 'approx_profit':
        approx_factors_profit(args)
    elif args.command == 'predict':
        predict(args)
    else:
        custom(args)
