#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    Record
)
import re
import json
from nile.utils.misc import coerce_path
import getpass
import random
import datetime
import math
import copy
import string
import itertools
from collections import defaultdict
import numpy as np
from pytils import get_yt_exists


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def aggregate_queries(groups, search_engine):
    for key, records in groups:
        if not key.query:
            continue
        uid = ""
        ts = ""
        reqs = 0
        lrs = defaultdict(int)
        countries = defaultdict(str)
        for rec in records:
            if rec.lr:
                lrs[int(rec.lr)] += rec.reqs
                countries[int(rec.lr)] = rec.country
            if not uid and rec.get('uid'):
                uid = rec.uid
            if not ts and rec.get('ts'):
                ts = rec.ts
            reqs += rec.reqs
        result = vars(key)
        result['cat2n'] = get_2n_category(reqs)
        result['reqs'] = reqs

        if not uid:
            uid = 'uid'
        result['uid'] = uid + '-' + search_engine + '-' + key.service + '-' + str(result['cat2n'])
        result['ts'] = str(ts)

        total = float(sum(lrs.values()))
        p = [x / total for x in lrs.values()]
        result['lr'] = np.random.choice(lrs.keys(), p=p)
        result['country'] = countries[result['lr']]

        yield Record(**result)

def aggregate_queries_yandex(groups):
    for r in aggregate_queries(groups, 'yandex'):
        yield r

def aggregate_queries_google(groups):
    for r in aggregate_queries(groups, 'google'):
        yield r


class GetHRPath(object):

    def __init__(self, cluster):
        self.cluster = cluster

    def __call__(self, path):
        path = str(
            coerce_path(
                path
            ).eval(**self.cluster.environment.templates)
        )
        if not path.startswith('//'):
            return '//' + path
        return path


def main():
    hahn = clusters.yt.Hahn(
        # pool='search-research_{}'.format(getpass.getuser())
    ).env(
        templates=dict(
            job_root='//home/images/dev/nerevar/baskets_img/2017-10-23-RU-DUPS',
        ),
        package_paths=['.'],
        packages=['numpy']
    )
    yt = hahn.driver.client
    yt_exists = get_yt_exists(yt)
    get_hr_path = GetHRPath(hahn)

    # =========================== 1. Yandex filtered & aggregated queries for Y.Web + Y.Img ===========================
    yandex_queries_table = get_hr_path('$job_root/yandex_queries')
    yandex_queries_aggr_table = get_hr_path('$job_root/yandex_queries_aggr')

    if not yt_exists(yandex_queries_table):
        print('1.1 {} does not exist, creating it'.format(yandex_queries_table))
        job = hahn.job()

        job.table(
            '//home/search-research/ensuetina/'
            'QUERIES_MINING/queries_with_country'
        ).filter(
            nf.and_(
                nf.custom(lambda x: x == 'RU', 'country'),
                nf.custom(lambda x: bool(x), 'query'),
                nf.custom(lambda x: x in {'web', 'img'}, 'service'),
            )
        ).project(
            ne.all(), cat2n=ne.custom(get_2n_category, 'reqs')
        ).sort(
            'service', 'cat2n'
        ).put(
            yandex_queries_table
        )

        job.run()
    else:
        print('1.1 {} already exists, using it'.format(yandex_queries_table))

    if not yt_exists(yandex_queries_aggr_table):
        print('1.2 {} does not exist, creating it'.format(yandex_queries_aggr_table))
        job = hahn.job()

        job.table(
            yandex_queries_table
        ).groupby(
            'platform', 'service', 'query'
        ).reduce(
            aggregate_queries_yandex,
            memory_limit=3 * 1024
        ).sort('reqs').put(
            yandex_queries_aggr_table
        )

        job.run()
    else:
        print('1.2 {} already exists, using it'.format(yandex_queries_aggr_table))

    # =========================== 2. Google filtered & aggregated queries for G.Web + G.Img ===========================
    google_queries_table = get_hr_path('$job_root/google_queries')
    google_queries_aggr_table = get_hr_path('$job_root/google_queries_aggr')

    if not yt_exists(google_queries_table):
        print('2.1 {} does not exist, creating it'.format(google_queries_table))

        job = hahn.job()

        job.table(
            '//home/images/dev/nerevar/baskets_img/google/google_all_queries2'
        ).filter(
            nf.and_(
                nf.custom(lambda x: x == 'RU', 'country'),
                nf.custom(lambda x: bool(x), 'query'),
                nf.custom(lambda x: x in {'web', 'img'}, 'service'),
            )
        ).project(
            ne.all(), cat2n=ne.custom(get_2n_category, 'reqs')
        ).sort(
            'service', 'cat2n'
        ).put(
            google_queries_table
        )

        job.run()
    else:
        print('2.1 {} already exists, using it'.format(google_queries_table))

    if not yt_exists(google_queries_aggr_table):
        print('2.2 {} does not exist, creating it'.format(google_queries_aggr_table))

        job = hahn.job()

        job.table(
            google_queries_table
        ).groupby(
            'platform', 'service', 'query'
        ).reduce(
            aggregate_queries_google,
            memory_limit=3*1024
        ).sort('reqs').put(
            google_queries_aggr_table
        )

        job.run()
    else:
        print('2.2 {} already exists, using it'.format(google_queries_aggr_table))

    # =========================== 3. Get stats and sample parts for: Y.Web + Y.Img + G.Web + G.Img ===========================
    for tup in itertools.product(
        ('yandex', 'google'),
        ('desktop', 'touch'),
        ('web', 'img')
    ):
        search = tup[0]
        platform = tup[1]
        service = tup[2]
        target_number = 250 * 1.2
        if service == 'web':
            target_number *= 10
        print('target number for {} {} {}: {}'.format(
            search, platform, service, target_number
        ))
        job = hahn.job()

        stats_table = '$job_root/stats_filtered_cat2n_{}_{}_{}'.format(
            search, platform, service
        )
        filtered_table = '$job_root/filtered_queries_{}_{}_{}'.format(
            search, platform, service
        )

        filtered = job.table(
            get_hr_path('$job_root/{}_queries_aggr'.format(search))
        ).filter(
            nf.and_(
                nf.equals('platform', platform),
                nf.equals('service', service)
            )
        ).sort(
            'cat2n'
        ).put(
            filtered_table
        )

        filtered.groupby(
            'cat2n'
        ).aggregate(
            count=na.count()
        ).put(
            stats_table
        )

        job.run()

        recs = hahn.read(stats_table)
        cats = {rec.cat2n: rec.count for rec in recs}

        records_by_cat = {}

        target_number_basket = target_number

        cat_left = len(cats)
        for cat in sorted(cats, key=lambda x: cats[x]):
            ask = target_number_basket // cat_left + 1
            if cats[cat] < ask:
                ask = cats[cat]
            records_by_cat[cat] = ask
            print('ask {} from category {}'.format(ask, cat))
            target_number_basket -= ask
            cat_left -= 1

        for basket_type in ('validate'):
            job = hahn.job().env(
                parallel_operations_limit=10
            )

            chosen = job.table(
                filtered_table
            )
            to_concat = []

            for cat in records_by_cat:
                to_concat.append(
                    chosen.filter(
                        nf.equals('cat2n', cat)
                    ).random(records_by_cat[cat])
                )

            job.concat(
                *to_concat
            ).project(
                ne.all(),
                inv_cat2n=ne.custom(lambda x: -x, 'cat2n')
            ).sort(
                'inv_cat2n'
            ).put(
                '$job_root/01_pool_{search}_{service}_{platform}_{basket_type}'.format(
                    search=search,
                    service=service,
                    platform=platform,
                    basket_type=basket_type
                )
            )

            job.run()

if __name__ == "__main__":
    main()
