#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import codecs
import argparse
import itertools
import datetime
import json
import pdb
import copy
from sessions import get_programs
from collections import Counter
import math


def srt(dct):
    return sorted(dct, key=lambda x: x['start_time'])


def tts(i):
    dt = datetime.datetime.fromtimestamp(int(i))
    return dt.strftime('%Y-%m-%d %H:%M:%S')


def round_minutes(m):
    return int(math.floor(m / 10.0)) * 10


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--error_list')
    parser.add_argument('--report_data')
    parser.add_argument('--report')
    parser.add_argument('--yt_root', default="//home/videolog/mma-1684")
    args = parser.parse_args()

    datetime_ = datetime.datetime.now()
    datetime_ = datetime_.replace(
        microsecond=0, second=0, minute=round_minutes(datetime_.minute)
    )
    date = datetime_.date()
    date_y = date - datetime.timedelta(days=1)
    date_t = date + datetime.timedelta(days=1)

    p_ = get_programs(date.strftime('%Y-%m-%d'))
    p_y = get_programs(date_y.strftime('%Y-%m-%d'))
    p_t = get_programs(date_t.strftime('%Y-%m-%d'))
    fielddate = datetime_.strftime('%Y-%m-%d %H:%M:%S')
    th_low = int(datetime_.strftime('%s'))
    th_high = int(
        (datetime_ + datetime.timedelta(seconds=86399)).strftime('%s')
    )

    errors = Counter()
    error_list = []

    for channel in p_:
        prs = srt(p_y.get(channel, {}).get('programs') or [])
        prs += srt(p_[channel]['programs'] or [])
        prs += srt(p_t.get(channel, {}).get('programs') or [])
        prs = srt(
            {
                json.dumps(v, sort_keys=True): v for v in copy.deepcopy(prs)
            }.values()
        )
        points = {
            x['start_time'] for x in prs
            if th_low <= x['start_time'] < th_high
        }
        points |= {
            x['end_time'] for x in prs
            if th_low <= x['end_time'] < th_high
        }
        points.add(th_low)
        points.add(th_high)

        points = sorted(points)
        p1 = points[:-1]
        p2 = points[1:]
        segments = list(zip(p1, p2))

        seg_dict = {}
        for seg in segments:
            point = seg[0] + 1
            conflicting_programs = [
                x for x in prs if x['start_time'] <= point < x['end_time']
            ]
            seg_dict[seg] = len(conflicting_programs)
            if len(conflicting_programs) > 1 and args.debug:
                pdb.set_trace()
            if not conflicting_programs:
                error_list.append(
                    {
                        'channel': channel,
                        'error_type': u'gap',
                        'segment': u'{} — {}'.format(tts(seg[0]), tts(seg[1]))
                    }
                )
            elif len(conflicting_programs) > 1:
                error_list.append(
                    {
                        'channel': channel,
                        'error_type': u'conflict',
                        'segment': u'{} — {}'.format(tts(seg[0]), tts(seg[1])),
                        'conflicting_programs': [
                            x['title'] for x in conflicting_programs
                        ]
                    }
                )

        for hours in [1, 2, 4, 8, 24]:
            gaps = len([
                x for x in seg_dict if (
                    th_low <= x[0] <= th_low + hours * 3600
                ) and seg_dict[x] == 0
            ])
            for comb in itertools.product(
                [channel, '_total_'],
                ['gap', '_total_']
            ):
                errors[
                    (
                        fielddate,
                        comb[0],
                        comb[1],
                        '{:02}hrs'.format(hours)
                    )
                ] += gaps
            conflicts = len([
                x for x in seg_dict if (
                    th_low <= x[0] <= th_low + hours * 3600
                ) and seg_dict[x] > 1
            ])
            for comb in itertools.product(
                [channel, '_total_'],
                ['conflict', '_total_']
            ):
                errors[
                    (
                        fielddate,
                        comb[0],
                        comb[1],
                        '{:02}hrs'.format(hours)
                    )
                ] += conflicts

    report_data = []
    for x in errors:
        res = dict(zip(['fielddate', 'channel', 'error_type', 'timeframe'], x))
        res['error_count'] = errors[x]
        report_data.append(res)

    json.dump(
        error_list, codecs.open(args.error_list, 'w', 'utf8'),
        ensure_ascii=False, sort_keys=True, indent=4
    )
    json.dump(
        report_data, codecs.open(args.report_data, 'w', 'utf8'),
        ensure_ascii=False, sort_keys=True, indent=4
    )

    if args.report:
        from nile.api.v1 import statface as ns
        client = ns.StatfaceClient(
            proxy='upload.stat.yandex-team.ru',
            username=os.environ['STAT_LOGIN'],
            password=os.environ['STAT_TOKEN']
        )

        ns.StatfaceReport().path(
            args.report
        ).scale('minutely').replace_mask(
            'fielddate'
        ).client(
            client
        ).data(
            report_data
        ).publish()
    if args.yt_root:
        from nile.api.v1 import clusters, Record
        hahn = clusters.yt.Hahn(token=os.environ['YT_TOKEN'])
        hahn.write(
            path='{}/{}'.format(
                args.yt_root, datetime_.strftime('%Y-%m-%dT%H:%M:%S')
            ),
            records=[Record(**x) for x in error_list]
        )


if __name__ == "__main__":
    main()
