#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import ast
import collections
from datetime import datetime
from itertools import repeat, izip
import multiprocessing
import urlparse

from nile.api.v1 import (
    filters as nf,
    aggregators as na,
    extractors as ne,
    clusters,
    Record,
)
import libra
import nile
import pandas as pd
import yt.wrapper as yt
import logging
logging.basicConfig(format='[%(asctime)s] %(filename)s[LINE:%(lineno)d] %(levelname)-8s %(message)s',
                    level=logging.ERROR)

PROCESSES = 2
JOB_NAME = 'HOME TRAIN TEST'
REALSHOW_ROOT = '//home/search-research/kaminsky/EXPERIMENTS/RESULT'


def firstreduce(groups):
    for key, recs in groups:
            uid = key.key
            if uid[0] != 'y':
                continue
            try:
                requests = libra.ParseSession(recs, "blockstat.dict")
            except:
                continue
            for request in requests:
                if not request.IsA("TPortalRequestProperties"):
                    continue
                ts = request.Timestamp
                reqDay = str(datetime.fromtimestamp(ts).isoformat()).split('T')[0]
                reqid = request.ReqId
                region = request.UserRegion
                m_content = request.MContent
                if m_content != 'touch':
                    continue
                enabled_test_buckets = request.GetEnabledTestInfo()
                enabled_test_ids = []
                for t in enabled_test_buckets:
                    ti = t.TestID
                    enabled_test_ids.append(ti)
                main_blocks = request.GetMainBlocks()
                for block in main_blocks:
                    main_res = block.GetMainResult()
                    if main_res.IsA('TPortalResult'):
                        path = main_res.Path
                        parent = main_res.ParentPath
                        position = main_res.Position
                        height = block.Height
                        block_clicks = main_res.GetOwnEvents()
                        clicks = []
                        for click in block_clicks:
                            if click.IsA('TClick'):
                                url = click.Url
                                clicks.append({
                                    'url': click.Url,
                                    'path': click.Path,
                                    'dynamic': int(click.IsDynamic),
                                })
                        num_clicks = len(block_clicks)
                        yield Record(
                            date=reqDay,
                            uid=uid,
                            reqid=reqid,
                            region=region,
                            ts=ts,
                            m_content=m_content,
                            path=path,
                            parent=parent,
                            position=position,
                            height=height,
                            clicks=clicks,
                            num_clicks=num_clicks,
                            enabled_test_ids=enabled_test_ids,
                        )


def argument_parser():
    parser = argparse.ArgumentParser(description='Get parameters')
    parser.add_argument(
        '-p',
        dest='pool',
        type=str,
        help='you hahn pool',
    )
    parser.add_argument(
        '-r',
        dest='rpath',
        type=str,
        help='result path',
    )
    parser.add_argument(
        '-u',
        dest='usessions',
        type=str,
        help='user sessions path',
    )
    parser.add_argument(
        '-d1',
        dest='date1',
        type=str,
        help='date1 format "yyyy-mm-dd"',
    )
    parser.add_argument(
        '-d2',
        dest='date2',
        type=str,
        help='date2 format "yyyy-mm-dd"',
    )
    args = parser.parse_args()
    return args


def one_process(job_root, pool, sessions_path, date):
        cluster = clusters.Hahn(pool=pool).env(templates=dict(
            job_root=job_root,
            realshow_root=REALSHOW_ROOT,
            )
        )
        external_files = [
            nile.files.RemoteFile('//statbox/resources/libra.so'),
            nile.files.RemoteFile('//statbox/statbox-dict-last/blockstat.dict')
        ]
        job = cluster.job(JOB_NAME + ': tmp_' + date)
        log = job.table('{}/{}/clean'.format(sessions_path, date))
        to_join = job.table('$realshow_root/{}/total_events'.format(date)).filter(
            nf.and_(
                nf.equals('event_type', 'show'),
                nf.custom(lambda x: x.split('.')[0] == 'geotouch', 'path'),
            )
        ).project(
            reqid='reqid',
            yandexuid='yandexuid',
            timestamp='timestamp',
            old_path='path',
            path=ne.custom(lambda x: x.split('.')[2], 'path'),
            joined=ne.const(1),
        ).project(
            ne.all(),
            realshow=ne.custom(lambda x: int('realshow' in x), 'old_path'),
        ).groupby(
            'reqid',
            'path'
        ).aggregate(
            realshow=na.max('realshow')
        )
        reqs = log.groupby('key').sort('subkey').reduce(
            firstreduce,
            files=external_files,
            memory_limit = 32*1024,
            intensity='large_data',
        ).join(
            to_join,
            by=('reqid', 'path'),
            type='left'
        ).put('$job_root/' + date)
        job.run()


def one_process_star(all_args):
    return one_process(*all_args)


def main():
    args = argument_parser()
    date1 = args.date1
    date2 = args.date2
    job_root = args.rpath
    user_sessions = args.usessions

    dates = [i.strftime('%Y-%m-%d') for i in pd.date_range(date1, date2)]

    try:
        yt.config.set_proxy("hahn")
        dates_done = yt.list(job_root)
        dates_done = [x for x in dates_done if x[0] == '2']
        dates_for_refresh = list(set(dates) - set(dates_done))
    except:
        dates_for_refresh = dates

    multiprocessing.freeze_support()
    proc_pool = multiprocessing.Pool(processes=PROCESSES)
    proc_pool.map(
        one_process_star,
        izip(
            repeat(job_root),
            repeat(args.pool),
            repeat(user_sessions),
            dates_for_refresh
        )
    )


if __name__ == "__main__":
    main()
