#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import copy
import argparse
import random
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    files as nfi,
    Record,
    with_hints,
    extended_schema
)
from qb2.api.v1 import (
    filters as qf,
    extractors as qe,
    typing as qt,
    resources as qr,
    typing as qt
)
import datetime
# import json
# import hashlib
# import urlparse
# import re
from pytils import (
    date_range,
    optionalize_schema,
    yt_get_date_from_table as get_date,
    get_driver
)
from strm_cube_2_common_beta import (
    preprocessed_schema,
    StrmParser,
    JsTracerParser,
    RedirParser,
    DspParser,
    RtbParser,
    microsessions_reducer,
    add_session_markers,
    microsessions_schema,
    get_hash
)

cluster = None
IRON_BRANCH_TABLE = '//home/videolog/strm_meta/iron_branch/concat'
PAGE_IMP_TABLE = '//home/videolog/strm_meta/page_imp'
DSP_ID_TABLE = '//home/videolog/strm_meta/dsp_id'
ASNAME_TABLE = (
    '//home/search-research/ensuetina/AS_MAP/proper_AS_names_corrected'
)
CRYPTA_TABLE = '//home/crypta/production/profiles/export/profiles_for_14days'
nfi_common_small = [
    nfi.LocalFile('strm_cube_2_common_beta.py'),
    nfi.LocalFile('pytils.py'),
]
nfi_common = nfi_common_small + [
    nfi.StatboxDict('Geobasev6.bin'),
    nfi.StatboxDict('IpOriginV6.xml'),
]
default_jobroot = '//home/videoquality/vh_analytics/strm_cube_2'


def common_pipeline(
    cluster, job, table, parser, asname_dict, intensity='large_data',
    memory_limit=5000, optional_kwargs=None
):
    if not optional_kwargs:
        optional_kwargs = {}
    return job.table(
        table
    ).map(
        with_hints(
            output_schema=preprocessed_schema
        )(parser(asname_dict, **optional_kwargs)),
        files=nfi_common, memory_limit=memory_limit, intensity=intensity
    )


def dsp_pipeline(
    job, dsp_table
):
    page_ids = {
        str(x['PageID']) for x in cluster.driver.yt_driver.read(PAGE_IMP_TABLE)
    }

    return job.table(
        dsp_table
    ).map(
        with_hints(
            output_schema=preprocessed_schema
        )(DspParser(page_ids)), files=nfi_common_small,
        intensity='large_data'
    ).project(
        'os_family', 'browser_name', 'timestamp', 'yu_hash',
        'yandexuid', 'page_id', 'event', 'source', 'add_info'
    )


def merge_dicts(*args):
    if not args:
        return {}
    a = args[0] or {}
    for d in args[1:]:
        a.update(d or {})
    return a


def try_get_content_type_id(chain):
    try:
        return int(chain[-1]['ContentTypeID'])
    except (ValueError, TypeError, AttributeError, IndexError):
        return -1


def check_rtb_dsp(a, c):
    dspid = a['DspID']
    producttype = a['producttype']
    video_type = c['video_type']
    return (
        video_type != 'in_app' or
        (
            video_type == 'in_app' and
            (
                dspid != 1 or producttype == 'auto-video-direct'
            )
        )
    )


def rtb_dsp_pipeline(
    cluster, job, dsp_table, asname_dict, rtb_table=None
):
    page_ids = {
        str(x['PageID']) for x in
        cluster.driver.yt_driver.read(PAGE_IMP_TABLE)
        if x.get('vh_partner', '') != 'not_vh'
    }

    dsp_ids = {
        x['DSPID'] for x in
        cluster.driver.yt_driver.read(DSP_ID_TABLE)
    }

    dsp_fields_common = [
        'os_family', 'browser_name', 'timestamp', 'yu_hash',
        'yandexuid', 'page_id', 'event', 'source',
        'imp_id'
    ]
    rtb_fields_common = [
        'vsid', 'video_content_id',
        'user_agent', 'browser_version', 'device_type',
        'region', 'country', 'ip', 'a_station', 'provider'
    ]

    dsp_schema = copy.deepcopy(preprocessed_schema)
    dsp_schema.pop('add_info')
    dsp_schema['add_info_a'] = qt.Optional[qt.Json]
    dsp_schema['imp_id'] = qt.Optional[qt.String]

    if rtb_table:

        dsp_fields = ['bidreqid', 'add_info_a'] + dsp_fields_common

        dsp = job.table(
            dsp_table
        ).map(
            with_hints(
                output_schema=dsp_schema
            )(DspParser(page_ids, dsp_ids, add_info='add_info_a')),
            files=nfi_common_small, intensity='large_data'
        ).project(
            dsp_fields
        )

        rtb_schema = copy.deepcopy(preprocessed_schema)
        rtb_schema.pop('add_info')
        rtb_schema['add_info_b'] = qt.Optional[qt.Json]

        rtb_fields = ['bidreqid', 'add_info_b'] + rtb_fields_common

        rtb = job.table(
            rtb_table
        ).map(
            with_hints(
                output_schema=rtb_schema
            )(RtbParser(asname_dict)), files=nfi_common, intensity='large_data'
        ).unique(
            'bidreqid'
        ).project(
            rtb_fields
        )

        rtb_dsp = dsp.join(
            rtb, type='left', by='bidreqid'
        )
    else:
        dsp_schema['add_info_b'] = qt.Optional[qt.Json]
        rtb_dsp_fields = (
            ['bidreqid', 'add_info_a', 'add_info_b'] +
            dsp_fields_common + rtb_fields_common
        )

        rtb_dsp = job.table(
            dsp_table
        ).map(
            with_hints(
                output_schema=dsp_schema
            )(DspParser(
                page_ids, dsp_ids,
                add_info='add_info_a', process_rtb=True,
                asname_dict=asname_dict
            )),
            files=nfi_common, intensity='large_data'
        ).project(
            rtb_dsp_fields
        )

    page_imp = job.table(
        PAGE_IMP_TABLE
    ).project(
        category_id=ne.custom(str, 'category_id').add_hints(
            type=qt.Optional[qt.String]
        ),
        imp_id=ne.custom(str, 'ImpID').add_hints(
            type=qt.Optional[qt.String]
        ),
        page_id=ne.custom(str, 'PageID').add_hints(
            type=qt.Optional[qt.String]
        ),
        add_info_c=ne.custom(
            lambda ad_type, video_type, not_vh_partner, vh_partner:
            {
                'ad_type': ad_type,
                'video_type': video_type,
                'not_vh_partner': not_vh_partner,
                'vh_partner': vh_partner,
            }, 'type', 'video_type', 'not_vh_partner', 'vh_partner'
        ).add_hints(type=qt.Optional[qt.Json])
    )

    return rtb_dsp.join(
        page_imp, type='inner', by=['page_id', 'imp_id']
    ).filter(
        nf.custom(
            check_rtb_dsp, 'add_info_a', 'add_info_c'
        )
    ).project(
        ne.all(exclude=[
            'add_info_a', 'add_info_b', 'add_info_c', 'video_content_id'
        ]),
        video_content_id=ne.custom(
            lambda x: x if x else 'novcid', 'video_content_id'
        ).add_hints(
            type=qt.Optional[qt.String]
        ),
        add_info=ne.custom(
            merge_dicts,
            'add_info_a', 'add_info_b', 'add_info_c'
        ).add_hints(type=qt.Optional[qt.Json])
    )


def try_get_from_parent(chain, field, t='str'):
    for element in chain[::-1]:
        if element['ContentTypeID'] not in {2, 4, 6, 14, 32}:
            continue
        try:
            return element[field]
        except (IndexError, TypeError, KeyError):
            if t == 'str':
                return ''
            else:
                return 0
    if t == 'str':
        return ''
    else:
        return 0


def get_view_type(
    timetuple, timestamp, content_type_id, heur_category
):
    if heur_category == 'vod':
        return 'vod'
    timetuple = timetuple or [None, None]
    if timetuple[0] <= timestamp <= timetuple[1]:
        return 'live'
    if not timetuple[0] and content_type_id == 2:
        return 'live'
    return 'dvr'


def good_vcid(s):
    return bool(s) and s != 'novcid'


def process_date(date, args):
    date_f = format(date)
    date_rev = date.strftime('%d-%m-%Y')
    strm_table = '//logs/strm-access-log/1d/{}'.format(date)
    redir_table = '//logs/redir-log/1d/{}'.format(date)
    js_tracer_table = '//logs/jstracer-log/1d/{}'.format(date)
    js_tracer_table_old = '//home/js_tracer/day_by_day/{}'.format(date_rev)
    # rtb_table = '//logs/bs-rtb-log/1d/{}'.format(date)
    # dsp_table = '//logs/bs-dsp-log/1d/{}'.format(date)
    rtb_dsp_cooked_table = (
        '//statbox/cooked_logs/bs-dsp-cooked-log/v1/1d/{}'.format(date)
    )
    driver = get_driver(cluster)
    client = driver.client

    root = (
        '{}/{}'.format(args.job_root, date_f)
    )
    preprocessed = '{}/preprocessed'.format(root)
    sessions = '{}/sessions'.format(root)

    asname_dict = {
        rec.ASN: rec.ISP
        for rec in driver.read(ASNAME_TABLE)
    }

    redo_date = None
    try:
        redo_date = get_date(args.redo)
    except:
        pass

    try:
        mtimedate = get_date(driver.get_attribute(
            preprocessed, 'modification_time'
        ))
    except:
        mtimedate = None

    job = cluster.job()

    if driver.exists(preprocessed) and (
        not args.redo or redo_date and mtimedate and redo_date <= mtimedate
    ):
        print('using existing preprocessed table')
        stream = job.table(preprocessed)
    else:
        to_concat = []
        to_concat.append(
            common_pipeline(
                cluster, job, strm_table, StrmParser, asname_dict,
                memory_limit=5000, intensity='data'
            )
        )
        if driver.exists(js_tracer_table):
            to_concat.append(
                common_pipeline(
                    cluster, job,
                    js_tracer_table, JsTracerParser, asname_dict,
                    optional_kwargs={'new': True}
                )
            )
        elif driver.exists(js_tracer_table_old):
            to_concat.append(
                common_pipeline(
                    cluster, job,
                    js_tracer_table_old, JsTracerParser, asname_dict
                )
            )
        if driver.exists(redir_table):
            to_concat.append(
                common_pipeline(
                    cluster, job, redir_table, RedirParser, asname_dict
                ),
            )
        if driver.exists(rtb_dsp_cooked_table):
            print('using rtb-dsp cooked log')
            to_concat.append(
                rtb_dsp_pipeline(
                    cluster, job, rtb_dsp_cooked_table, asname_dict,
                    rtb_table=None
                )
            )
        # elif (
        #     cluster.driver.exists(rtb_table) and
        #     cluster.driver.exists(dsp_table)
        # ):
        #     print('using rtb and dsp tables')
        #     to_concat.append(
        #         rtb_dsp_pipeline(
        #             cluster, job, dsp_table, asname_dict,
        #             rtb_table=rtb_table
        #         )
        #     )
        # elif cluster.driver.exists(dsp_table):
        #     print('using dsp table only')
        #     to_concat.append(
        #         rtb_dsp_pipeline(
        #             cluster, job, dsp_table, asname_dict,
        #             rtb_table=None
        #         )
        #     )
        else:
            raise Exception('neither of rtb-dsp tables exist, failing')

        stream = job.concat(
            *to_concat
        ).sort(
            'vsid', 'timestamp'
        ).put(
            preprocessed
        )

    job.run()

    client.set_attribute(
        preprocessed,
        'build_time',
        datetime.datetime.now().strftime(
            '%Y-%m-%dT%H:%M:%S'
        )
    )

    if args.no_sessions:
        return

    job = cluster.job()

    iron_branch_short = []
    for i in range(5):
        iron_branch_short.append(job.table(
            IRON_BRANCH_TABLE
        ).project(
            "UUID",
            video_content_id="JoinKey",
            rand=ne.const(i)
        ))
    iron_branch_short = job.concat(*iron_branch_short)

    iron_branch = job.table(
        IRON_BRANCH_TABLE
    ).project(
        'heur_category',
        'UUID',
        'ContentGroupID',
        content_duration='duration',
        content_type_id=ne.custom(
            try_get_content_type_id, 'chain'
        ).add_hints(type=qt.Integer),
        channel='computed_channel',
        program='computed_program',
        timetuple=ne.custom(
            lambda x, y: (x, y), 'start_time', 'finish_time'
        ).add_hints(
            type=qt.Json
        ),
        ContentTypeID=ne.custom(
            lambda x: try_get_from_parent(
                x, 'ContentTypeID', t='int'
            ), 'chain'
        ).add_hints(type=int),
        ParentUUID=ne.custom(
            lambda x: try_get_from_parent(x, 'UUID'), 'chain'
        ).add_hints(type=str),
        ParentTypeID=ne.custom(
            lambda x: try_get_from_parent(
                x, 'ContentTypeID', t='int'
            ), 'chain'
        ).add_hints(type=int),
    )

    crypta = job.table(
        CRYPTA_TABLE
    ).project(
        'age_segments', 'user_age_6s', 'gender',
        'income_segments', 'exact_socdem',
        yandexuid_from_crypta=ne.custom(str, 'yandexuid').add_hints(
            type=str
        ),
        yu_hash=ne.custom(
            lambda x: get_hash(str(x)), 'yandexuid'
        ).add_hints(type=str),
        files=nfi_common_small
    ).unique(
        'yu_hash'
    )

    stream = job.table(
        preprocessed
    ).filter(
        qf.nonzero("vsid"), intensity="data"
    ).project(
        ne.all(), rand=ne.custom(
            lambda x: random.randint(0, 5), 'vsid'
        ).with_type(int), intensity="data"
    )

    with_vcids = stream.filter(
        qf.nonzero("video_content_id"), intensity="data"
    ).join(
        iron_branch_short, by=["video_content_id", "rand"], type="left",
        assume_unique_right=True
    )
    without_vcids = stream.filter(
        qf.not_(qf.nonzero("video_content_id"))
    )

    stream = job.concat(
        with_vcids, without_vcids
    ).groupby(
        'vsid'
    ).sort(
        'timestamp', 'UUID'
    ).reduce(
        with_hints(
            output_schema=optionalize_schema(qt, microsessions_schema)
        )(microsessions_reducer),
        files=nfi_common_small, memory_limit=5000, intensity="data"
    )

    stream_with_vcid = stream.filter(
        nf.custom(bool, 'UUID')
    )

    stream_without_vcid = stream.filter(
        nf.custom(lambda x: not x, 'UUID')
    )

    stream_with_vcid = stream_with_vcid.join(
        iron_branch, by='UUID', type='left',
        assume_unique_right=True
    )

    stream = job.concat(
        stream_with_vcid,
        stream_without_vcid
    )

    stream_with_yu_hash = stream.filter(
        nf.custom(bool, 'yu_hash')
    )

    stream_without_yu_hash = stream.filter(
        nf.custom(lambda x: not bool(x), 'yu_hash')
    )

    stream_with_yu_hash = stream_with_yu_hash.project(
        ne.all(exclude=['video_content_id']),
        video_content_id=ne.custom(
            lambda x: x if x else 'novcid', 'video_content_id'
        ).add_hints(type=qt.Optional[qt.String])
    )

    stream_with_yu_hash = stream_with_yu_hash.join(
        crypta, by='yu_hash', type='left', assume_unique_right=True
    )

    stream = job.concat(
        stream_with_yu_hash,
        stream_without_yu_hash
    )

    stream = stream.project(
        ne.all(exclude=['yandexuid', 'yandexuid_from_crypta']),
        yandexuid=ne.custom(
            lambda x, y: (x if x else y),
            'yandexuid', 'yandexuid_from_crypta'
        ).add_hints(type=str),
        view_type=ne.custom(
            get_view_type,
            'timetuple', 'timestamp', 'content_type_id', 'heur_category'
        ).add_hints(type=str),
    ).groupby(
        'user_id'
    ).sort(
        'timestamp'
    ).reduce(
        with_hints(
            output_schema=extended_schema(
                is_view_old=int,
                is_view_new=int,
                new_ad_session_crutch=int,
                new_user_session=int
            )
        )(
            add_session_markers
        ), files=nfi_common, memory_limit=2000  # intensity='large_data'
    ).project(
        ne.all(),
        fielddate=ne.const(date_f)
    )

    stream = stream.put(
        sessions,  # schema=optionalize_schema(qt, microsessions_schema)
    )

    job.run()

    client.set_attribute(
        sessions,
        'build_time',
        datetime.datetime.now().strftime(
            '%Y-%m-%dT%H:%M:%S'
        )
    )


def main():
    global cluster
    parser = argparse.ArgumentParser()
    parser.add_argument('--from')
    parser.add_argument('--to')
    parser.add_argument('--pool', default='videolog')
    parser.add_argument('--redo')
    parser.add_argument('--no_redir', action='store_true')
    parser.add_argument('--no_sessions', action='store_true')
    parser.add_argument('--rtb_dsp', action='store_true')
    parser.add_argument('--job_root', default=default_jobroot)
    args = parser.parse_args()

    from_ = getattr(args, 'from')
    to_ = getattr(args, 'to')
    job_root = args.job_root

    hahn_kwargs = dict(
        token=os.environ['YT_TOKEN'],
        yql_token=os.environ['YQL_TOKEN'],
        pool=args.pool
    )

    cluster = clusters.yql.Hahn(
        **hahn_kwargs
    ).env(
        yt_spec_defaults=dict(
            pool_trees=["physical"],
            tentative_pool_trees=["cloud"]
        ),
        templates=dict(
            tmp_root='//home/videoquality/vh_analytics/tmp',
            title='STRM Cube v2'
        )
    )
    driver = get_driver(cluster)
    client = driver.client

    if from_ and to_:
        dates_to_process = date_range(from_, to_)
    else:
        processed_dates = sorted(
            get_date(s) for s in client.search(
                root=job_root, node_type="table",
                path_filter=lambda x: (x or '').endswith(
                    'sessions'
                )
            ) if get_date(s)
        )

        print('last date: {}'.format(processed_dates[-1]))

        available_dates = sorted(
            get_date(s) for s in client.search(
                root='//logs/strm-access-log/1d', node_type="table"
            ) if get_date(s)
        )

        print('last available date: {}'.format(available_dates[-1]))

        dates_to_process = [
            x for x in available_dates if x > processed_dates[-1]
        ]

    for date in dates_to_process:
        print('processing {}...'.format(date))
        process_date(date, args)


if __name__ == "__main__":
    main()
