#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
import sys
import os
import codecs
import json
import argparse
import math
import random
from collections import Counter


rnd = random.SystemRandom()


default_config = [
    {"predicate": "x.get('classifiers', {}).get('isFilm', '')", "number": 300},
    {"predicate": "x.get('classifiers', {}).get('isSerial', '')", "number": 300},
    {"predicate": "x.get('classifiers', {}).get('isPorn', '')", "number": 300},
    {"predicate": "x['service'] == 'vid' and '_yandex_' in x['table']", "number": 100},
    {"predicate": "'_google_' in x['table']", "number": 500},
    {"predicate": "x['service'] == 'web' and '_yandex_' in x['table']", "number": 500},
]


def renorm(value, t, n):
    return (1 - t) * value + (t / n)


def c_make(c, target_number, t=0):
    assert 0 <= t <= 1
    sum_ = sum(c.values())
    n = len(c)
    cand = {
        k: int(math.ceil(renorm(v / sum_, t, n) * target_number)) for k, v in c.items()
    }
    keys = sorted(cand.keys(), reverse=True)
    for i, k in enumerate(keys):
        ks = keys[i + 1 :]
        while cand[k] > c[k] and ks:
            cand[k] -= 1
            cand[random.choice(ks)] += 1
    return cand


def resample_basic(
    recs,
    predicate,
    target_number,
    t,
    set_=None,
    query_field="text",
    count_field="frequency",
):
    if set_ is None:
        set_ = set()
    target_number_basket = target_number
    for rec in recs:
        rec["cat2n"] = round(math.log(rec[count_field], 2))
    recs = [x for x in recs if x[query_field] not in set_]
    print(
        "Needed: {}. Available: {}. Predicate: {}".format(
            target_number_basket, len(recs), predicate
        )
    )

    cats = Counter(x["cat2n"] for x in recs)
    target_numbers = c_make(cats, target_number_basket, t)

    while sum(target_numbers.values()) > target_number_basket:
        if any(x > 1 for x in target_numbers.values()):
            keys = [x for x in target_numbers.keys() if target_numbers[x] > 1]
        else:
            keys = [x for x in target_numbers.keys() if target_numbers[x] == 1]
        target_numbers[random.choice(keys)] -= 1

    cat_left = len(cats)
    result = []
    for cat in sorted(cats, key=lambda x: cats[x]):
        ask = target_numbers[cat]
        pool = [x for x in recs if x["cat2n"] == cat]
        random.shuffle(pool)
        result.extend(pool[:ask])
        set_ |= {x[query_field] for x in pool[:ask]}
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--basket")
    parser.add_argument("--config")
    parser.add_argument("--kpi")
    parser.add_argument("--validate")
    parser.add_argument("--query_field", default="text")
    parser.add_argument("--count_field", default="frequency")
    parser.add_argument("--coef", default=0.0, type=float)
    args = parser.parse_args()

    if not args.config:
        args.config = default_config
    else:
        args.config = json.load(open(args.config))

    recs = json.load(open(args.basket))

    kpi = []
    validate = []

    set_1 = set()
    for ask in args.config:
        pool = [x for x in recs if eval(ask["predicate"])]
        try:
            kpi.extend(
                resample_basic(
                    pool,
                    ask["predicate"],
                    ask["number"],
                    args.coef,
                    set_=set_1,
                    query_field=args.query_field,
                    count_field=args.count_field,
                )
            )
        except:
            print(ask["predicate"])

    json.dump(kpi, codecs.open(args.kpi, "w", "utf8"), indent=2, sort_keys=True)

    set_2 = set([x["text"] for x in kpi])
    for ask in args.config:
        pool = [x for x in recs if eval(ask["predicate"])]
        try:
            validate.extend(
                resample_basic(
                    pool,
                    ask["predicate"],
                    ask["number"],
                    args.coef,
                    set_=set_2,
                    query_field=args.query_field,
                    count_field=args.count_field,
                )
            )
        except:
            print(ask["predicate"])

    json.dump(
        validate, codecs.open(args.validate, "w", "utf8"), indent=2, sort_keys=True
    )

    kpi_queries = set([x["text"] for x in kpi])
    validate_queries = set([x["text"] for x in validate])

    print("Kpi len is {} ".format(len(kpi)))
    print("Validate len is {} ".format(len(validate)))

    print(
        "The intersection is {} ".format(
            round(len(kpi_queries.intersection(validate_queries)) * 100 / len(kpi))
        )
    )


if __name__ == "__main__":
    main()
