#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
)
import re
import os
import json
import datetime
import math
# import concurrent.futures


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


all_tables = []


def sample(target_table, source_table, hahn, count_column_name, number):
    stats_table = '//home/videolog/tmp/{}/stats'.format(
        target_table.split('/')[-1]
    )
    pool_table = target_table

    job = hahn.job()

    job.table(
        source_table
    ).project(
        ne.all(), cat2n=ne.custom(get_2n_category, count_column_name)
    ).groupby(
        'cat2n'
    ).aggregate(
        count=na.count()
    ).put(
        stats_table
    )

    job.run()

    recs = hahn.read(stats_table)
    cats = {rec.cat2n: rec.count for rec in recs}

    records_by_cat = {}

    target_number_basket = number

    cat_left = len(cats)
    for cat in sorted(cats, key=lambda x: cats[x]):
        ask = int(target_number_basket // cat_left + 1)
        if cats[cat] < ask:
            ask = cats[cat]
        records_by_cat[cat] = ask
        print('ask {} from category {}'.format(ask, cat))
        target_number_basket -= ask
        cat_left -= 1

    print('creating {}...'.format(pool_table))
    job = hahn.job().env(
        parallel_operations_limit=10,
    )

    chosen = job.table(
        source_table
    )
    to_concat = []

    for cat in records_by_cat:
        to_concat.append(
            chosen.filter(
                nf.equals('cat2n', cat)
            ).random(records_by_cat[cat], memory_limit=2000)
        )

    job.concat(*to_concat).sort(
        'cat2n'
    ).put(
        pool_table
    )

    job.run()

    all_tables.append(pool_table)


def to_pairs(lst):
    return zip(lst[::2], lst[1::2])


default_config = {
    "web": 10,
    "vid": 100,
}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pool', '-p', required=True)
    parser.add_argument('--cluster', default='hahn')
    parser.add_argument('--tables_list', required=True)
    parser.add_argument('--coefficients')
    parser.add_argument('--regexp_filter', '-r', default=None)
    parser.add_argument('--outfile', default='out.json')
    parser.add_argument(
        '--root',
        default='//home/videolog/2018Q1_baskets/candidates/desktop_validate'
    )
    parser.add_argument('--query_column_name', '-q', default='query')
    parser.add_argument('--count_column_name', '-c', default='count')
    parser.add_argument('--min1k', action='store_true')
    parser.add_argument('--reuse', action='store_true')
    args = parser.parse_args()

    hahn = getattr(clusters, args.cluster.title())(
        pool=args.pool, token=os.environ['YT_TOKEN']
    )

    if args.coefficients:
        coefficients = json.load(open(args.coefficients))
    else:
        coefficients = default_config

    config = json.load(open(args.tables_list))
    output_tables = []
    for table in config:
        parsed_table_name = dict(to_pairs(table.split('/')[-1].split('_')))
        system = 'yandex' if 'yandex' in table else 'google'
        target_table = '{}/{}/{}_{}_{}'.format(
            args.root, parsed_table_name['country'],
            parsed_table_name['country'], system, table.split('/')[-1]
        )
        if (
            args.reuse and
            hahn.driver.exists(target_table) and
            hahn.driver.get_attribute(target_table, 'row_count') or 0
        ):
            print('skipping {}'.format(target_table))
            output_tables.append(target_table)
            continue
        number = int(
            int(parsed_table_name['count']) *
            coefficients[parsed_table_name['service']],
        )
        if args.min1k and number < 1000:
            number = 1000
        sample(
            target_table,
            table,
            hahn,
            args.count_column_name,
            number
        )
        output_tables.append(target_table)

    json.dump(output_tables, open(args.outfile, 'w'), indent=2)


if __name__ == "__main__":
    main()
