#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
from nile.utils.misc import coerce_path
from nile.api.v1 import (
    clusters,
    filters as nf,
    # extractors as ne,
    # aggregators as na,
    # Record
)
import itertools


def get_yt_exists(yt):
    def yt_exists(table):
        if not yt.exists(table):
            return False
        if yt.get_attribute(table, 'row_count'):
            return True
        else:
            return False
    return yt_exists


class GetHRPath(object):

    def __init__(self, cluster):
        self.cluster = cluster

    def __call__(self, path):
        path = str(
            coerce_path(
                path
            ).eval(**self.cluster.environment.templates)
        )
        if not path.startswith('//'):
            return '//' + path
        return path


def sample_basket_part(hahn, params, queries_count, basket_types):
    yt = hahn.driver.client
    yt_exists = get_yt_exists(yt)
    get_hr_path = GetHRPath(hahn)

    country, platform, search, service = params

    source_table = '$job_root/{country}/{platform}/{search}/{service}/01_queries_aggr'.format(
        country=country,
        search=search,
        service=service,
        platform=platform
    )
    stats_table = '$job_root/{country}/{platform}/{search}/{service}/02_stats'.format(
        country=country,
        search=search,
        service=service,
        platform=platform
    )

    target_number = queries_count[country] / 4.0
    target_number *= 1.5  # duplicates
    if service == 'web':
        target_number *= 2.5  # уместность + бред
    else:
        target_number *= 1.5   # бред
    if country == 'exUSSR':
        target_number *= 1.5
    target_number = int(target_number)

    print('target number for {} {} {} {}: {}'.format(
        country, platform, search, service, target_number
    ))

    recs = hahn.read(stats_table)
    cats = {rec.bucket: rec.count for rec in recs}

    records_by_cat = {}

    cat_left = len(cats)
    for cat in sorted(cats, key=lambda x: cats[x]):
        ask = target_number // cat_left + 1
        if cats[cat] < ask:
            ask = cats[cat]
        records_by_cat[cat] = ask
        print('ask {} from category {}'.format(ask, cat))
        target_number -= ask
        cat_left -= 1

    baskets = []
    for basket_type in basket_types:
        basket_part_table = '$job_root/{country}/{platform}/{search}/{service}/03_basket_part_{basket_type}'.format(
            country=country,
            search=search,
            service=service,
            platform=platform,
            basket_type=basket_type
        )

        if yt_exists(get_hr_path(basket_part_table)):
            print '\n{} — {} — table exists'.format(params, basket_part_table)
        else:
            print '\n{} — {} — create table'.format(params, basket_part_table)

            job = hahn.job().env(
                parallel_operations_limit=8
            )

            chosen = job.table(source_table)
            to_concat = []
            for cat in records_by_cat:
                to_concat.append(
                    chosen.filter(
                        nf.equals('bucket', cat)
                    ).random(records_by_cat[cat])
                )

            job.concat(
                *to_concat
            ).sort(
                'bucket'
            ).put(
                basket_part_table
            )
            job.run()
        baskets.append(basket_part_table)
    return baskets


def main(token=None):
    hahn = clusters.Hahn(
        token=token
    ).env(
        templates=dict(
            job_root='//home/images/dev/nerevar/baskets_img/2018Q3',
        ),
        package_paths=['.'],
        packages=['numpy']
    )

    queries_count = {
        'RU': 1680,
        'BY': 600,
        'KZ': 600,
        'UA': 360,
        'UZ': 360,
        'exUSSR': 400,
    }

    tables_list = []
    for tup in itertools.product(
        ['RU', 'UA', 'BY', 'KZ', 'UZ', 'exUSSR'],
        ['touch', 'desktop'],
        ['google'],  # TODO: yandex
        ['img', 'web'],
    ):
        table_name = sample_basket_part(hahn, tup, queries_count, ['kpi', 'validate'])

        tables_list.append((tup, table_name))
    return tables_list


if __name__ == "__main__":
    main()
