# -*- coding: utf-8 -*-

import os
from argparse import ArgumentParser
from yabs.tabtools import mr_do_aggregate
from yabs.tabtools import Sum, Count, StatAggregator, Grep, Mapper, BasicMapper
from yabs.matrixnet import Matrixnet
from yabs.vwlib.mappers import VWApplyMapper
from yabs.logconfig import get_logs_regexp_time
from yabs.ml.dump import fetch_latest
from yabs.tabutils import read_ts_table


class AdjustedRealCostMapper(BasicMapper):
    def __init__(self, start, finish, **kwargs):
        super(AdjustedRealCostMapper, self).__init__()
        self.start = start
        self.finish = finish
        self.field = kwargs.get('kind', 'D120Click')
        self.adjusted_real_cost_field = self.field + 'AdjustedRealCost'
        self.adjusted_cost_field = self.field + 'AdjustedCost'
        self.max_value = kwargs.get('max_value', 1000*1e6)

        if self.field in ['D120Click']:
            self.logs_regexp = 'stat/ClickActionStat/\d{6}/ClickActionStat(?P<TIME>%Y%m%d)'
        else:
            raise 'Unsupported field: %s' % self.field

        self.stats = dict()
        stats_tables = get_logs_regexp_time(self.logs_regexp, self.start, self.finish)
        for t in stats_tables:
            data = {r.DomainID: self.value(r) for r in read_ts_table(t['name'])}
            day = t['datetime'].strftime('%Y%m%d')
            self.stats[day] = data

        self.latest_day = sorted(t['datetime'].strftime('%Y%m%d') for t in stats_tables)[-1]

    def __call__(self, records):
        # join to records by DomainID
        for r in records:
            latest_stats = self.stats[self.latest_day]
            adjusted_value = self.stats.get(r.Day, latest_stats).get(r.DomainID, 0)
            setattr(r, self.adjusted_real_cost_field, adjusted_value * getattr(r, self.field))
            setattr(r, self.adjusted_cost_field, r.Cost if adjusted_value > 0 else 0)
            yield r

    def value(self, r):
        value = 0
        stat = getattr(r, self.field)
        if stat > 1:
            value = r.RealCost / stat
        if value > self.max_value:
            value = 0
        return value

    @staticmethod
    def download_table(t, recs):
        recs.extend([r for r in read_ts_table(t)])

    def setFormat(self, fmt):
        super(AdjustedRealCostMapper, self).setFormat(fmt)
        self._output_format.addField(self.adjusted_real_cost_field, float)
        self._output_format.addField(self.adjusted_cost_field, float)


class Tmp(object):
    def __init__(self, name):
        self.name = name


def stats(src_tables, dst_table):
    # aggregate metrics by DomainID
    mr_do_aggregate(
        aggregator=StatAggregator(
            reducers=[
                Count('Shows'),
                Sum('RealCost', 'RealCost'),
                Sum('EventCost', 'Cost'),
                Sum('IsClick', 'Clicks'),
                Sum('Rank', 'ClicksPredicted'),
                Sum('HasCounter', 'VClicks'),
                Sum('LClick', 'LClicks'),
                Sum('SClick', 'SClicks'),
                Sum('MaxDepth', 'Depth'),
                Sum('MaxDuration', 'Duration'),
                Sum('PAPC', 'PAPC'),
                Sum('VCost', 'VCost'),
                Sum('D120Click', 'D120Click')
            ],
            keys=['ExperimentBits', 'Day', 'TypeID', 'SimDistance', 'DomainID']
        ),
        premap=[
            Grep('r.PageID==2 and r.FraudBits == 0'),
            Mapper('r.Day = datetime.datetime.fromtimestamp(int(r.ShowTime)).strftime("%Y%m%d")'),
            Mapper('r.IsClick = r.CounterType - 1'),
            Mapper('r.RealCost = r.RealCost if r.IsClick else 0'),
            Mapper('r.SClick = int(r.MaxDepth > 1 and r.CounterType == 2) if r.HasCounter else 0'),
            Mapper('r.LClick = r.LClick if r.HasCounter else 0'),
            Mapper('r.MaxDepth = r.MaxDepth if r.HasCounter else 0'),
            Mapper('r.MaxDuration = r.MaxDuration if r.HasCounter else 0'),
            Mapper('r.D120Click = int(r.MaxDuration > 120) if r.HasCounter else 0'),
            Mapper('r.VCost = r.EventCost if r.HasCounter else 0'),
            Mapper('r.PAPC = r.PAPC if r.HasCounter else 0'),
        ],
        begin='import datetime; import math',
        src_tables=src_tables,
        dst_tables=[dst_table]
    )


def aggregate(src_table, dst_table):
    # get final statistics
    mr_do_aggregate(
        aggregator=StatAggregator(
            reducers=[
                Sum('Shows', 'Shows'),
                Sum('RealCost', 'RealCost'),
                Sum('Cost', 'Cost'),
                Sum('Clicks', 'Clicks'),
                Sum('ClicksPredicted', 'ClicksPredicted'),
                Sum('VClicks', 'VClicks'),
                Sum('LClicks', 'LClicks'),
                Sum('SClicks', 'SClicks'),
                Sum('Depth', 'Depth'),
                Sum('Duration', 'Duration'),
                Sum('PAPC', 'PAPC'),
                Sum('VCost', 'VCost'),
                Sum('D120Click', 'D120Click'),
                Sum('D120ClickAdjustedRealCost', 'D120ClickAdjustedRealCost'),
                Sum('D120ClickAdjustedCost', 'D120ClickAdjustedCost'),
            ],
            keys=['ExperimentBits', 'Day', 'TypeID', 'SimDistance']
        ),
        premap=[
            AdjustedRealCostMapper(args.start, args.finish, kind='D120Click')
        ],
        src_tables=[src_table],
        dst_tables=[dst_table]
    )


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('-s', '--start', required=True)
    parser.add_argument('-f', '--finish', required=True)
    parser.add_argument('-d', '--dst', required=True)
    args = parser.parse_args()

    tables = get_logs_regexp_time('yabs-log/\d{6}/JoinedEFHFactors(?P<TIME>%Y%m%d%H)', args.start, args.finish)
    tables = sorted([t['name'] for t in tables])

    # output tables
    aggr_domain_table = '%s.%s_%s.exp_day_slot_simdistance_domain' % (args.dst, args.start, args.finish)
    aggr_table = '%s.%s_%s.exp_day_slot_simdistance' % (args.dst, args.start, args.finish)

    # collect basic stats
    stats(tables, aggr_domain_table)

    # integrate domains
    aggregate(aggr_domain_table, aggr_table)
