#!/usr/bin/env python
# -*- coding: utf-8 -*-

import nile
from nile.api.v1 import (
    clusters,
    Record,
    with_hints,
    aggregators as na,
    grouping as ng,
    filters as nf,
    extractors as ne,
    statface as ns
)

from pytils import date_range

from datetime import datetime as dt, timedelta

import json
import requests
import re
import urllib
import numpy as np
import itertools
import random
import argparse
import string
import os
import sys
import time
import math
import datetime
from collections import defaultdict, Counter

# где будут жить логи на кластере
job_root = '//home/videolog/strm_video'
cluster = None


def get_stat_headers():
    return {
        'StatRobotUser': os.environ['STAT_LOGIN'],
        'StatRobotPassword': os.environ['STAT_TOKEN']
    }


def merge_names(a, b):
    if a and b and a != b:
        return normalize_name(b)
    elif (not a) and b:
        return normalize_name(b)
    return normalize_name(a)


def merge_program_title(x):
    pt = x.get('program_title', '').decode('utf8')
    t = x.get('title', '').decode('utf8')
    if pt and t:
        return u'{}. {}'.format(pt, t)
    return t


def transform_programs_value(dct):
    return {
        x['content_id'].decode('utf8'): truncate_string(
            merge_program_title(x)
        ) for x in dct['programms']
    }


def read_programs_table(table):
    obj = list(cluster.read(table))[0].to_dict()
    obj = obj['programs']
    obj = {k: transform_programs_value(v) for k, v in obj.items()}
    return obj


def add_shares(groups):
    for key, records in groups:
        before_total = []
        total = None
        for rec in records:
            if rec.error == '_total_':
                total = float(rec['sessions'])
                res = rec.to_dict()
                res['error_share'] = 1
                yield Record(**res)
            elif total is None:
                before_total.append(rec.to_dict())
            else:
                res = rec.to_dict()
                res['error_share'] = round(
                    rec['sessions'] / total, 3
                )
                yield Record(**res)
        for rec in before_total:
            try:
                rec['error_share'] = round(rec['sessions'] / total, 4)
            except:
                raise Exception((json.dumps(rec), total))
            yield Record(**rec)


class SessionsMapper(object):

    def __init__(self, date):
        self.date = date

    def _get_p(self, channel, content_id, content_video_title):
        if content_video_title:
            return content_video_title
        elif self.programs.get(channel, {}).get(content_id):
            return normalize_name(
                self.programs.get(channel, {}).get(content_id) or ''
            )
        return "UNKNOWN"

    def __call__(self, records):
        for rec in records:
            if rec['date'] != self.date:
                continue
            # cntr = defaultdict(lambda: Counter())
            view_channels = sorted(
                {x['view_channel'] for x in rec['view_session']}
            )
            view_channels.append('_total_')
            vs = rec['view_session']
            testids = [x for x in rec.get('slots_arr', []) if x]
            if not testids:
                testids.append('_no_testids_')
            testids.append('_total_')
            tvt = sum([x['view_duration'] for x in vs])
            lvt = sum(
                [x['log_view_duration'] for x in vs]
            )
            for comb in itertools.product(
                (rec['os_family'], '_total_'),
                (rec['ref_from'], '_total_'),
                testids
            ):
                vrs = dict(
                    fielddate=rec['date'],
                    os_family=comb[0],
                    ref_from=comb[1],
                    testid=comb[2],
                    tvt=tvt,
                    lvt=lvt,
                    sessions=1,
                    yandexuid=rec['ref_yandexuid_hash']
                )
                yield Record(**vrs)


def extract_parsed_as(field):
    field = field or []
    if len(field) == 1:
        return field[0]
    else:
        return "UNKNOWN"


def deutf8(dct):
    for field in dct:
        if isinstance(dct[field], str):
                dct[field] = dct[field].decode('utf8')
    return dct


BAD_SHIT = [
    u'част[ьи](?=[^а-я])', u'эпизод(?=[^а-я])', u'сери[яи](?=[^а-я])', u'выпуск(?=[^а-я])', u'сезон(?=[^а-я])',
    u'[0-9]+ (январ|феврал|март|апрел|ма|июн|июл|август|сентябр|октябр|ноябр|декабр).? [0-9]+( года)?',
    u'[0-9]{1,2}\.[0-9]{1,2}\.[0-9]{4}',
    u'[0-9]{2}:[0-9]{2}'
]
NUMBER_PREFIX = u'([0-9]+(-(ы?й|[ая]?я))? [-–—и]? ?)?[0-9]+(-(ы?й|[ая]?я))? '
NUMBER_POSTFIX = u' [0-9]+(-(ы?й|[ая]?я))?'


def normalize_name(name):
    name = name or u''
    if not isinstance(name, unicode):
        name = name.decode('utf8', errors='replace')
    name = name.strip()
    name = re.sub(u'["«»“”]', u'', name)
    for bad_shit in BAD_SHIT:
        name = re.sub(
            NUMBER_PREFIX + bad_shit, u'', name, flags=(re.I | re.U)
        )
        name = re.sub(
            bad_shit + NUMBER_POSTFIX, u'', name, flags=(re.I | re.U)
        )
        name = re.sub(
            bad_shit, u'', name, flags=(re.I | re.U)
        )
    name = truncate_string(name)
    name = intelligent_strip(name)
    name = till_first_punctuation(name)
    return name or u'BAD_NAME'


def intelligent_strip(name):
    alphanums = list(re.finditer(u'[а-яa-zё]', name, flags=(re.I | re.U)))
    if not alphanums:
        return u"BAD_NAME"
    return name[alphanums[0].span()[0]:alphanums[-1].span()[0] + 1]


def truncate_string(st_, thr=30):
    if len(st_) <= thr:
        return st_
    sp = st_.split()
    for x in list(range(len(sp)))[::-1]:
        jnd = u' '.join(sp[:x])
        if len(jnd) <= thr:
            return jnd
    return u'{}...'.format(st_[:thr - 3])


def till_first_punctuation(name):
    preps = u'[,\.\?!]'
    for x in re.finditer(preps, name):
        return name[:x.span()[0]]
    return name


def make_testid_report(date, report):
    chunks_table = '{}/{}/chunks'.format(job_root, date)
    sessions_table = '$job_root/%s/sessions' % date
    programs_table = '$job_root/%s/programs' % date
    prov_stats_table = '{}/{}/providers_stats'.format(job_root, date)
    cube_report_table = '$job_root/%s/testids_report' % date
    errors_table_for_join = '$job_root/%s/errors_for_join' % date


    key_fields = [
        'fielddate', 'os_family', 'ref_from',
        'testid'
    ]

    job = cluster.job()

    job.table(
        sessions_table
    ).map(
        SessionsMapper(format(date)), intensity='ultra_cpu'
    ).groupby(
        *key_fields
    ).aggregate(
        tvt=na.sum('tvt'),
        lvt=na.sum('lvt'),
        sessions=na.sum('sessions'),
        users=na.count_distinct_estimate('yandexuid'),
    ).sort(
        'fielddate'
    ).put(
        cube_report_table
    )

    job.run()

    client = ns.StatfaceClient(
        proxy='upload.stat.yandex-team.ru',
        username=os.environ['STAT_LOGIN'],
        password=os.environ['STAT_TOKEN']
    )

    result = []
    for rec in cluster.read(cube_report_table):
        try:
            du = deutf8(rec.to_dict())
            if du:
                result.append(du)
        except:
            continue

    ns.StatfaceReport().path(
        report
    ).scale('daily').replace_mask(
        'fielddate'
    ).client(
        client
    ).data(
        # [x.to_dict() for x in cluster.read(cube_report_table)]
        result
    ).publish()


def get_date(s):
    try:
        return datetime.datetime.strptime(
            re.search(r'[0-9]{4}-[0-9]{2}-[0-9]{2}', s).group(0),
            '%Y-%m-%d'
        ).date()
    except (ValueError, TypeError, AttributeError):
        return


def main():
    global cluster
    global job_root
    parser = argparse.ArgumentParser()
    parser.add_argument('--pool')
    parser.add_argument('--parallel_operations_limit', type=int, default=10)
    parser.add_argument('--from', default=None)
    parser.add_argument('--job_root', default=None)
    parser.add_argument('--report', default='Video/Others/Strm/testids')
    parser.add_argument('--to', default=None)
    args = parser.parse_args()

    if args.job_root:
        job_root = args.job_root

    from_ = getattr(args, 'from')
    to_ = getattr(args, 'to')

    cluster = clusters.yt.Hahn(token=os.environ['YT_TOKEN'], pool=args.pool).env(
        templates=dict(
            job_root=job_root
        ),
        parallel_operations_limit=args.parallel_operations_limit
    )

    if from_ and to_:
        dates_to_process = date_range(from_, to_)
    else:
        report = 'Video/Others/Strm/testids'
        headers = get_stat_headers()
        print('getting dates from report')
        dimensions = [
            'testid', 'os_family', 'ref_from'
        ]
        dim_totals = '&'.join(
            '{}=_total_'.format(x) for x in dimensions
        )
        req = requests.get(
            'https://upload.stat.yandex-team.ru/{}?{}&_type=json'.format(
                report, dim_totals
            ),
            headers=headers, verify=False
        )
        print('parsing response')

        values = sorted(
            req.json()['values'], key=lambda x: x['fielddate'], reverse=True
        )
        last_date = get_date(
            values[0]['fielddate'].split(' ')[0]
        )

        print('last date: {}'.format(last_date))

        available_dates = sorted(
            get_date(s) for s in cluster.driver.client.search(
                root=job_root, node_type="table", path_filter=(
                    lambda x: x.endswith('/sessions')
                )
            ) if get_date(s) and get_date(s) > datetime.date(2018, 3, 14)
        )

        print('last available date: {}'.format(available_dates[-1]))

        if last_date:
            dates_to_process = [
                x for x in available_dates if x > last_date
            ]
        else:
            dates_to_process = available_dates

    print('dates to process: {}'.format(dates_to_process))

    for date in dates_to_process:
        print('running for {}'.format(date))
        make_testid_report(date=format(date), report=args.report)


if __name__ == '__main__':
    main()
