#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    Record
)
import json
from nile.utils.misc import coerce_path
import getpass
import random
import datetime
import math
import copy
import itertools
from collections import Counter
from pytils import get_yt_exists
from v4_daily_yt import get_country


def get_geobase(path='geobase.json'):
    gb = json.load(open(path))
    geobase = {int(x['id']): preprocess_geobase(x) for x in gb}
    return geobase


def preprocess_geobase(z):
    x = copy.deepcopy(z)
    for k in ['id', 'type']:
        x[k] = int(x[k])
    x['path'] = [int(y) for y in x['path'].split(', ') if y]
    return x


geobase = get_geobase('geobase20170915.json')


def get_2n_category(n):
    return min(int(round(math.log(n, 2))), 25)


def aggregate_queries(groups):
    rnd = random.SystemRandom()
    for key, records in groups:
        if not key.query:
            continue
        uid = ""
        ts = ""
        lr = ""
        reqs = 0
        for rec in records:
            if not lr and rec.lr:
                lr = rec.lr
            elif lr and rec.lr and rnd.randint(1, 10) == 10:
                lr = rec.lr
            if not uid and rec.uid:
                uid = rec.uid
            if not ts and rec.ts:
                ts = rec.ts
            reqs += rec.reqs
        result = vars(key)
        result['uid'] = uid
        result['ts'] = ts
        result['lr'] = lr
        result['cat2n'] = get_2n_category(reqs)
        result['reqs'] = reqs
        yield Record(**result)


numbers_by_country = {
    'RU': 2000,
    'TR': 950,
    'UA': 650,
    'BY': 450,
    'KZ': 400,
    'UZ': 400,
}


lr_by_country = {
    'AZ': 167,
    'AM': 168,
    'GE': 169,
    'IL': 181,
    'KG': 207,
    'LV': 206,
    'LT': 117,
    'MD': 208,
    'TJ': 209,
    'TM': 170,
    'EE': 179,
}


class GetHRPath(object):

    def __init__(self, cluster):
        self.cluster = cluster

    def __call__(self, path):
        path = str(
            coerce_path(
                path
            ).eval(**self.cluster.environment.templates)
        )
        if not path.startswith('//'):
            return '//' + path
        return path


def aggregate_queries_google(groups):
    rnd = random.SystemRandom()
    for key, records in groups:
        if not key.query:
            continue
        lr = ""
        reqs = 0
        for rec in records:
            if not lr and rec.region:
                lr = rec.region
            elif lr and rec.region and rnd.randint(1, 10) == 10:
                lr = rec.region
            reqs += rec.paircount
        result = vars(key)
        result['lr'] = lr
        result['cat2n'] = get_2n_category(reqs)
        result['reqs'] = reqs
        yield Record(**result)


class CountryFilter(object):
    def __init__(self, country, geobase):
        self.geobase = geobase
        self.country = country

    def __call__(self, lr):
        try:
            lr = int(lr)
        except ValueError:
            return False
        try:
            gb_obj = self.geobase[lr]
        except KeyError:
            return False
        if self.country in gb_obj.get('path', []):
            return True
        return False


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config', default='2017-10-18_new_baskets_config.json'
    )
    args = parser.parse_args()
    hahn = clusters.yt.Hahn(
        pool='search-research_{}'.format(getpass.getuser())
    ).env(
        templates=dict(
            job_root='home/videolog/2017-10-18_world_baskets',
        )
    )
    yt = hahn.driver.client
    yt_exists = get_yt_exists(yt)
    get_hr_path = GetHRPath(hahn)
    config = json.load(open(args.config))

    for country in config['countries']:
        queries_table = get_hr_path('$job_root/{}/queries'.format(country))
        aggr_table = get_hr_path('$job_root/{}/queries_aggr'.format(country))
        google_table_desktop = get_hr_path(
            '$job_root/{}/filtered_desktop_google'.format(country)
        )
        google_table_touch = get_hr_path(
            '$job_root/{}/filtered_touch_google'.format(country)
        )

        if not yt_exists(queries_table):
            print('{} does not exist, creating it'.format(queries_table))
            job = hahn.job()

            job.table(
                '//home/search-research/ensuetina/'
                'QUERIES_MINING/exUSSR'
            ).filter(
                nf.and_(
                    nf.equals('country', country),
                    nf.custom(lambda x: bool(x), 'query'),
                    nf.custom(lambda x: x in {'web', 'vid'}, 'service'),
                    nf.custom(
                        CountryFilter(lr_by_country[country], geobase), 'lr'
                    )
                )
            ).project(
                ne.all(), cat2n=ne.custom(get_2n_category, 'reqs')
            ).put(
                queries_table
            )

            job.run()
        else:
            print('{} already exists, using it'.format(queries_table))

        if not yt_exists(aggr_table):
            print('{} does not exist, creating it'.format(aggr_table))
            job = hahn.job()

            job.table(
                queries_table
            ).groupby(
                'country', 'platform', 'service', 'query'
            ).reduce(
                aggregate_queries
            ).sort('reqs').put(
                aggr_table
            )

            job.run()
        else:
            print('{} already exists, using it'.format(aggr_table))

        if not yt_exists(google_table_desktop):
            print('{} does not exist, creating it'.format(
                google_table_desktop
            ))
            job = hahn.job()

            job.table(
                '//home/goda/zyko/classes/nano_frequency_agg_sg2'
            ).filter(
                nf.custom(
                    lambda x: (x or '').startswith('www.google.') and
                    (x or '').endswith('.' + country.lower()), 'domain'
                )
            ).project(
                ne.all(), country=ne.const(country)
            ).groupby(
                "domain", "query"
            ).reduce(
                aggregate_queries_google
            ).sort('reqs').put(
                google_table_desktop
            )
            job.run()
        else:
            print('{} already exists, using it'.format(google_table_desktop))

        if not yt_exists(google_table_touch):
            print('{} does not exist, creating it'.format(google_table_touch))
            job = hahn.job()

            job.table(
                '//home/goda/zyko/classes/nano_frequency_agg_sgmob2'
            ).filter(
                nf.custom(
                    lambda x: (x or '').startswith('www.google.') and
                    (x or '').endswith('.' + country.lower()), 'domain'
                )
            ).project(
                ne.all(), country=ne.const(country)
            ).groupby(
                "domain", "query"
            ).reduce(
                aggregate_queries_google
            ).sort('reqs').put(
                google_table_touch
            )
            job.run()
        else:
            print('{} already exists, using it'.format(google_table_touch))

        tups = list(itertools.product(
            ('desktop', 'touch'),
            ('web', 'vid', 'google')
        ))

        for tup in tups:
            platform = tup[0]
            service = tup[1]
            local_config = None
            for comp in config['countries'][country]:
                if comp['platform'] == platform and comp['service'] == service:
                    local_config = comp
            if not local_config:
                continue
            target_number = local_config['number']
            print('target number for {} {}: {}'.format(
                platform, service, target_number
            ))
            stats_table = get_hr_path(
                '$job_root/{}/filtered_cat2n_stats_{}_{}'.format(
                    country, platform, service
                )
            )
            filtered_table = get_hr_path('$job_root/{}/filtered_{}_{}'.format(
                country, platform, service
            ))

            if not yt_exists(filtered_table):
                if service == 'google':
                    print('google\'s table does not exist, exiting')
                    sys.exit(1)
                print('{} does not exist, creating it'.format(filtered_table))
                job = hahn.job()

                job.table(
                    aggr_table
                ).filter(
                    nf.and_(
                        nf.equals('platform', platform),
                        nf.equals('service', service)
                    )
                ).put(
                    filtered_table
                )

                job.run()
            else:
                print('{} already exists, using it'.format(filtered_table))

            if not yt_exists(stats_table):
                print('{} does not exist, creating it'.format(stats_table))
                job = hahn.job()
                job.table(
                    filtered_table
                ).groupby(
                    'cat2n'
                ).aggregate(
                    count=na.count()
                ).put(
                    stats_table
                )

                job.run()
            else:
                print('{} already exists, using it'.format(stats_table))

            recs = hahn.read(stats_table)
            cats = {rec.cat2n: rec.count for rec in recs}

            records_by_cat = {}

            target_number_basket = target_number

            cat_left = len(cats)
            for cat in sorted(cats, key=lambda x: cats[x]):
                ask = int(target_number_basket // cat_left + 1)
                if cats[cat] < ask:
                    ask = cats[cat]
                records_by_cat[cat] = ask
                print('ask {} from category {}'.format(ask, cat))
                target_number_basket -= ask
                cat_left -= 1

            job = hahn.job().env(
                parallel_operations_limit=10
            )
            for basket_type in local_config['basket_types']:
                pool_table = get_hr_path(
                    '$job_root/{}/pool_{}_{}_{}'.format(
                        country, platform, service, basket_type
                    )
                )
                print('creating {}...'.format(pool_table))

                chosen = job.table(
                    filtered_table
                )
                to_concat = []

                for cat in records_by_cat:
                    to_concat.append(
                        chosen.filter(
                            nf.equals('cat2n', cat)
                        ).random(records_by_cat[cat])
                    )

                job.concat(*to_concat).put(
                    pool_table
                )

            job.run()


if __name__ == "__main__":
    main()
