#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
)
import re
import os
import json
import datetime
import math
import random


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pool", "-p", required=True)
    parser.add_argument("--cluster", default="hahn")
    parser.add_argument("--source_table", "-s", required=True)
    parser.add_argument("--regexp_filter", "-r", default=None)
    parser.add_argument("--target_table", "-t", required=True)
    parser.add_argument("--number", "-n", default=1000, type=int)
    parser.add_argument("--outfile", default="out.json")
    parser.add_argument("--override_job_root")
    parser.add_argument("--query_column_name", "-q", default="query")
    parser.add_argument("--count_column_name", "-c", default="count")
    args = parser.parse_args()
    hahn = getattr(clusters, args.cluster.title())(
        pool=args.pool, token=os.environ["YT_TOKEN"]
    ).env(
        templates=dict(
            job_root=args.override_job_root
            or "tmp/sample{}".format(datetime.datetime.now().strftime("%s"))
        )
    )

    intermediate_table = "$job_root/tmp"
    stats_table = "$job_root/stats"
    pool_table = args.target_table

    if not args.override_job_root:
        job = hahn.job()

        if os.path.isfile(args.source_table):
            args.source_table = json.load(open(args.source_table))["table"]

        stream = job.table(
            args.source_table,
            weak_schema={
                args.query_column_name: str,
                args.count_column_name: int,
            },
        )

        if args.regexp_filter:
            stream = stream.filter(
                nf.custom(
                    lambda x: not re.search(args.regexp_filter, x),
                    args.query_column_name,
                )
            )

        stream.project(
            ne.all(),
            cat2n=ne.custom(get_2n_category, args.count_column_name).with_type(
                int
            ),
        ).put("$job_root/tmp").groupby("cat2n").aggregate(
            count=na.count(), freqs=na.sum(args.count_column_name)
        ).put(
            stats_table
        )

        job.run()

    recs = list(hahn.read(stats_table))
    total = sum([x.freqs for x in recs])

    records_by_cat = {}

    target_number_basket = args.number

    total_ = total
    cat_left = len(recs)
    for cat in sorted(recs, key=lambda x: x.cat2n, reverse=True):
        ask = math.ceil(target_number_basket * cat.freqs / float(total_))
        if cat.count < ask:
            ask = cat.count
        records_by_cat[int(cat.cat2n)] = int(ask)
        print("ask {} from category {}".format(ask, cat))
        target_number_basket -= ask
        cat_left -= 1
        total_ -= cat.freqs

    print("creating {}...".format(pool_table))
    job = hahn.job().env(parallel_operations_limit=10)

    chosen = job.table(intermediate_table)
    to_concat = []

    for cat in records_by_cat:
        filtered = chosen.filter(nf.equals("cat2n", int(cat)))
        target = records_by_cat[cat]
        if int(cat) in (0, 1):
            to_concat.append(
                filtered.project(
                    ne.all(),
                    rnd=ne.custom(
                        lambda x: random.randint(1, 100000),
                        args.query_column_name,
                    ).with_type(int),
                )
                .sort("rnd")
                .take(target)
            )
        else:
            to_concat.append(filtered.random(target))

    job.concat(*to_concat).sort("cat2n").put(pool_table)

    job.run()

    with codecs.open(args.outfile, "w", "utf8") as f:
        f.write(
            json.dumps({"cluster": args.cluster.lower(), "table": pool_table})
        )


if __name__ == "__main__":
    main()
