# -*- coding: utf-8 -*-

from yabs.tabtools import mr_do_aggregate, Sum, Count, StatAggregator, Grep, Mapper, BasicMapper
from yabs.tabutils import read_ts_table
from yabs.logconfig import get_logs_regexp_time
from yabs.tabutils import TemporaryTableWithMeta


class CounterIDMapper(BasicMapper):

    def __init__(self):
        super(CounterIDMapper, self).__init__()

    def __call__(self, records):
        for r in records:
            counter_id = eval(r.CountersID or "[]")
            if counter_id:
                r.CounterID = counter_id[0]
            else:
                r.CounterID = 0

            yield r

    def setFormat(self, fmt):
        super(CounterIDMapper, self).setFormat(fmt)
        self._output_format.addField('CounterID', int)


class AvgD120Mapper(BasicMapper):

    def __init__(self, action_stat_table):
        super(AvgD120Mapper, self).__init__()
        self._action_stat = dict()

        for r in read_ts_table(action_stat_table):
            counter_id = r['CounterID']
            d120click = r['D120Click']
            real_cost = r['RealCost']
            if d120click > 0:
                self._action_stat[counter_id] = real_cost / d120click

    def __call__(self, records):
        for r in records:
            r.AvgD120RealCost = self._action_stat.get(r.CounterID) or 0

            yield r

    def setFormat(self, fmt):
        super(AvgD120Mapper, self).setFormat(fmt)
        self._output_format.addField('AvgD120RealCost', float)


class D120AdjustedMapper(BasicMapper):

    def __init__(self):
        super(D120AdjustedMapper, self).__init__()

    def __call__(self, records):
        for r in records:
            r.D120AdjustedRealCost = r.D120 * r.AvgD120RealCost
            r.D120AdjustedIncome = r.D120 * r.AvgD120RealCost - r.VCost

            yield r

    def setFormat(self, fmt):
        super(D120AdjustedMapper, self).setFormat(fmt)
        self._output_format.addField('D120AdjustedRealCost', float)
        self._output_format.addField('D120AdjustedIncome', float)


if __name__ == '__main__':

    tables = get_logs_regexp_time('yabs-log/\d{6}/JoinedEFHFactors(?P<TIME>%Y%m%d%H)', '20160409', '20160417')
    tables = sorted([t['name'] for t in tables])

    for t in tables:
        print t

    with TemporaryTableWithMeta() as tmp:
        mr_do_aggregate(
            aggregator=StatAggregator(
                reducers=[
                    Count('Shows'),
                    Sum('RealCost', 'RealCost'),
                    Sum('EventCost', 'Cost'),
                    Sum('IsClick', 'Clicks'),
                    Sum('Rank', 'ClicksPredicted'),
                    Sum('HasCounter', 'VClicks'),
                    Sum('LClick', 'LClicks'),
                    Sum('SClick', 'SClicks'),
                    Sum('MaxDepth', 'Depth'),
                    Sum('MaxDuration', 'Duration'),
                    Sum('PAPC', 'PAPC'),
                    Sum('VCost', 'VCost'),
                    Sum('D120', 'D120')
                ],
                keys=['ExperimentBits', 'Day', 'TypeID', 'SimDistance', 'CounterID']
            ),
            premap=[
                Mapper('r.Day = datetime.datetime.fromtimestamp(int(r.ShowTime)).strftime("%Y%m%d")'),
                Mapper('r.IsClick = r.CounterType - 1'),
                Mapper('r.RealCost = r.RealCost if r.IsClick else 0'),
                Mapper('r.SClick = int(r.MaxDepth > 1 and r.CounterType == 2) if r.HasCounter else 0'),
                Mapper('r.LClick = r.LClick if r.HasCounter else 0'),
                Mapper('r.MaxDepth = r.MaxDepth if r.HasCounter else 0'),
                Mapper('r.MaxDuration = r.MaxDuration if r.HasCounter else 0'),
                Mapper('r.D120 = int(r.MaxDuration > 120) if r.HasCounter else 0'),
                Mapper('r.VCost = r.EventCost if r.HasCounter else 0'),
                Mapper('r.PAPC = r.PAPC if r.HasCounter else 0'),
                Grep('r.PageID == 2 and r.FraudBits == 0'),
                CounterIDMapper()
            ],
            begin='import datetime',
            src_tables=tables,
            dst_tables=[tmp.name]
        )

        mr_do_aggregate(
            aggregator=StatAggregator(
                reducers=[
                    Sum('Shows', 'Shows'),
                    Sum('RealCost', 'RealCost'),
                    Sum('Cost', 'Cost'),
                    Sum('Clicks', 'Clicks'),
                    Sum('ClicksPredicted', 'ClicksPredicted'),
                    Sum('VClicks', 'VClicks'),
                    Sum('LClicks', 'LClicks'),
                    Sum('SClicks', 'SClicks'),
                    Sum('Depth', 'Depth'),
                    Sum('Duration', 'Duration'),
                    Sum('PAPC', 'PAPC'),
                    Sum('VCost', 'VCost'),
                    Sum('D120', 'D120'),
                    Sum('D120AdjustedIncome', 'D120AdjustedIncome'),
                    Sum('D120AdjustedRealCost', 'D120AdjustedRealCost')
                ],
                keys=['ExperimentBits', 'Day', 'TypeID', 'SimDistance']
            ),
            premap=[
                AvgD120Mapper(action_stat_table='users/stys/broadmatch/stat/ActionStat/201602/ActionStat20160229'),
                D120AdjustedMapper()
            ],
            src_tables=[tmp.name],
            dst_tables=['users/stys/broadmatch/stats.20160409_20160417.t']
        )