#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    grouping as ng,
    statface as ns,
    Record
)
import json
from collections import Counter
import itertools


table = '//home/videolog/strm_video/2018-06-14/sessions'
# table = '//home/videolog/tmp/refreshes_debug_1'
TIMEOUT = 30 * 60


def merge_ecits(old, new):
    for vcid in new:
        if vcid not in old:
            old[vcid] = new[vcid]
            continue
        for error_name in new[vcid]:
            if error_name not in old[vcid]:
                old[vcid][error_name] = new[vcid][error_name]
            else:
                old[vcid][error_name] = sorted(
                    set(old[vcid][error_name]) | set(new[vcid][error_name])
                )


def process_session(session):
    big_view_session = []
    error_content_ids_ts = {}
    refreshes = Counter()
    for rec in session:
        if rec.get('error_content_ids_ts', {}):
            if not error_content_ids_ts:
                error_content_ids_ts.update(rec['error_content_ids_ts'])
            else:
                merge_ecits(error_content_ids_ts, rec['error_content_ids_ts'])
        for x in rec['view_session']:
            x['vsid'] = rec['vsid']
        big_view_session.extend(rec['view_session'])
    if error_content_ids_ts:
        bvs = big_view_session
        for i in range(1, len(big_view_session)):
            start_ts = bvs[i]['_start']
            last_error_ts = 0
            last_error_type = ''
            refresh = False
            if (
                bvs[i]['vsid'] != bvs[i - 1]['vsid'] and
                bvs[i]['view_channel'] == bvs[i - 1]['view_channel']
            ):
                refresh = True
                errors = error_content_ids_ts.get(
                    bvs[i - 1]['video_content_id']
                ) or {}
                for err in errors:
                    fitting = [
                        x for x in errors[err] if
                        start_ts - 30 <= x <= start_ts
                    ]
                    if not fitting:
                        continue
                    if fitting[-1] > last_error_ts:
                        last_error_ts = fitting[-1]
                        last_error_type = err
            if refresh:
                refreshes[last_error_type or '_no_errors_'] += 1
    sum_ = sum(refreshes.values())
    sum_errors = sum(
        refreshes[x] for x in refreshes if x != '_no_errors_'
    )
    if sum_errors:
        refreshes['_all_errors_'] = sum_errors
    if not sum_:
        refreshes['_no_errors_'] = 0
    refreshes['_total_'] = sum_

    for pair in itertools.product(
        (rec['os_family'], '_total_'),
        (rec['browser_name'], '_total_'),
        (rec['provider'], '_total_'),
        sorted(refreshes)
    ):
        yield Record(
            os_family=pair[0],
            browser_name=pair[1],
            provider=pair[2],
            error=pair[3],
            refreshes=refreshes[pair[3]],
            sessions=1,
        )


def get_refreshes(groups, out1, out2):
    for key, recs in groups:
        current_session = []
        prev_end_time = 0
        for rec in recs:
            if (
                prev_end_time and
                (rec['start_time'] - prev_end_time) > TIMEOUT
            ):
                for_out2 = False
                for rec_ in process_session(current_session):
                    out1(rec_)
                    if rec_.refreshes:
                        for_out2 = True
                if for_out2:
                    for rec_ in current_session:
                        out2(Record(**rec_))
                current_session = []
            current_session.append(rec.to_dict())
            prev_end_time = rec['end_time']
        for_out2 = False
        for rec_ in process_session(current_session):
            out1(rec_)
            if rec_.error == 'Stalled_First':
                for_out2 = True
        if for_out2:
            for rec_ in current_session:
                out2(Record(**rec_))


def add_shares(groups):
    for key, records in groups:
        before_total = []
        total = None
        for rec in records:
            if rec.error == '_total_':
                total = float(rec['sessions'])
                res = rec.to_dict()
                res['error_share'] = float(1)
                yield Record(**res)
            elif total is None:
                before_total.append(rec.to_dict())
            else:
                res = rec.to_dict()
                res['error_share'] = float(round(
                    rec['sessions'] / total, 3
                ))
                yield Record(**res)
        for rec in before_total:
            try:
                rec['error_share'] = float(round(rec['sessions'] / total, 4))
            except:
                raise Exception((json.dumps(rec), total))
            yield Record(**rec)


def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument('filename')
    # args = parser.parse_args()
    job_root = '//home/videolog/tmp/refreshes_folder/r1'

    hahn = clusters.yt.Hahn(token=os.environ['YT_TOKEN'])

    job = hahn.job()

    refreshes, debug = job.table(
        table
    ).filter(
        nf.custom(bool, 'ref_yandexuid_hash')
    ).groupby(
        'ref_yandexuid_hash'
    ).reduce(
        get_refreshes
    )

    debug.sort(
        'ref_yandexuid_hash', 'start_time'
    ).put(
        '{}/debug'.format(job_root)
    )

    refreshes.groupby(
        'os_family', 'browser_name', 'provider', 'error'
    ).aggregate(
        refreshes=na.sum('refreshes'),
        sessions=na.sum('sessions')
    ).groupby(
        'os_family', 'browser_name', 'provider'
    ).reduce(
        add_shares
    ).project(
        ne.all(), fielddate=ne.const('2018-06-14')
    ).sort(
        'os_family', 'browser_name', 'provider', 'error'
    ).put(
        '{}/refreshes_non_schematic'.format(job_root)
    )

    job.run()

    try:
        job = hahn.job()

        job.table(
            '{}/refreshes_non_schematic'.format(job_root)
        ).put(
            '{}/refreshes_schematic'.format(job_root),
            schema=dict(
                fielddate=str,
                os_family=str,
                browser_name=str,
                provider=str,
                error=str,
                refreshes=int,
                sessions=int,
                error_share=float
            )
        )

        job.run()
    except Exception as e:
        print(e)

    client = ns.StatfaceClient(
        proxy='upload.stat.yandex-team.ru',
        username=os.environ['STAT_LOGIN'],
        password=os.environ['STAT_TOKEN']
    )

    ns.StatfaceReport().path(
        'Video/Others/Strm/refreshes'
    ).scale('daily').replace_mask(
        'fielddate'
    ).client(
        client
    ).data(
        [
            x.to_dict() for x in hahn.read(
                '{}/refreshes_schematic'.format(job_root)
            )
        ]
    ).publish()


if __name__ == "__main__":
    main()
