#!/usr/bin/env python
# -*- coding: utf-8 -*-

# https://st.yandex-team.ru/EXPERIMENTS-21325

from nile.api.v1 import (
    Record,
    files,
    clusters,
    cli,
    with_hints,
    filters as nf,
    aggregators as na,
    extractors as ne,
    statface as ns #obligatory for Statface
)
from qb2.api.v1 import (
    QB2,
    resources as sr,
    extractors as se,
    filters as sf,
    filters as qf
)

import os #obligatory for Statface
import sys #obligatory for Statface
import re #obligatory for Statface
import argparse #obligatory for Statface
import getpass #obligatory for Statface
import datetime
import time
import re

import uatraits
import random
import itertools

EXPORT_PAGE = '//statbox/statbox-dict-last/export_page'

@with_hints(
    output_schema=dict(
        uid=str,
        testid=str
    )
)
def myMap_redir(recs):
    for rec in recs:
        try:
            tids = rec.tids
            td = rec.testids
            test_ids = td.replace(' ', '').split(',')
            # tids=76083,0,73;63208,0,63
            if len(tids) < 1:
                continue
            exps = [x.split(',', 1)[0] for x in tids.split(';') if ',' in x]
            for el in exps:
                testid = el
                if testid not in test_ids:
                    continue
                try:
                    uid = rec.yandexuid
                    if uid!="":
                        yield Record(uid=uid, testid=testid)
                except:
                    pass
        except:
            pass


@with_hints(
    output_schema=dict(
        page_id=str,
        impid=str,
        click=int,
        show=int,
        cost=float,
        uid=str
    )
)
def parse_chevent_coll(recs):
    pageId_list = []
    export_page_full = sr.get(sr.json('export_page'))
    for k, v in export_page_full.iteritems():
        if "Name" in v:
            if "collections.yandex" in v["Name"]:
                pageId_list.append(k)
    for rec in recs:
        pageid = rec.pageid

        if pageid in pageId_list:

            devicetype = int(rec.devicetype)
            if devicetype < 4:
                devicetypestr = "touch"
            elif devicetype == 4:
                devicetypestr = "pad"
            else:
                devicetypestr = "desktop"

            countertype = rec.countertype
            if countertype == "2":
                click = 1
                show = 0
            else:
                click = 0
                show = 1

            eventcost = float(rec.eventcost)
            cost = click*eventcost*30/1000000
            yield Record(
                page_id=pageid,
                impid=rec.impid,
                click=click,
                show=show,
                cost=cost,
                uid=str(rec.uniqid)
            )


@with_hints(
    output_schema=dict(
        page_id=str,
        impid=str,
        click=int,
        show=int,
        cost=float,
        testid=str
    )
)
def add_totals(recs):
    for rec in recs:
        for comb in itertools.product((rec.get('page_id'), '_total_'), (rec.get('impid'), '_total_')):
            yield Record(page_id=comb[0],
                        impid=comb[1],
                        click=rec.click,
                        show=rec.show,
                        cost=rec.cost,
                        testid=rec.testid)


@with_hints(
    output_schema=dict(
        page_id=str,
        impid=str,
        click=int,
        show=int,
        cost=float,
        bucket=int,
        testid=str
    )
)
def gen_bucket(recs):
    for rec in recs:
        bucket = random.randrange(0, 100, 1)
        yield Record(
            page_id=rec.page_id,
            impid=rec.impid,
            click=rec.click,
            show=rec.show,
            cost=rec.cost,
            bucket=bucket,
            testid=rec.testid
        )

def make_jobroot_from_table(s):
    splitted_s = s.split('/')
    job_root = "/".join(splitted_s[0:-1])
    return job_root

def parse_from_path(s):
    splitted_s = s.split('/')
    table = splitted_s[-1]
    job_root = "/".join(splitted_s[0:-1])
    return [job_root, table]


@cli.statinfra_job(options=[cli.Option('test_ids', default='?')])

def make_job(job, nirvana, options):
    output_table = nirvana.output_tables[0]
    output_folder = parse_from_path(output_table)[0]
    table_name = parse_from_path(output_table)[1]

    job = job.env(
        yt_spec_defaults=dict(
            pool_trees=["physical"],
            tentative_pool_trees=["cloud"]
        ),
        templates=dict(
            job_root=output_folder,
            tmp_files=output_folder + "/temporary"
        )
    )

    testids = options.test_ids
    if testids == "?":
        print >> sys.stderr, 'wrong testids'
    split_testids = testids.replace(" ", "").split(",")

    uids = job.table('//logs/collections-redir-log/1d/@dates') \
        .qb2(log = 'redir-log',
            fields=['yandexuid', se.log_field('tids').allow_override()],
            filters = [sf.default_filtering('redir-log'), sf.defined('tids', 'yandexuid')],
            mode='yamr_lines') \
        .project(ne.all(), testids=ne.const(testids)) \
        .map(myMap_redir) \
        .unique('testid', 'uid')

    log = job.table('//logs/bs-chevent-log/1d/@dates')
    midresult = log.filter(
            nf.and_(
                nf.equals('placeid', '542'),
                nf.equals('fraudbits', '0')
                )
            ).map(
                parse_chevent_coll,
                files=[files.RemoteFile(EXPORT_PAGE)],
                memory_limit=3*1024,
                intensity='default'
            )

    midresult_j = midresult.join(uids, by='uid')

    result = midresult_j.groupby('page_id', 'impid', 'testid', 'uid') \
            .aggregate(
                show = na.sum('show'),
                click = na.sum('click'),
                cost = na.sum('cost')
            ) \
            .map(gen_bucket) \
            .groupby('page_id', 'impid', 'testid', 'bucket') \
            .aggregate(
                show = na.sum('show'),
                click = na.sum('click'),
                cost = na.sum('cost')
            )

    result.map(add_totals) \
            .groupby('page_id', 'impid', 'testid') \
            .aggregate(
                show = na.sum('show'),
                click = na.sum('click'),
                cost = na.sum('cost')
            ) \
            .sort('page_id', 'impid', 'testid') \
            .put(output_table)#, schema=dict(page_id=str, impid=str,testid=str, show=int, click=int, cost=float))

    job.concat(*[result.filter(nf.equals('testid', i)).groupby('testid','bucket').aggregate(money = na.sum('cost')) for i in split_testids]) \
        .put("$tmp_files" + "/" + table_name + "_buckets")

    return job


if __name__ == '__main__':
    cli.run()
