#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
import sys
import os
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    Record
)
import getpass
import datetime
import itertools
import numpy as np
from collections import Counter


def get_2n_category(n):
    assert isinstance(n, int)
    for x in range(1, 25):
        if n <= (2 ** x):
            return 2 ** x
    return 2 ** 25


def aggregate_queries_with_lr(groups):
    for key, records in groups:
        if not key.query:
            continue
        ts = ""
        reqs = 0
        lrs = {}
        for rec in records:
            if rec.lr:
                if rec.lr in lrs:
                    lrs[rec.lr] += rec.reqs
                else:
                    lrs[rec.lr] = rec.reqs
            if not ts and rec.ts:
                ts = rec.ts
            reqs += rec.reqs
        result = vars(key)
        result['cat2n'] = get_2n_category(reqs)
        result['reqs'] = reqs

        total = float(sum(lrs.values()))
        p = [x / total for x in lrs.values()]
        result['lr'] = np.random.choice(lrs.keys(), p=p)

        yield Record(**result)


queries_count_by_country = {
    'RU': 500,
    'TR': 250,
    'UA': 200,
    'BY': 125,
    'KZ': 125,
    'UZ': 125,
}

def main():
    hahn = clusters.yt.Hahn(
        pool='search-research_{}'.format(getpass.getuser())
    ).env(
        templates=dict(
            job_root='//home/images/dev/nerevar/baskets_img/2017-08-26-final-google-3',
        ),
        package_paths=['.'],
        packages=['numpy']
    )

    for country in ['RU', 'UA', 'BY', 'KZ', 'TR', 'UZ']:
        queries_table = '$job_root/{}/queries'.format(country)

        print('Prepare country {}'.format(country))

        job = hahn.job()

        job.table(
            '//home/images/dev/nerevar/baskets_img/google/google_all_queries2'
        ).filter(
            nf.and_(
                nf.equals('country', country),
                nf.custom(lambda x: bool(x), 'query'),
                nf.custom(lambda x: x in {'web', 'img'}, 'service')
            )
        ).project(
            ne.all(), cat2n=ne.custom(get_2n_category, 'reqs')
        ).sort(
            'service', 'cat2n'
        ).put(
            queries_table
        )

        job.run()

        job = hahn.job()

        aggr_table = '$job_root/{}/queries_aggr'.format(country)

        job.table(
            queries_table
        ).groupby(
            'country', 'platform', 'service', 'query'
        ).reduce(
            aggregate_queries_with_lr,
            memory_limit=3*1024
        ).sort('reqs').put(
            aggr_table
        )

        job.run()

        for tup in itertools.product(
            ('desktop', 'touch'),
            ('web', 'img')
        ):
            platform = tup[0]
            service = tup[1]
            target_number = int(queries_count_by_country[country] * 1.2)
            if service == 'web':
                target_number *= 10
            print('target number for {} {}: {}'.format(
                platform, service, target_number
            ))
            job = hahn.job()

            stats_table = '$job_root/{}/filtered_cat2n_stats_{}_{}'.format(
                country, platform, service
            )
            filtered_table = '$job_root/{}/filtered_{}_{}'.format(
                country, platform, service
            )

            filtered = job.table(
                aggr_table
            ).filter(
                nf.and_(
                    nf.equals('platform', platform),
                    nf.equals('service', service)
                )
            ).sort(
                'cat2n'
            ).put(
                filtered_table
            )

            filtered.groupby(
                'cat2n'
            ).aggregate(
                count=na.count()
            ).put(
                stats_table
            )

            job.run()

            recs = hahn.read(stats_table)
            cats = {rec.cat2n: rec.count for rec in recs}

            records_by_cat = {}

            target_number_basket = target_number

            cat_left = len(cats)
            for cat in sorted(cats, key=lambda x: cats[x]):
                ask = target_number_basket // cat_left + 1
                if cats[cat] < ask:
                    ask = cats[cat]
                records_by_cat[cat] = ask
                print('ask {} from category {}'.format(ask, cat))
                target_number_basket -= ask
                cat_left -= 1


            for basket_type in ('kpi', 'validate'):
                job = hahn.job().env(
                    parallel_operations_limit=10
                )

                chosen = job.table(
                    filtered_table
                )
                to_concat = []

                for cat in records_by_cat:
                    to_concat.append(
                        chosen.filter(
                            nf.equals('cat2n', cat)
                        ).random(records_by_cat[cat])
                    )

                job.concat(
                    *to_concat
                ).project(
                    ne.all(),
                    inv_cat2n=ne.custom(lambda x: -x, 'cat2n')
                ).sort(
                    'inv_cat2n'
                ).put(
                    '$job_root/{country}/01_pool_google_{service}_{country}_{platform}_{basket_type}'.format(
                        country=country,
                        service=service,
                        platform=platform,
                        basket_type=basket_type
                    )
                )

                job.run()


if __name__ == "__main__":
    main()
