#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import codecs
import argparse
import re
import string
import os
import json
import datetime
import random
from collections import OrderedDict, Counter
import math
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    Record
)
from pytils import get_cluster, get_driver


def get_2n_category(n):
    try:
        n = float(n)
        return int(round(math.log(n, 2)))
    except:
        raise Exception('error on <{}> n: {}'.format(type(n), repr(n)))


country_dict = {
    'RU': 'loose',
    'BY': 'belarusian',
    'UA': 'ukrainian',
    'TR': 'turkish',
    'KZ': 'loose',
    'UZ': 'loose',
    'exUSSR': 'loose',
    # world
    'AZ': 'strict',
    'AM': 'strict',
    'GE': 'strict',
    'IL': 'strict',
    'KG': 'strict',
    'LV': 'strict',
    'LT': 'strict',
    'MD': 'strict',
    'TJ': 'strict',
    'TM': 'strict',
    'EE': 'strict',
}


bad_symbols = re.compile(u'[^a-zA-Zа-яА-ЯёçÇğĞıİöÖşŞüÜІіЎўҐґЇїЄє0-9 ]')


class QueryLangFilter(object):

    def __init__(self, type):
        self.type = type

    def __call__(self, query, type='loose'):
        regex = {
            'strict': u'[^а-яА-ЯёЁ ]+',
            'turkish': u'[^a-zA-ZçÇğĞıİöÖşŞüÜ ]+',
            'loose': u'[^a-zA-Zа-яА-ЯёЁ ]+',
            'belarusian': u'[^a-zA-Zа-яА-ЯёЁІіЎў ]+',
            'ukrainian': u'[^a-zA-Zа-яА-ЯёЁҐґІіЇїЄє ]+',
        }[self.type]

        if not isinstance(query, str):
            try:
                query = query.encode('utf8')
            except:
                return False
        try:
            query = query.translate(
                None, string.punctuation
            ).translate(
                None, string.digits
            )
            matches = re.search(
                regex, query.decode('utf8')
            )
        except (TypeError, ValueError, AttributeError):
            return False
        return not matches


def weighted_choice(choices, p, rnd):
    total = sum(p)
    r = rnd.uniform(0, total)
    upto = 0
    for e, c in enumerate(choices):
        if upto + p[e] >= r:
            return c
        upto += p[e]


class QueryAggregator(object):

    def __init__(self, query_column_name, count_column_name):
        self.query_column_name = query_column_name
        self.count_column_name = count_column_name

    def __call__(self, groups):
        rnd = random.SystemRandom()
        for key, records in groups:
            if not key.get(self.query_column_name, ''):
                continue
            result = {}
            region_counter = Counter()
            for rec in records:
                if not result:
                    result.update(rec.to_dict())
                result[self.count_column_name] += (
                    rec.get(self.count_column_name, 0) or 0
                )
                region_counter[rec.region] += 1

            keys = sorted(region_counter.keys())
            p = [
                region_counter[k] / float(result[self.count_column_name])
                for k in keys
            ]

            result['regions_counter'] = {
                str(k): v for k, v in region_counter.items()
            }
            result['region'] = weighted_choice(keys, p, rnd)
            yield Record(**result)


def normalize_query(query):
    if not isinstance(query, unicode):
        query = str(query).decode('utf8', errors='replace')
    query = bad_symbols.sub(u' ', query)
    query = re.sub(u' +', u' ', query)
    query = query.strip()
    query = query.lower()
    return query


class QueryReaggregator(object):

    def __init__(self, query_column_name, count_column_name):
        self.query_column_name = query_column_name
        self.count_column_name = count_column_name

    def __call__(self, groups):
        rnd = random.SystemRandom()
        for key, records in groups:
            result = {}
            regions_counter = Counter()
            for rec in records:
                if not result:
                    result.update(rec.to_dict())
                result[self.count_column_name] += (
                    rec.get(self.count_column_name, 0) or 0
                )
                regions_counter += Counter(rec.regions_counter)
            result['regions_counter'] = regions_counter
            keys = sorted(regions_counter.keys())
            p = [
                regions_counter[k] / float(result[self.count_column_name])
                for k in keys
            ]
            result['region'] = weighted_choice(keys, p, rnd)
            yield Record(**result)


class ClassifierFilter(object):

    def __init__(self, key):
        self.key = key

    def __call__(self, dct):
        return bool(dct.get(self.key, None))


def get_name_from_dict(dict_):
    pairs = ['{}_{}'.format(k, v) for k, v in dict_.items()]
    return '_'.join(pairs)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pool', '-p', required=True)
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--cluster', default='hahn')
    parser.add_argument('--config', required=True)
    parser.add_argument('--source_table', '-s', required=True)
    parser.add_argument('--target_folder', '-t', required=True)
    parser.add_argument('--outfile', default='out.json')
    parser.add_argument('--query_column_name', '-q', default='query')
    parser.add_argument('--count_column_name', '-c', default='count')
    parser.add_argument('--nofilters', action='store_true')
    args = parser.parse_args()

    tmp_job_root = '{}/tmp'.format(args.target_folder)

    vargs = vars(args)
    vargs["no_yql"] = True
    vargs["title"] = "Baskets Presampler"
    vargs["templates"] = dict(
        job_root=args.target_folder,
        tmp_root=tmp_job_root
    )

    hahn = get_cluster(clusters, vargs)

    if os.path.isfile(args.source_table):
        args.source_table = json.load(open(args.source_table))['table']

    config = json.load(open(args.config))

    job = hahn.job().env(
        parallel_operations_limit=20,
        default_memory_limit=(20*1024)
    )

    yt = hahn.driver

    input_stream_base = job.table(
        args.source_table
    )

    counts_dict = OrderedDict()
    for dict_ in config:
        filters = []
        name = get_name_from_dict(dict_)
        counts_dict[name] = dict_['count']
        dict_.pop('count')
        if not args.nofilters:
            if 'classifiers' in dict_:
                filters.append(
                    nf.custom(
                        ClassifierFilter(dict_['classifiers']), 'classifiers'
                    )
                )
                dict_.pop('classifiers')
            if 'country' in dict_:
                filters.append(
                    nf.custom(
                        QueryLangFilter(country_dict[dict_['country']]),
                        args.query_column_name
                    )
                )
        for key in dict_:
            filters.append(
                nf.equals(key.encode('utf8'), dict_[key].encode('utf8'))
            )

        output_table = '{}/{}_stats'.format(
            tmp_job_root, name
        )

        filtered_table = '{}/{}'.format(tmp_job_root, name)

        print('filtering into {}'.format(output_table))

        # скип расчета статистики и фильтрации, если процесс упал
        if yt.exists(output_table) and yt.get_attribute(output_table, 'row_count') != 0 and yt.exists(filtered_table) and yt.get_attribute(filtered_table, 'row_count') != 0:
            print('Table {} was already processed'.format(output_table))
            continue

        normalize_kwargs = {
            args.query_column_name: ne.custom(
                normalize_query, args.query_column_name
            )
        }

        input_stream = input_stream_base.filter(
            *filters
        )
        if not args.nofilters:

            input_stream = input_stream.groupby(
                args.query_column_name
            ).reduce(
                QueryAggregator(
                    args.query_column_name, args.count_column_name
                )
            ).project(
                ne.all(), **normalize_kwargs
            ).groupby(
                args.query_column_name
            ).reduce(
                QueryReaggregator(
                    args.query_column_name, args.count_column_name
                )
            )

        input_stream = input_stream.filter(
            nf.custom(bool, args.count_column_name)
        ).project(
            ne.all(), cat2n=ne.custom(get_2n_category, args.count_column_name)
        ).put(
            '{}/{}'.format(tmp_job_root, name)
        ).groupby(
            'cat2n'
        ).aggregate(
            count=na.count()
        ).put(
            output_table
        )

    job.run()

    target_tables = []

    for table in counts_dict:
        tmp_table = '{}/{}'.format(tmp_job_root, table)
        stats_table = '{}/{}_stats'.format(tmp_job_root, table)
        target_table = '{}/{}'.format(args.target_folder, table)
        print('sampling into {}'.format(target_table))
        recs = hahn.read(stats_table)
        cats = {rec.cat2n: rec.count for rec in recs}

        # скип семплинга таблицы, если процесс упал
        if yt.exists(target_table) and yt.get_attribute(target_table, 'row_count') != 0:
            print('Table {} was already processed'.format(target_table))
            target_tables.append(target_table)
            continue

        records_by_cat = {}

        target_number_basket = counts_dict[table]

        cat_left = len(cats)
        for cat in sorted(cats, key=lambda x: cats[x]):
            ask = int(target_number_basket // cat_left + 1)
            if cats[cat] < ask:
                ask = cats[cat]
            records_by_cat[cat] = ask
            print('ask {} from category {}'.format(ask, cat))
            target_number_basket -= ask
            cat_left -= 1

        print('creating {}...'.format(target_table))
        job = hahn.job().env(
            parallel_operations_limit=20,
            default_memory_limit=(20*1024)
        )

        chosen = job.table(
            tmp_table
        )
        to_concat = []

        for cat in records_by_cat:
            to_concat.append(
                chosen.filter(
                    nf.equals('cat2n', cat)
                ).random(records_by_cat[cat], memory_limit=80000)
            )

        job.concat(*to_concat).sort(
            'cat2n'
        ).put(
            target_table
        )

        job.run()

        target_tables.append(target_table)

    json.dump(target_tables, open(args.outfile, 'w'), indent=2)

    if not args.debug:
        get_driver(hahn).client.remove(tmp_job_root, recursive=True)


if __name__ == "__main__":
    main()
