#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
import nile
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
)
import re
import os
import json
import datetime
import math
import random
from pytils import get_host


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def set_weight(x):
    return 1 / max(math.log(x + 1, 2), 1)


class WeightedChooser(object):
    def __init__(self, number, upper_bound):
        sr = random.SystemRandom()
        self.upper_bound = upper_bound
        self.nums = [
            sr.uniform(0, self.upper_bound)
            for _ in range(int(round(number * 1.2)))
        ]

    def __call__(self, groups):
        sr = random.SystemRandom()
        for key, records in groups:
            buff = 0
            for rec in records:
                good_numbers = [
                    x for x in self.nums
                    if buff < x <= (buff + rec.host_weight)
                ]
                if len(good_numbers) > 0:
                    yield rec
                buff += rec.host_weight
                if len(good_numbers) > 1:
                    for _ in range(len(good_numbers) - 1):
                        self.nums.append(
                            sr.uniform(buff, self.upper_bound)
                        )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pool', '-p', required=True)
    parser.add_argument('--pool_stats', '-ps', required=True)
    parser.add_argument('--cluster', default='hahn')
    parser.add_argument('--source_table', '-s', required=True)
    parser.add_argument('--regexp_filter', '-r', default=None)
    parser.add_argument('--target_table', '-t', required=True)
    parser.add_argument('--token', '-y', required=True)
    parser.add_argument('--job_root')
    parser.add_argument('--number', '-n', default=1000, type=int)
    parser.add_argument('--outfile', default='out.json')
    parser.add_argument('--do_not_clean_up', action='store_true')
    parser.add_argument('--query_column_name', '-q', default='query')
    parser.add_argument('--count_column_name', '-c', default='count')
    args = parser.parse_args()
    if not args.job_root:
        job_root = '//home/videolog/tmp/sample{}'.format(
            datetime.datetime.now().strftime('%s')
        )
    else:
        job_root = args.job_root
    hahn = getattr(clusters, args.cluster.title())(
        pool=args.pool, token=args.token
    ).env(
        templates=dict(
            job_root=job_root,
        )
    )

    stats_table = '$job_root/stats'
    pool_table = args.target_table

    job = hahn.job().env(
        parallel_operations_limit=10
    )

    if os.path.isfile(args.source_table):
        args.source_table = json.load(open(args.source_table))['table']

    stream = job.table(
        args.source_table
    ).filter(
        nf.custom(
            lambda x: x > 4, 'shows'
        )
    )

    if args.regexp_filter:
        stream = stream.filter(
            nf.custom(
                lambda x: not re.search(args.regexp_filter, x),
                args.query_column_name
            ),
        )

    stream_enriched = stream.project(
        ne.all(), cat2n=ne.custom(get_2n_category, args.count_column_name),
        host=ne.custom(get_host, args.query_column_name),
        files=[nile.files.LocalFile('pytils.py')]
    )

    stream_enriched.groupby(
        'cat2n'
    ).aggregate(
        count=na.count()
    ).put(
        stats_table
    )

    host_weights = stream_enriched.groupby(
        'host'
    ).aggregate(
        count=na.count()
    ).project(
        'host', 'count',
        host_weight=ne.custom(
            lambda x, y: (
                set_weight(x)
            ), 'count', 'host'
        )
    ).sort(
        'host_weight'
    ).put(
        '$job_root/host_weights'
    )

    stream_enriched.join(
        host_weights, by='host', type='inner'
    ).project(
        ne.all(), url_weight=ne.custom(
            lambda x: math.log(x, 2), args.count_column_name
        )
    ).project(
        ne.all(), weight=ne.custom(
            lambda x, y: x * y, 'url_weight', 'host_weight'
        )
    ).sort(args.count_column_name).put(
        '$job_root/aggregated'
    )

    job.run()

    categories = hahn.read(stats_table)

    cat_dict = {}
    for rec in categories:
        cat_dict[rec.cat2n] = rec.count

    target_number = args.number

    records_by_cat = {}
    take_all = {}

    cat_left = len(categories)
    for cat in sorted(cat_dict, key=lambda x: cat_dict[x]):
        ask = target_number // cat_left + 1
        if cat_dict[cat] < ask:
            ask = cat_dict[cat]
            take_all[cat] = True
        records_by_cat[cat] = ask
        print('Will take {} records from cat {}'.format(ask, cat))
        target_number -= ask
        cat_left -= 1
    print('Will take {} records total'.format(sum(records_by_cat.values())))

    print('creating {}...'.format(pool_table))
    to_concat = []
    for cat in records_by_cat:
        result = '$job_root/cat2n_pool_{}'.format(cat)
        if not take_all.get(cat):
            job = hahn.job()

            job.table(
                '$job_root/aggregated'
            ).filter(
                nf.equals('cat2n', cat), memory_limit=16384
            ).put(
                '$job_root/cat2n_{}'.format(cat)
            ).groupby().aggregate(
                upper_bound=na.sum('host_weight')
            ).put(
                '$job_root/upper_bound_{}'.format(cat)
            )

            job.run()

            upper_bound = hahn.read(
                '$job_root/upper_bound_{}'.format(cat)
            )[0].upper_bound

            job = hahn.job()

            job.table(
                '$job_root/cat2n_{}'.format(cat)
            ).groupby().reduce(
                WeightedChooser(
                    number=records_by_cat[cat], upper_bound=upper_bound
                )
            ).put(
                result
            )
            job.run()
        else:
            job = hahn.job()
            job.table(
                '$job_root/aggregated'
            ).filter(
                nf.equals('cat2n', cat), memory_limit=16384
            ).put(
                result
            )

            job.run()
        to_concat.append(result)

    job = hahn.job()

    to_concat = [job.table(x) for x in to_concat]
    job.concat(*to_concat).put(
        args.target_table
    )

    job.run()

    job = hahn.job()

    job.table(
        args.target_table
    ).groupby('host').aggregate(
        total=na.count()
    ).sort('total').put(
        '$job_root/pool_stats'
    )

    job.run()

    stats = [x.to_dict() for x in hahn.read('$job_root/pool_stats')]

    job.run()

    with codecs.open(args.outfile, 'w', 'utf8') as f:
        f.write(
            json.dumps({
                "cluster": args.cluster.lower(),
                "table": pool_table
            })
        )

    json.dump(stats, open(args.pool_stats, 'w'), indent=2, sort_keys=True)

    if not args.do_not_clean_up:
        hahn.driver.client.remove(job_root, recursive=True)


if __name__ == "__main__":
    main()
