#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
from nile.api.v1 import clusters, filters as nf, extractors as ne, aggregators as na
from qb2.api.v1 import typing as qt
import re
import os
import json
import datetime
import math
import random
from videolog_common import get_driver

PYTHON3_LAYERS = [
    "//porto_layers/base/bionic/porto_layer_search_ubuntu_bionic_app_lastest.tar.gz",
    "//home/portolayer/bionic/python3/latest",
]


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def get_tmp_table_name(prefix, entry):
    table = entry["source_table"].replace("/", "_")
    if "name" in entry:
        table += "_{}".format(entry["name"])
    return "$job_root/{}_{}".format(prefix, table)


def wrap_weak_schema(dict_):
    return {k: eval(v) for k, v in dict_.items()}


def prefilter_tables(config, cluster, env_kwargs):
    job = cluster.job().env(**env_kwargs)
    yt = get_driver(cluster).client

    for entry in config["entries"]:
        kwargs = (
            {"weak_schema": wrap_weak_schema(entry["weak_schema"])}
            if "weak_schema" in entry
            else {}
        )
        stream = job.table(entry["source_table"], **kwargs)

        if "filter" in entry:
            args = [eval(entry["filter"][0])] + entry["filter"][1:]
            stream = stream.filter(nf.custom(*args))

        try:
            schema = {
                x["name"] for x in yt.get_attribute(entry["source_table"], "schema")
            }
        except:
            schema = set((entry.get("weak_schema") or {}).keys())

        if "cat2n" not in schema:
            stream = stream.project(
                ne.all(),
                cat2n=ne.custom(get_2n_category, entry["count_column_name"]).with_type(
                    int
                ),
            )

        stream.put(get_tmp_table_name("tmp", entry)).groupby("cat2n").aggregate(
            count=na.count(), freqs=na.sum(entry["count_column_name"])
        ).put(get_tmp_table_name("stats", entry))

    job.run()


def filter_into_pool(entry, cluster, job, to_concat):
    tmp_table = get_tmp_table_name("tmp", entry)
    stats_table = get_tmp_table_name("stats", entry)
    print(
        "filtering {} into pool...".format(
            entry["name"] if "name" in entry else entry["source_table"]
        )
    )

    recs = list(cluster.read(stats_table))
    total = sum([x.freqs for x in recs])

    records_by_cat = {}

    target_number_basket = entry["target_number"]
    print("target number is {}".format(target_number_basket))

    total_ = total
    cat_left = len(recs)
    for cat in sorted(recs, key=lambda x: x.cat2n, reverse=True):
        ask = math.ceil(target_number_basket * cat.freqs / float(total_))
        if cat.count < ask:
            ask = cat.count
        records_by_cat[int(cat.cat2n)] = int(ask)
        print("ask {} from category {}".format(ask, cat))
        target_number_basket -= ask
        cat_left -= 1
        total_ -= cat.freqs

    chosen = job.table(tmp_table)

    for cat in records_by_cat:
        filtered = chosen.filter(nf.equals("cat2n", int(cat)))
        target = records_by_cat[cat]
        if int(cat) in (0, 1):
            to_concat.append(
                filtered.project(
                    ne.all(),
                    rnd=ne.custom(
                        lambda x: random.randint(1, 100000), entry["count_column_name"]
                    ).with_type(int),
                )
                .sort("rnd")
                .take(target)
            )
        else:
            to_concat.append(filtered.random(target))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pool", "-p", required=True)
    parser.add_argument("--cluster", default="hahn")
    parser.add_argument("--config", "-c")
    parser.add_argument("--outfile", default="out.json")
    parser.add_argument("--override_job_root")
    args = parser.parse_args()
    cluster = getattr(clusters, args.cluster.title())(
        pool=args.pool, token=os.environ["YT_TOKEN"]
    ).env(
        templates=dict(
            job_root=args.override_job_root
            or "tmp/sample{}".format(datetime.datetime.now().strftime("%s"))
        )
    )

    with open(args.config) as f:
        config = json.load(f)

    yt_spec_defaults = {
        "mapper": {"layer_paths": config.get("layer_paths") or PYTHON3_LAYERS},
        "reducer": {"layer_paths": config.get("layer_paths") or PYTHON3_LAYERS},
    }
    env_kwargs = {
        "parallel_operations_limit": 10,
        "yt_spec_defaults": yt_spec_defaults
    }

    if not args.override_job_root:
        prefilter_tables(config, cluster, env_kwargs)

    print("creating {}...".format(config["pool_table"]))
    job = cluster.job().env(**env_kwargs)
    to_concat = []
    for entry in config["entries"]:
        filter_into_pool(entry, cluster, job, to_concat)
    job.concat(*to_concat).sort("cat2n").put(config["pool_table"])
    job.run()

    with codecs.open(args.outfile, "w", "utf8") as f:
        f.write(
            json.dumps({"cluster": args.cluster.lower(), "table": config["pool_table"]})
        )


if __name__ == "__main__":
    main()
