#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    statface as ns,
    with_hints,
    Record
)
from yql.api.v1.client import YqlClient
import nile.files as nfi
from qb2.api.v1 import typing as qt, resources as qr
from pytils import (
    get_cluster,
    get_driver,
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers
)
import datetime
import sys
import json
from collections import Counter


IRON_BRANCH_TABLE = '//home/videolog/strm_meta/iron_branch/concat'
query = """
use hahn;

select ListReverse(ListSort(LIST(distinct UpdateTime))){0} as LastDate
from [//home/video-hosting-dev/vandermond/VideohostingViewsDetailedStat]
"""


def tryint(s, default=0):
    try:
        return int(s)
    except:
        return default


def get_max_prob_value(dct):
    try:
        return Counter(dct).most_common(1)[0][0]
    except:
        return ''


def repack_dict(d1, d2, values, t='str'):
    for v in values:
        if t == 'str':
            d2[v] = d1.get(v or '')
        elif t == 'int':
            d2[v] = tryint(d1.get(v or '0'))


def process_region(reg, geobase):
    try:
        reg = geobase.region_by_id(reg)
    except Exception:
        return 10000
    donext = False
    for reg_ in reg.path:
        if donext and reg_.type == 5:
            return reg_.id
        if reg_.type == 3:
            if reg_.id != 225:
                return reg_.id
            else:
                donext = True
                continue
    if donext:
        return 225
    else:
        return 10000


def get_country(region):
    for reg in region.path:
        if reg.type == 3:
            return reg.short_name
    return 'UNK'


@with_hints(
    output_schema=dict(
        VideoContentID=qt.Optional[qt.String],
        VideoSessionID=qt.String,
        Service=qt.String,
        IP=qt.String,
        DetailedDeviceType=qt.String,
        BrowserName=qt.String,
        UserAgent=qt.String,
        PageID=qt.Integer,
        DeviceType=qt.Integer,
        VideoCategoryID=qt.Integer,
        EventTime=qt.Integer,
        UpdateTime=qt.Integer,
        UniqID=qt.UInt64,
        Price=qt.Integer,
        PartnerPrice=qt.Integer,
        RegionID=qt.UInt64,
        IsView=qt.Integer,
        Duration=qt.Integer,
        Hits=qt.Integer,
        ShownHits=qt.Integer,
        UserGender=qt.String,
        UserAgeSegment=qt.String,
    )
)
def vh_detailed_stat_mapper(recs):
    geobase = qr.get('Geobase')
    for rec in recs:
        result = {}
        result['VideoContentID'] = rec.get(
            'VideoContentID', ''
        )
        isView = int(bool(result['VideoContentID']))
        result['VideoSessionID'] = rec.get('VideoSessionID', '') or ''
        result['Service'] = rec.get('Service', '') or ''
        result['IP'] = rec.get('IP', '') or ''
        result['DetailedDeviceType'] = rec.get('DetailedDeviceType', '') or ''
        result['BrowserName'] = rec.get('BrowserName', '') or ''
        result['UserAgent'] = rec.get('UserAgent', '') or ''
        result['PageID'] = tryint(rec.get('PageID', '0'))
        result['DeviceType'] = tryint(rec.get('DeviceType', '0'))
        result['VideoCategoryID'] = tryint(rec.get('VideoCategoryID', '0'))
        result['EventTime'] = tryint(rec.get('EventTime', '0'))
        result['UniqID'] = tryint(rec.get('UniqID', '0'))
        result['Price'] = tryint(rec.get('Price', '0'))
        result['PartnerPrice'] = tryint(rec.get('PartnerPrice', '0'))
        try:
            result['UpdateTime'] = int(
                datetime.datetime.fromtimestamp(
                    result['EventTime']
                ).replace(
                    hour=0, minute=0, second=0, microsecond=0
                ).strftime(
                    '%s'
                )
            )
        except (TypeError, AttributeError, ValueError):
            result['UpdateTime'] = 0
        regionId = tryint(rec.get('RegionID', '0'))
        result['RegionID'] = process_region(regionId, geobase)
        isKal = tryint(rec.get('isKal', '0'))
        isView = tryint(rec.get('SourceLog', '0')) and isView
        result['IsView'] = isView
        if not isView:
            result['Duration'] = 0
        elif isKal:
            result['Duration'] = 5
        else:
            result['Duration'] = 10
        result['Hits'] = tryint(rec.get('Hit', '0'))
        result['ShownHits'] = tryint(rec.get('ShownHit', '0'))
        result['UserGender'] = get_max_prob_value(rec.get('gender', {}))
        result['UserAgeSegment'] = get_max_prob_value(
            rec.get('user_age_6s', {})
        )
        yield Record.from_dict(result)


def try_get_from_parent(chain, field, t='str'):
    for element in chain[::-1]:
        if element['ContentTypeID'] not in {2, 4, 6, 14, 32}:
            continue
        try:
            return element[field]
        except (IndexError, TypeError, KeyError):
            if t == 'str':
                return ''
            else:
                return 0
    if t == 'str':
        return ''
    else:
        return 0


def process_date(date, cluster, job_root, debug=False):
    source_table = '//home/video-hosting/logs/soc_dem/1d/{}'.format(date)
    mapped_table = '{}/{}_mapped'.format(
        job_root, date
    )
    reduced_table = '{}/{}'.format(
        job_root, date
    )

    job = cluster.job()

    iron_branch = job.table(
        IRON_BRANCH_TABLE
    ).project(
        VideoContentID='JoinKey',
        ContentUUID='UUID',
        ContentTypeID=ne.custom(
            lambda x: try_get_from_parent(
                x, 'ContentTypeID', t='int'
            ), 'chain'
        ).add_hints(type=int),
        ParentUUID=ne.custom(
            lambda x: try_get_from_parent(x, 'UUID'), 'chain'
        ).add_hints(type=str),
        ParentTypeID=ne.custom(
            lambda x: try_get_from_parent(
                x, 'ContentTypeID', t='int'
            ), 'chain'
        ).add_hints(type=int),
    )

    stream = job.table(
        source_table
    ).map(
        vh_detailed_stat_mapper,
        files=[
            nfi.StatboxDict('Geobasev6.bin')
        ]
    ).join(
        iron_branch, type='left', by='VideoContentID'
    ).project(
        ne.all(exclude=[
            'ContentUUID',
            'ContentTypeID',
            'ParentUUID',
            'ParentTypeID',
            'VideoContentID'
        ]),
        ContentUUID=ne.custom(lambda x: x or '', 'ContentUUID').add_hints(
            type=str
        ),
        VideoContentID=ne.custom(
            lambda x: x or '', 'VideoContentID'
        ).add_hints(
            type=str
        ),
        ParentUUID=ne.custom(lambda x: x or '', 'ParentUUID').add_hints(
            type=str
        ),
        ContentTypeID=ne.custom(lambda x: x or 0, 'ContentTypeID').add_hints(
            type=int
        ),
        ParentTypeID=ne.custom(lambda x: x or 0, 'ParentTypeID').add_hints(
            type=int
        ),
    )

    if debug:
        stream = stream.put(mapped_table)

    stream = stream.groupby(
        'UpdateTime',
        'ParentUUID',
        'ContentUUID',
        'PageID',
        'ParentTypeID',
        'ContentTypeID',
        'VideoSessionID',
    ).aggregate(
        VideoContentID=na.max("VideoContentID"),
        IsView=na.max("IsView"),
        UniqID=na.max("UniqID"),
        Service=na.max("Service"),
        RegionID=na.max("RegionID"),
        UserGender=na.max("UserGender"),
        UserAgeSegment=na.max("UserAgeSegment"),
        IP=na.max("IP"),
        UserAgent=na.max("UserAgent"),
        DeviceType=na.max("DeviceType"),
        DetailedDeviceType=na.max("DetailedDeviceType"),
        BrowserName=na.max("BrowserName"),
        VideoCategoryID=na.max("VideoCategoryID"),
        Price=na.sum("Price"),
        PartnerPrice=na.sum("PartnerPrice"),
        Duration=na.sum("Duration"),
        Hits=na.sum("Hits"),
        ShownHits=na.sum("ShownHits")
    ).put(
        reduced_table
    )

    job.run()

    return reduced_table


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--job_root',
        default='//home/videoquality/vh_analytics/vhds_reduced'
    )
    parser.add_argument('--from', default=None)
    parser.add_argument('--pool', default=None)
    parser.add_argument('--title', default='VH Detailed Stat')
    parser.add_argument('--to', default=None)
    parser.add_argument('--output')
    args = parser.parse_args()

    cluster = get_cluster(clusters, args)
    from_ = getattr(args, 'from', None)
    to_ = getattr(args, 'to', None)

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        processed_dates = sorted(
            get_date(x) for x in get_driver(cluster).client.search(
                root=args.job_root,
                path_filter=lambda x: get_date(x)
            )
        )
        try:
            last_date = processed_dates[-1]
        except IndexError:
            last_date = datetime.date(2018, 9, 17)
        available_dates = sorted(
            get_date(x) for x in get_driver(cluster).client.search(
                root='//home/video-hosting/logs/soc_dem/1d',
                path_filter=lambda x: get_date(x)
            )
        )

        dates = [x for x in available_dates if x > last_date]

    processed_tables = []
    for date in dates:
        print('processing {}'.format(date))
        p = process_date(date, cluster, args.job_root)
        if p:
            processed_tables.append(p.split('/')[-1])

    yql = YqlClient(db='hahn', token=os.environ['YQL_TOKEN'])
    req = yql.query(query, title='YQL | Get Last Date')
    req.run()
    req.wait_progress()
    last_date = datetime.datetime.fromtimestamp(
        int(req.full_dataframe['LastDate'][0])
    ).date()

    processed_dates = sorted(
        get_date(x) for x in get_driver(cluster).client.search(
            root=args.job_root,
            path_filter=lambda x: get_date(x)
        )
    )

    dates_to_push = [format(x) for x in processed_dates if x > last_date]

    if dates_to_push:
        json.dump(dates_to_push, open(args.output, 'w'), indent=2)


if __name__ == "__main__":
    main()
