#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
import math
import string
import re
import numpy as np

from nile.utils.misc import coerce_path
from nile.api.v1 import (
    clusters,
    filters as nf,
    # extractors as ne,
    aggregators as na,
    Record
)

import itertools
from collections import defaultdict
from functools import partial

import multiprocessing
from multiprocessing.dummy import Pool as ThreadPool

# def get_geobase(path='geobase.json'):
#     gb = json.load(open(path))
#     geobase = {int(x['id']): preprocess_geobase(x) for x in gb}
#     return geobase


# def preprocess_geobase(z):
#     x = copy.deepcopy(z)
#     for k in ['id', 'type']:
#         x[k] = int(x[k])
#     x['path'] = [int(y) for y in x['path'].split(', ') if y]
#     return x

# # https://wiki.yandex-team.ru/statbox/infrastructure/doc/geobase-binary/
# geobase = get_geobase('geobase-2018-02-10.json')


def get_2n_category(n):
    return int(round(math.log(n, 2)))


def get_yt_exists(yt):
    def yt_exists(table):
        if not yt.exists(table):
            return False
        if yt.get_attribute(table, 'row_count'):
            return True
        else:
            return False
    return yt_exists


class GetHRPath(object):

    def __init__(self, cluster):
        self.cluster = cluster

    def __call__(self, path):
        path = str(
            coerce_path(
                path
            ).eval(**self.cluster.environment.templates)
        )
        if not path.startswith('//'):
            return '//' + path
        return path


def aggregate_queries(groups, search_engine):
    for key, records in groups:
        if not key.text:
            continue
        regions = defaultdict(int)
        other_value = None
        frequency = 0
        for rec in records:
            if rec.region:
                regions[int(rec.region)] += rec.frequency
            if not other_value and rec.get('other'):
                other_value = rec.get('other')
            frequency += rec.frequency

        p = [x / float(frequency) for x in regions.values()]

        result = {
            'query_country': key.country,
            'query_device': 'DESKTOP' if key.platform == 'desktop' else other_value['os'].upper(),
            'query_region_id': np.random.choice(regions.keys(), p=p),
            'query_text': key.text,
            'query_uid': other_value['uid'],
            'bucket': get_2n_category(frequency),
            'other': other_value,
        }

        result['other']['frequency'] = frequency
        result['other']['bucket'] = result['bucket']
        result['other']['service'] = key.service
        result['other']['search'] = search_engine
        result['other']['platform'] = key.platform

        yield Record(**result)


def aggregate_queries_yandex(groups):
    for r in aggregate_queries(groups, 'yandex'):
        yield r


def aggregate_queries_google(groups):
    for r in aggregate_queries(groups, 'google'):
        yield r


country_dict = {
    'RU': 'loose',
    'BY': 'belarusian',
    'UA': 'ukrainian',
    'TR': 'turkish',
    'KZ': 'loose',
    'UZ': 'loose',
    'exUSSR': 'loose',
}


class QueryLangFilter(object):

    def __init__(self, type):
        self.type = type

    def __call__(self, query, type='loose'):
        regex = {
            'strict': u'[^а-яА-Яё ]+',
            'turkish': u'[^a-zA-ZçÇğĞıİöÖşŞüÜ ]+',
            'loose': u'[^a-zA-Zа-яА-ЯёЁ ]+',
            'belarusian': u'[^a-zA-Zа-яА-ЯёЁІіЎў ]+',
            'ukrainian': u'[^a-zA-Zа-яА-ЯёЁҐґІіЇїЄє ]+',
        }[self.type]

        if not isinstance(query, str):
            try:
                query = query.encode('utf8')
            except:
                return False
        try:
            query = query.translate(
                None, string.punctuation
            ).translate(
                None, string.digits
            )
            matches = re.search(
                regex, query.decode('utf8')
            )
        except (TypeError, ValueError, AttributeError):
            return False
        return not matches


def create_aggr_sample(hahn, job_root, params):
    yt = hahn.driver.client
    yt_exists = get_yt_exists(yt)
    get_hr_path = GetHRPath(hahn)

    country, platform, search, service = params

    queries_table = '$job_root/{country}/{platform}/{search}/{service}/01_queries_aggr'.format(
        country=country,
        search=search,
        service=service,
        platform=platform
    )

    if yt_exists(get_hr_path(queries_table)):
        print '\n{}  — table exists'.format(params)
    else:
        print '\n{}  — create table'.format(params)

        if country == 'exUSSR':
            countries = ['AZ', 'AM', 'GE', 'IL', 'KG', 'LV', 'LT', 'MD', 'TJ', 'TM', 'EE']
        else:
            countries = [country]

        if search == 'google':
            aggr_fn = aggregate_queries_google
        else:
            aggr_fn = aggregate_queries_yandex

        job = hahn.job()
        job.table(
            '{job_root}/source_{search}_{platform}'.format(job_root=job_root, search=search, platform=platform)
        ).filter(
            nf.and_(
                nf.custom(lambda x: x in countries, 'country'),
                nf.equals('service', service),
                nf.equals('platform', platform),
                nf.custom(lambda x: bool(x), 'text'),
                nf.custom(QueryLangFilter(country_dict[country]), 'text')
            )
        ).groupby(
            'country', 'service', 'platform', 'text'
        ).reduce(
            aggr_fn,
            memory_limit=3*1024
        ).sort(
            'bucket'
        ).put(
            queries_table
        )
        job.run()

    stats_table = '$job_root/{country}/{platform}/{search}/{service}/02_stats'.format(
        country=country,
        search=search,
        service=service,
        platform=platform
    )

    if yt_exists(get_hr_path(stats_table)):
        print '{}  — stats table exists'.format(params)
    else:
        print '{}  — create stats table'.format(params)

        job = hahn.job()
        job.table(
            queries_table
        ).groupby(
            'bucket'
        ).aggregate(
            count=na.count()
        ).put(
            stats_table
        )
        job.run()

    return get_hr_path(queries_table)


def main(token=None):
    job_root = '//home/images/dev/nerevar/baskets_img/2018Q3'
    hahn = clusters.Hahn(
        token=token
    ).env(
        templates=dict(
            job_root=job_root,
        ),
        package_paths=['.'],
        packages=['numpy']
    )

    fn = partial(create_aggr_sample, hahn, job_root)
    basket_parts = itertools.product(
        ['RU', 'UA', 'BY', 'KZ', 'UZ', 'exUSSR'],
        ['touch', 'desktop'],
        ['google', 'yandex'],
        ['img', 'web'],
    )

    processes = multiprocessing.cpu_count() - 1 or 1
    print 'processes {}'.format(processes)
    pool = ThreadPool(processes=processes)
    tables_list = pool.map(fn, basket_parts)
    pool.close()
    pool.join()

    return tables_list


if __name__ == "__main__":
    main()
