#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
)
import re
import os
import json
import datetime
import math


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pool', '-p', required=True)
    parser.add_argument('--cluster', default='hahn')
    parser.add_argument('--source_table', '-s', required=True)
    parser.add_argument('--regexp_filter', '-r', default=None)
    parser.add_argument('--target_table', '-t', required=True)
    parser.add_argument('--token', '-y', required=True)
    parser.add_argument('--number', '-n', default=1000, type=int)
    parser.add_argument('--outfile', default='out.json')
    parser.add_argument('--query_column_name', '-q', default='query')
    parser.add_argument('--count_column_name', '-c', default='count')
    args = parser.parse_args()
    hahn = getattr(clusters, args.cluster.title())(
        pool=args.pool, token=args.token
    ).env(
        templates=dict(
            job_root='tmp/sample{}'.format(
                datetime.datetime.now().strftime('%s')
            ),
        )
    )

    intermediate_table = '$job_root/tmp'
    stats_table = '$job_root/stats'
    pool_table = args.target_table

    job = hahn.job()

    if os.path.isfile(args.source_table):
        args.source_table = json.load(open(args.source_table))['table']

    stream = job.table(
        args.source_table
    )

    if args.regexp_filter:
        stream = stream.filter(
            nf.custom(
                lambda x: not re.search(args.regexp_filter, x),
                args.query_column_name
            )
        )

    stream.project(
        ne.all(), cat2n=ne.custom(get_2n_category, args.count_column_name)
    ).put(
        '$job_root/tmp'
    ).groupby(
        'cat2n'
    ).aggregate(
        count=na.count()
    ).put(
        stats_table
    )

    job.run()

    recs = hahn.read(stats_table)
    cats = {rec.cat2n: rec.count for rec in recs}

    records_by_cat = {}

    target_number_basket = args.number

    cat_left = len(cats)
    for cat in sorted(cats, key=lambda x: cats[x]):
        ask = int(target_number_basket // cat_left + 1)
        if cats[cat] < ask:
            ask = cats[cat]
        records_by_cat[cat] = ask
        print('ask {} from category {}'.format(ask, cat))
        target_number_basket -= ask
        cat_left -= 1

    print('creating {}...'.format(pool_table))
    job = hahn.job().env(
        parallel_operations_limit=10
    )

    chosen = job.table(
        intermediate_table
    )
    to_concat = []

    for cat in records_by_cat:
        to_concat.append(
            chosen.filter(
                nf.equals('cat2n', cat)
            ).random(records_by_cat[cat])
        )

    job.concat(*to_concat).sort(
        'cat2n'
    ).put(
        pool_table
    )

    job.run()

    with codecs.open(args.outfile, 'w', 'utf8') as f:
        f.write(
            json.dumps({
                "cluster": args.cluster.lower(),
                "table": pool_table
            })
        )


if __name__ == "__main__":
    main()
