#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    statface as ns,
    with_hints,
    Record
)
import nile.files as nfi
from qb2.api.v1 import typing as qt, resources as qr
from pytils import (
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers
)
import itertools
import copy


VH_ROOT = '//home/videoquality/vh_analytics'
init_schema = dict(
    channel_old=qt.String,
    provider=qt.String,
    ref_from=qt.String,
    os_family=qt.String,
    browser_name=qt.String,
    country=qt.String,
    gap_type=qt.String,
    error_type=qt.String,
    gap_value=qt.Integer
)
up_schema = copy.deepcopy(init_schema)
up_schema.pop('gap_value')
up_schema['fielddate'] = qt.String
up_schema['quantile'] = qt.String
up_schema['value'] = qt.Float
qs_tup = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)


def is_bad(e):
    e = e or ''
    return e.endswith('_fatal') or e == 'Stalled_Other'


def get_country(region):
    for reg in region.path:
        if reg.type == 3:
            return reg.short_name
    return 'UNK'


def process_session(recs):
    geobase = qr.get('Geobase')
    for rec in recs:
        errors_all = getattr(rec, 'errors_all', []) or []
        bad_errors_all = [
            (x['timestamp'], x['error_id'])
            for x in errors_all if is_bad(x['error_id'])
        ]
        country = get_country(geobase.region_by_id(int(rec.reg)))
        first = True
        for vs in rec.view_session:
            st = vs['_start']
            et = st + vs['view_duration']
            bad_errors = sorted(
                [x for x in bad_errors_all if st < x[0] < et]
            )
            bad_errors.insert(0, (st, 'SessionStart'))
            bad_errors.append((et, 'SessionEnd'))
            for i in range(len(bad_errors) - 1):
                if bad_errors[i + 1][0] - bad_errors[i][0] < 0:
                    raise Exception('{}, {}, {}, {}'.format(
                        bad_errors[i + 1], bad_errors[i], bad_errors, rec.vsid
                    ))
                if first:
                    gap_types = ['before_first', '_total_']
                    first = False
                else:
                    gap_types = ['intermediate', '_total_']
                for comb in itertools.product(
                    (vs['view_channel_old'], '_total_'),
                    (rec.provider, '_total_'),
                    (rec.ref_from, '_total_'),
                    (rec.os_family, '_total_'),
                    (rec.browser_name, '_total_'),
                    (country, '_total_'),
                    gap_types,
                    (bad_errors[i + 1][1], '_total_')
                ):
                    yield Record(
                        channel_old=comb[0],
                        provider=comb[1],
                        ref_from=comb[2],
                        os_family=comb[3],
                        browser_name=comb[4] or "unknown",
                        country=comb[5],
                        gap_type=comb[6],
                        error_type=comb[7],
                        gap_value=bad_errors[i + 1][0] - bad_errors[i][0]
                    )


class AddQuantiles(object):

    def __init__(self, fd):
        self.fd = str(fd)

    def __call__(self, recs):
        for rec in recs:
            qs = dict(list(rec.quantiles))
            d = rec.to_dict()
            d.pop('quantiles')
            d['fielddate'] = self.fd
            for q in qs:
                dd = copy.copy(d)
                ql = 'q{}'.format(q)
                dd['quantile'] = ql
                dd['value'] = qs[q]
                yield Record.from_dict(dd)


def process_date(date, report, cluster, proxy, args):

    job_root = '{}/time_before_error'.format(VH_ROOT)
    report_table = '{}/{}'.format(job_root, date)

    if not cluster.driver.exists(report_table) or args.redo:
        job = cluster.job()

        job.table(
            '{}/strm_video_nc/{}/sessions'.format(VH_ROOT, date)
        ).map(
            with_hints(output_schema=init_schema)(process_session),
            files=[
                nfi.StatboxDict('Geobasev6.bin'),
            ],
            intensity='ultra_cpu'
        ).groupby(
            'channel_old', 'provider', 'ref_from', 'os_family',
            'browser_name', 'country', 'gap_type', 'error_type'
        ).aggregate(
            quantiles=na.quantile_estimate(
                'gap_value', qs_tup
            )
        ).map(
            with_hints(output_schema=up_schema)(AddQuantiles(date))
        ).sort(
            'fielddate'
        ).put(
            report_table
        )

        job.run()

    client = ns.StatfaceClient(
        proxy='upload.stat.yandex-team.ru',
        username=os.environ['STAT_LOGIN'],
        password=os.environ['STAT_TOKEN']
    )

    ns.StatfaceReport().path(
        report
    ).scale('daily').replace_mask(
        'fielddate'
    ).client(
        client
    ).remote_publish(
        proxy=proxy,
        table_path=report_table,
        async_mode=False,
        upload_config=False
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--report', default='Video/Others/Strm/errors_gaps')
    parser.add_argument('--from', default=None)
    parser.add_argument('--pool', default=None)
    parser.add_argument('--to', default=None)
    parser.add_argument('--redo', action='store_true')
    args = parser.parse_args()

    proxy = os.environ['YT_PROXY'].split('.')[0].title()
    cluster = getattr(clusters.yql, proxy)(
        token=os.environ['YT_TOKEN'],
        yql_token=os.environ['YQL_TOKEN'],
    ).env(
        templates={'title': 'errors_gaps | YQL'}
    )
    if args.pool:
        cluster = cluster.update(pool=args.pool)

    from_ = getattr(args, 'from')
    to_ = getattr(args, 'to')

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        print('getting dates from stat')
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers,
            report=args.report,
            dimensions=[
                'os_family', 'browser_name', 'country', 'provider',
                'ref_from', 'channel_old', 'gap_type', 'error_type'
            ],
            add_cgi='&quantile=q0.5'
        )[-1]
        available_dates = sorted(
            get_date(x) for x in cluster.driver.yt_driver.client.search(
                root='//logs/strm-access-log/1d',
                path_filter=lambda x: get_date(x)
            )
        )

        dates = [x for x in available_dates if x > last_date_from_stat]

    for date in dates:
        print('processing {}'.format(date))
        process_date(date, args.report, cluster, proxy.lower(), args)


if __name__ == "__main__":
    main()
