#-*- coding: UTF-8 -*-
import argparse
from nile.api.v1 import (
    filters as nf,
    aggregators as na,
    extractors as ne,
    clusters,
    Record
)
from qb2.api.v1 import (
    extractors as se,
    filters as sf
)
from time import time
from random import random
from math import log

def get_strat_factor(x, strat_factor):
    if strat_factor == 'target':
        return float(x.split('\t')[0])
    remap_values = [0,688,1366,2045,2723,3401,4084,4781,5524,6373,7311,8587,10099,12308,15876,23983,7.51829e+06]
    FACTORS_SHIFT = 3
    FACTOR = int(strat_factor)
    views_remap = float(x.split('\t')[FACTOR + FACTORS_SHIFT])

    nCPcount = len(remap_values)
    if views_remap < 0:
        views_remap = 0
    views_remap *= (nCPcount - 1)
    cp = int(views_remap)
    frac = views_remap - cp
    if cp >= nCPcount - 1:
        cp = nCPcount - 2
        frac = 1
    return int(remap_values[cp] * (1 - frac) + remap_values[cp + 1] * frac)

def get_bucket(x, min_value, max_value, buckets):
    step = (max_value - min_value) / buckets
    for i in range(buckets):
        if x >= min_value + i * step  and x < min_value + (i + 1) * step:
            return i + 1
    return buckets

class strat_sample(object):
    def __init__(self, prob_by_bucket):
        self.prob_by_bucket = prob_by_bucket
    def __call__(self, recs):
        for rec in recs:
            if random() < self.prob_by_bucket[rec["bucket"]]:
                yield rec

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_directory', type=str, required=True)
    parser.add_argument('--tables_to_copy', nargs='+', default=[])
    parser.add_argument('--test_ratio', type=float, required=True)
    parser.add_argument('--buckets', type=int, required=True)
    parser.add_argument('--sample_count', type=int, required=True)
    parser.add_argument('--output_directory', type=str, required=True)
    parser.add_argument('--pool_type', type=str, required=True)
    parser.add_argument('--cluster', type=str, required=True)
    parser.add_argument('--strat_factor', type=str, required=True)
    args = parser.parse_args()

    if args.cluster == 'hahn':
        cluster = clusters.yt.Hahn().env(parallel_operations_limit=10)
    else:
        cluster = clusters.yt.Arnold().env(parallel_operations_limit=10)

    if not cluster.driver.exists(args.output_directory):
        cluster.driver.mkdir(args.output_directory)

    input_directory = args.input_directory + "/"
    output_directory = args.output_directory + "/"

    for table in args.tables_to_copy:
        if cluster.driver.exists(output_directory + table):
            cluster.driver.remove(output_directory + table)
        cluster.driver.copy(input_directory + table, output_directory + table)

    current_ts = str(time())

    features_with_bucket = "//tmp/msvvitaly/features_with_bucket_" + current_ts
    bucket_stats = "//tmp/msvvitaly/bucket_stats" + current_ts
    job = cluster.job()
    if args.pool_type == 'deep_click':
        with_bucket = job.table(input_directory + "features") \
                         .project(ne.all(), bucket=ne.custom(lambda x : int(log(1 + get_strat_factor(x, args.strat_factor), 2)), 'value')) \
                         .put(features_with_bucket)
        with_bucket.groupby('bucket') \
                   .aggregate(count=na.count()) \
                   .sort('count') \
                   .put(bucket_stats)
    else:
        with_bucket = job.table(input_directory + "features") \
                         .project(ne.all(), bucket=ne.custom(lambda x : int(log(x["users"], 2)), 'logs_stats')) \
                         .put(features_with_bucket)
        with_bucket.groupby('bucket') \
                   .aggregate(count=na.count()) \
                   .sort('count') \
                   .put(bucket_stats)
    job.run()

    data = []
    for rec in cluster.driver.read(bucket_stats):
        data.append([rec["bucket"], rec["count"]])

    print data

    prob_by_bucket = {}
    count_to_sample = args.sample_count
    buckets_left = len(data)
    for i in range(len(data)):
        bucket_size = count_to_sample // buckets_left + 1
        if bucket_size > data[i][1]:
            prob_by_bucket[data[i][0]] = 1
            count_to_sample -= data[i][1]
        else:
            prob_by_bucket[data[i][0]] = bucket_size / float(data[i][1])
            count_to_sample -= bucket_size
        buckets_left -= 1

    print prob_by_bucket

    job = cluster.job()
    features = job.table(features_with_bucket) \
                  .map(strat_sample(prob_by_bucket)) \
                  .sort('key', 'subkey') \
                  .put(output_directory + "features")
    learn, test = features.split(sf.custom(lambda x : random() < args.test_ratio))
    learn.sort('key', 'subkey').put(output_directory + "learn")
    test.sort('key', 'subkey').put(output_directory + "test")
    job.run()

if __name__ == '__main__':
    main()
