#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    statface as ns,
    with_hints,
    modified_schema,
    Record
)
import nile.files as nfi
import copy
from yql.api.v1.client import YqlClient
from qb2.api.v1 import typing as qt
from pytils import (
    yql_run,
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
    StatPusher,
    get_driver
)
import itertools
from collections import Counter


errors_query = """
pragma yt.Pool = 'POOL';
SELECT DISTINCT error_id
FROM [home/videoquality/vh_analytics/strm_cube_2/DATE/preprocessed]
"""


def get_duration_category(duration):
    if not duration or duration < 0:
        return 'other'
    if 0 < duration <= 60 * 5:
        return "005m-"
    elif 60 * 5 < duration <= 60 * 20:
        return "005-20m"
    elif 60 * 20 < duration <= 60 * 30:
        return "020-30m"
    elif 60 * 30 < duration <= 60 * 40:
        return "030-40m"
    elif 60 * 40 < duration <= 60 * 60:
        return "040-60m"
    elif 60 * 60 < duration <= 60 * 90:
        return "060-90m"
    elif 60 * 90 < duration <= 60 * 120:
        return "090-120m"
    elif 60 * 120 < duration <= 60 * 240:
        return "120-240m"
    else:
        return "240m+"


def get_os(rec):
    osf = rec['os_family']
    dt = rec['device_type']
    if dt == 'tv':
        osf = '{}_TV'.format(osf)
    return osf


class SessionReducer(object):

    def __init__(self, all_errors):
        self.all_errors = all_errors

    def __call__(self, groups):
        for key, recs in groups:
            vsid_set = set()
            buff = []
            for rec in recs:
                buff.append(rec)
                if len(buff) > 100:
                    continue
            for i, rec in enumerate(buff):
                if not rec['view_time']:
                    continue
                vt = rec['view_time']
                duration_category = get_duration_category(
                    rec.get('content_duration', None)
                )
                time_before_first_error = {}
                error_counter = Counter()
                for error in rec.get('errors', []):
                    eid = error['id']
                    if i != (len(buff) - 1):
                        eid = ensure_not_fatal(eid)
                    elif error['rel_time'] < vt:
                        eid = ensure_not_fatal(eid)
                    else:
                        eid = ensure_fatal(eid)
                    error_counter[eid] += 1
                    if eid not in time_before_first_error:
                        time_before_first_error[eid] = error['rel_time']
                for eid in (
                    self.all_errors - set(time_before_first_error.keys())
                ):
                    time_before_first_error[eid] = rec['view_time']

                fminv = min(
                    [
                        time_before_first_error[x]
                        for x in time_before_first_error
                        if x == ensure_fatal(x)
                    ] + [vt]
                )
                minv = min(
                    [
                        time_before_first_error[x]
                        for x in time_before_first_error
                    ] + [vt]
                )
                error_counter[
                    '_total_'
                ] = sum(error_counter.values())
                error_counter[
                    '_fatal_'
                ] = sum(
                    [
                        error_counter[x] for x in error_counter
                        if x == ensure_fatal(x)
                    ]
                )
                time_before_first_error['_fatal_'] = fminv
                time_before_first_error['_total_'] = minv
                for comb in itertools.product(
                    (duration_category, '_total_'),
                    (rec['provider'], '_total_'),
                    (get_os(rec), '_total_'),
                    (rec['view_type'], '_total_'),
                    sorted(time_before_first_error.keys()),
                    (rec['player_version'], '_total_'),
                ):
                    dc, provider, os_family, view_type, eid, pv = comb
                    vsids = (
                        0 if comb in vsid_set else 1
                    )
                    vsid_set.add(comb)
                    yield Record(
                        duration_category=dc,
                        error_id=eid,
                        view_type=view_type,
                        os_family=os_family,
                        provider=provider,
                        player_version=pv,
                        rel_time=float(time_before_first_error[eid]),
                        error_count=error_counter[eid],
                        vsids=vsids,
                        view_time=vt * vsids
                    )


def ensure_fatal(s):
    if s.endswith('_fatal'):
        return s
    return '{}_fatal'.format(s)


def ensure_not_fatal(s):
    if s.endswith('_fatal'):
        return s[:len('_fatal')]
    return s


def is_good_vcid(vcid):
    if not vcid or vcid in {'-', "novcid"}:
        return False
    return True


@with_hints(output_schema=modified_schema(
    exclude=[
        'rel_time_quantiles',
        'error_count_quantiles',
        'error_count_histogram'
    ],
    extend={
        'quantile_type': qt.String,
        'rel_time_quantile_value': qt.Float,
        'error_count_quantile_value': qt.Float,
        'errors_per_vsid': qt.Float,
        'errors_per_vcid': qt.Float,
        'errors_per_60s': qt.Float,
        'vt_per_vcid': qt.Float,
        'rel_time_quantile_value_weighted': qt.Float,
        'errors_total': qt.Integer,
        'vcids_with_errors': qt.Integer,
        'vcids_without_errors': qt.Integer,
        'vcids_with_errors_share': qt.Float,
    }
))
def add_quantiles(recs):
    for rec in recs:
        dct = rec.to_dict()
        rtq = dct.pop('rel_time_quantiles')
        ecq = dict(dct.pop('error_count_quantiles'))
        error_counter = Counter(dict(dct.pop('error_count_histogram')))
        for qpair in rtq:
            total_sessions = sum(error_counter.values())
            error_count = total_sessions - error_counter[0]
            ndct = copy.deepcopy(dct)
            ndct['errors_total'] = error_count
            ndct['quantile_type'] = 'q{}'.format(qpair[0])
            qv = float(qpair[1])
            ecqv = float(ecq[qpair[0]])
            ndct['rel_time_quantile_value'] = qv
            ndct['error_count_quantile_value'] = ecqv
            ndct['errors_per_60s'] = (
                60.0 * error_count / float(ndct['tvt'])
            )
            ndct['vt_per_vcid'] = ndct['tvt'] / float(ndct['vcids'])
            ndct['errors_per_vsid'] = (
                error_count / float(ndct['vsids'])
            )
            ndct['errors_per_vcid'] = (
                error_count / float(ndct['vcids'])
            )
            errorless = (
                error_counter[0] / total_sessions
            )
            ndct['vcids_without_errors'] = error_counter[0]
            ndct['vcids_with_errors'] = ndct['vcids'] - error_counter[0]
            ndct['vcids_with_errors_share'] = (
                ndct['vcids_with_errors'] / ndct['vcids']
            )
            ndct['rel_time_quantile_value_weighted'] = qv * (1 - errorless)
            yield Record(**ndct)


def process_date(
    date, report, cluster, replace_mask=None, async_mode=False,
    pool='loadbase', debug=False
):
    date_s = str(date)
    proxy = os.environ['YT_PROXY'].split('.')[0].lower()
    out_table_path = (
        "//home/videolog/time_before_first_error/yt/{date}".format(
            date=date
        )
    )
    input_table = (
        '//home/videoquality/vh_analytics/strm_cube_2/'
        '{}/sessions'.format(date)
    )
    client = YqlClient(
        db=proxy,
        token=os.environ["YQL_TOKEN"]
    )
    req = client.query(
        query=errors_query.replace(
            'POOL', pool
        ).replace('DATE', date_s),
        title='Time Before First Error | YQL'
    )
    req.run()
    all_errors = {
        ensure_not_fatal(x) for x in req.results.full_dataframe['error_id']
        if ensure_not_fatal(x)
    }
    all_errors |= {
        ensure_fatal(x) for x in all_errors
    }

    job = cluster.job()

    stream = job.table(
        input_table
    ).filter(
        nf.custom(is_good_vcid, 'video_content_id'),
        nf.custom(lambda x: x > 0, 'view_time')
    ).groupby(
        'vsid'
    ).sort(
        'timestamp'
    ).reduce(
        with_hints(
            output_schema={
                'duration_category': qt.String,
                'error_id': qt.String,
                'view_type': qt.String,
                'os_family': qt.String,
                'provider': qt.String,
                'player_version': qt.String,
                'rel_time': qt.Float,
                'error_count': qt.Integer,
                'view_time': qt.Integer,
                'vsids': qt.Integer
            }
        )(SessionReducer(all_errors)), intensity='ultra_cpu'
    ).groupby(
        'duration_category', 'error_id', 'view_type', 'os_family',
        'provider', 'player_version'
    ).aggregate(
        vcids=na.count(),
        vsids=na.sum('vsids'),
        error_count_histogram=na.histogram(
            'error_count'
        ),
        error_count_quantiles=na.quantile(
            'error_count',
            [round(0.1 * x, 1) for x in range(1, 10)]
        ),
        rel_time_quantiles=na.quantile(
            'rel_time',
            [round(0.1 * x, 1) for x in range(1, 10)]
        ),
        tvt=na.sum('view_time')
    )

    if debug:
        stream.put(
            '//home/videolog/tmp/error_cube_tmp'
        )

    stream = stream.map(
        add_quantiles, intensity='ultra_cpu'
    ).project(
        ne.all(),
        fielddate=ne.const(date_s)
    ).put(
        out_table_path
    )

    job.run()

    stat_pusher = StatPusher(
        cluster,
        report=report,
        remote_publish=True
    )

    print('Performing remote push of {}...'.format(
        out_table_path
    ))
    stat_pusher.push(out_table_path)
    if not debug:
        get_driver(cluster).remove(out_table_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--report', default='Video/Others/Strm/Stability/error_cube')
    parser.add_argument('--from', default=None)
    parser.add_argument('--pool', default=None)
    parser.add_argument('--to', default=None)
    parser.add_argument('--debug', action='store_true')
    args = parser.parse_args()

    proxy = os.environ['YT_PROXY'].split('.')[0].title()
    cluster = getattr(clusters.yql, proxy)(
        token=os.environ['YT_TOKEN'],
        yql_token=os.environ['YQL_TOKEN']
    ).env(
        yt_spec_defaults=dict(
            pool_trees=["physical"],
            tentative_pool_trees=["cloud"]
        ),
        templates=dict(
            tmp_root='//home/videoquality/vh_analytics/tmp',
            title='Error Cube'
        )
    )
    if args.pool:
        cluster = cluster.update(pool=args.pool)

    from_ = getattr(args, 'from')
    to_ = getattr(args, 'to')

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers,
            report=args.report,
            dimensions=[]
        )[-1]
        available_dates = sorted(
            get_date(x) for x in get_driver(cluster).client.search(
                root='//home/videoquality/vh_analytics/strm_cube_2',
                path_filter=lambda x: x.endswith('/sessions')
            )
        )

        dates = [x for x in available_dates if x > last_date_from_stat]

    for date in dates:
        print('processing {}'.format(date))
        process_date(
            date, args.report, cluster, pool=args.pool, debug=args.debug
        )


if __name__ == "__main__":
    main()
