#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
import sys
import argparse
from nile.api.v1 import (
    clusters,
    files,
    Record
)
import json
import getpass
import random
from pytils import tabulate, date_range, get_yt_exists
from collections import Counter
import datetime


rng = random.SystemRandom()


def parse_features(feat_text):
    spl = feat_text.split('\t')
    return {int(k): float(v) for k, v in zip(spl[::2], spl[1::2])}


def promolib_check(obj):
    return obj['h'] == 'promolib_dumpall' and (
        obj['t'] or 'AM::EVENT_AD_INSTALL' in obj['c']
    )


def default_search_check(obj):
    return obj['h'] == 'default_search' and obj['t']


def no_check(obj):
    return True


def default_target(obj):
    if ('AM::EVENT_AD_INSTALL' in obj['c'] or
        'SE::install' in obj['c'] or
            'SE::dayuse' in obj['c']):
        return 1, 1
    return 0, 1


def click_target(obj):
    if ('click' in obj['c']):
        return 1, 1
    return 0, 1


def make_pool_features(features, maxfeat=750):
    result = []
    maxfeat = max(maxfeat, max(features))
    for i in range(maxfeat + 1):
        if i in features:
            result.append(str(features[i]))
        else:
            result.append('0')
    return '\t'.join(result)


def bannerid_extractor(obj):
    return {'bannerid': obj['b']}


def reqid_extractor(obj):
    return {'reqid': obj['v'].split(',')[0].split('=')[-1]}


class PoolMaker(object):
    def __init__(self, check, target=None, add_features=None, extract_fields=None, maxfeat=750):
        self.check = check
        self.target = target
        if not self.target:
            self.target = default_target
        self.add_features = add_features
        if extract_fields is not None:
            self.extract_fields = extract_fields
        else:
            self.extract_fields = lambda obj: {}
        self.id_ = None
        self.maxfeat = maxfeat

    def __call__(self, records):
        for record in records:
            if not self.id_:
                self.id_ = int(str(abs(hash(record.value)))[:9])
            try:
                obj = json.loads(record.value.decode('utf8', errors='replace'))
            except:
                continue
            if not self.check(obj):
                continue
            features = parse_features(obj['f'])
            if self.add_features:
                features = self.add_features(features)
            target, weight = self.target(obj)
            yield Record(
                key=str(self.id_),
                subkey="",
                value=tabulate(
                    target,
                    '',
                    weight,
                    make_pool_features(features, self.maxfeat)
                ),
                **self.extract_fields(obj)
            )
            self.id_ += 1


def pool_stats(records):
    counter = Counter()
    for record in records:
        spl = record.value.split('\t')
        counter[(spl[0], spl[2])] += 1
    for pair in counter:
        yield Record(
            target=pair[0],
            weight=pair[1],
            total=counter[pair]
        )


def append_to_pool(job, table, pool_maker, path):
    job.table(
        table
    ).map(
        pool_maker, files=[
            files.LocalFile(tabulate.func_code.co_filename)
        ]
    ).put('$job_root/{}'.format(path), append=True)


def concat_learn_and_test(job):
    learn = job.table('$job_root/learn')
    test = job.table('$job_root/test')
    job.concat(learn, test).put('$job_root/all')


def make_stats_table(job, table):
    job.table(
        '$job_root/{}'.format(table)
    ).map(
        pool_stats
    ).put(
        '$job_root/{}_stats'.format(table)
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--from', '-f', default=None)
    parser.add_argument('--to', '-t', default=None)
    parser.add_argument('--check', '-c', default='promolib_check')
    parser.add_argument('--extract', '-e', default='bannerid_extractor')
    parser.add_argument('--target', '-T', default='default_target')
    parser.add_argument('--add_features', '-a', default=None)
    args = parser.parse_args()

    from_ = getattr(args, 'from')
    to_ = getattr(args, 'to')

    if not from_ and not to_:
        print('Please specify at least one date')
        sys.exit(1)
    if not from_:
        to_ = from_
    if not to_:
        from_ = to_

    cluster = clusters.Hahn(
        pool='search-research_{}'.format(getpass.getuser())
    ).env(
        templates=dict(job_root='home/atom/pools/{}_{}_{}_{}_{}_{}'.format(
            args.check,
            args.target,
            args.add_features or 'no_features',
            args.extract,
            from_,
            to_
        ))
    )

    from_ = datetime.datetime.strptime(from_, '%Y-%m-%d').date()
    to_ = datetime.datetime.strptime(to_, '%Y-%m-%d').date()

    check = globals()[args.check]
    target = globals()[args.target]
    extract = globals()[args.extract]
    add_features = None
    if args.add_features:
        add_features = globals()[args.add_features]

    pool_maker = PoolMaker(
        check=check, target=target, add_features=add_features, extract_fields=extract
    )

    dr = date_range(from_, to_)
    nonempty = get_yt_exists(cluster.driver.client)
    tables = []
    for date in dr:
        table = 'home/personalization/atom/fml/conveyor/pool/{}/all'.format(
            date
        )
        if not nonempty('//' + table):
            print('error: table {} does not exist.'.format(table))
        else:
            tables.append(table)

    tables = sorted(tables)
    if len(tables) <= 3:
        print('error: not enough days to split to learn/test')
        sys.exit(1)

    th = int(len(tables) / 4)
    learn, test = tables[:-th], tables[-th:]
    for table in learn:
        print('table {} added to learn'.format(table))
        job = cluster.job()
        append_to_pool(job, table, pool_maker, path='learn')
        try:
            job.run()
        except:
            pass
    for table in test:
        print('table {} added to test'.format(table))
        job = cluster.job()
        append_to_pool(job, table, pool_maker, path='test')
        try:
            job.run()
        except:
            pass

    job = cluster.job()
    # concat_learn_and_test(job)
    make_stats_table(job, 'learn')
    make_stats_table(job, 'test')

    job.run()

if __name__ == "__main__":
    main()
