#-*- coding: UTF-8 -*-
FUTURE_DAYS_COUNT = 7
PAST_DAYS_COUNT = 28

import requests
import nile
import argparse
import time
from nile.api.v1 import (
    filters as nf,
    aggregators as na,
    extractors as ne,
    statface as ns,
    clusters,
    Record
)
from qb2.api.v1 import (
    extractors as se,
    filters as sf
)
import yt.wrapper as yt
from datetime import datetime
from datetime import timedelta
import json

class calc_user_stats:
    def __call__(self, groups):
        import libra
        for key, recs in groups:
            uid = key.key
            corrected_uid = uid
            if "uu/" in uid:
                corrected_uid = uid[3:]
            if "y" in uid:
                corrected_uid = uid[1:]
            try:
                s = libra.ParseSessionWithFat(recs, './blockstat.dict')
            except:
                continue
            watched_objects = []
            video_request_count = 0
            web_request_count = 0
            video_morda_request_count = 0
            film_request_count = 0
            serial_request_count = 0
            tvt = 0
            watched_objects_ids = []
            for r in s:
                if r.IsA("TWebRequestProperties"):
                    web_request_count += 1
                if not r.IsA("TVideoRequestProperties"):
                    continue
                video_request_count +=1
                if r.IsA("TVideoPortalRequestProperties"):
                    video_morda_request_count += 1
                if not r.IsA("TMiscRequestProperties"):
                    continue
                if r.SearchPropsValues.get("REPORT.object_type", "") == "Film" or r.SearchPropsValues.get("REPORT.object_id", "") == "Film":
                    film_request_count += 1
                if "vserial" in r.RelevValues:
                    serial_request_count += 1
                if not r.SearchPropsValues.get("REPORT.object_type", "") == "Film" or r.SearchPropsValues.get("REPORT.object_id", "") == "":
                    continue
                for block in r.GetMainBlocks():
                    result = block.GetMainResult()
                    duration = r.FindVideoDurationInfo(result)
                    heartbeat = r.FindVideoHeartbeat(result, 'ANY')
                    url_tvt = 0
                    if duration and duration.PlayingDuration > 45:
                        url_tvt = min(duration.PlayingDuration, duration.Duration)
                    if heartbeat and heartbeat.Ticks > 45:
                        url_tvt = max(url_tvt, heartbeat.Ticks)
                    tvt += url_tvt
                    if url_tvt > 45:
                        if not r.SearchPropsValues.get("REPORT.object_id", "") in watched_objects_ids:
                            watched_objects_ids.append(r.SearchPropsValues.get("REPORT.object_id", ""))
                            watched_objects.append({"object_id" : r.SearchPropsValues.get("REPORT.object_id", ""),
                                                    "object_type" : r.SearchPropsValues.get("REPORT.object_type", ""),
                                                    "object_subtype" : r.SearchPropsValues.get("REPORT.object_subtype", ""),
                                                    "object_tags" : r.SearchPropsValues.get("REPORT.object_tags", ""),
                                                    "winner_types" : r.SearchPropsValues.get("REPORT.winner_types", "")})
            if video_request_count > 0:
                yield Record(watched_objects=watched_objects, uid=uid,
                             corrected_uid=corrected_uid,
                             video_request_count=video_request_count,
                             web_request_count=web_request_count,
                             video_morda_request_count=video_morda_request_count,
                             film_request_count=film_request_count,
                             serial_request_count=serial_request_count,
                             tvt=tvt)

class user_stats_aggregator(object):
    def __init__(self, watched_objects_column_name="watched_objects"):
        self.watched_objects_column_name = watched_objects_column_name
    def __call__(self, groups):
        for key, recs in groups:
            watched_objects = []
            watched_objects_ids = []
            video_request_count = 0
            web_request_count = 0
            video_morda_request_count = 0
            film_request_count = 0
            serial_request_count = 0
            tvt = 0
            for rec in recs:
                video_request_count += rec.get("video_request_count", 0)
                web_request_count += rec.get("web_request_count", 0)
                video_morda_request_count += rec.get("video_morda_request_count", 0)
                film_request_count += rec.get("film_request_count", 0)
                serial_request_count += rec.get("serial_request_count", 0)
                tvt += rec.get('tvt', 0)
                for object_info in rec["watched_objects"]:
                    if not object_info["object_id"] in watched_objects_ids:
                        watched_objects.append(object_info)
                        watched_objects_ids.append(object_info["object_id"])
            d = {self.watched_objects_column_name : watched_objects}
            yield Record(uid=key["corrected_uid"],
                         video_request_count=video_request_count,
                         web_request_count=web_request_count,
                         video_morda_request_count=video_morda_request_count,
                         film_request_count=film_request_count,
                         serial_request_count=serial_request_count,
                         tvt=tvt, **d)

def recommendations_reformatter(recs):
    import json
    type_mapper = {0 : "CATEG_FILM",
                   1 : "CATEG_SERIES",
                   2 : "CATEG_ANIM_FILM",
                   3 : "CATEG_ANIM_SERIES",
                   4 : "CATEG_TV_SHOW",
                   5 : "CATEG_MIXED",
                   6 : "CATEG_WATCHED"}
    for rec in recs:
        a = json.loads(rec["value"])
        results = []
        for category_info in a["netflix"]["categories"]:
            predicted_objects = []
            for query_info in category_info["queries"]:
                predicted_objects.append({"object_id" : query_info["ontoid"], "thumb_url" : query_info["thumb_url"], "title" : query_info["title"]})
            results.append({"predicted_objects" : predicted_objects, "type" : type_mapper[category_info["type"]]})
        uids = []
        if "cid_" in rec["key"]:
            for uid in rec["aliases"]:
                uids.append(uid)
        else:
            uids.append(rec["key"])
        for uid in uids:
            yield Record(uid=uid, results=json.dumps(results))

class prediction_quality_calculator(object):
    def __init__(self, cold_start, cold_start_mode=False):
        self.cold_start = cold_start
        self.cold_start_mode = cold_start_mode
    def __call__(self, groups):
        def get_category(object_info, objects_data):
            category = ""
            if object_info["object_id"] in objects_data:
                if objects_data[object_info["object_id"]] in ['OTYPE_FILM', 'OTYPE_FILM_FILM', 'OTYPE_FILM_FILM_SERIES']:
                    if "Animation" in object_info["object_tags"]:
                        category = "CATEG_ANIM_FILM"
                    else:
                        category = "CATEG_FILM"
                elif objects_data[object_info["object_id"]] in ['OTYPE_FILM_SERIES']:
                    if "Animation" in object_info["object_tags"]:
                        category = "CATEG_ANIM_SERIES"
                    else:
                        category = "CATEG_SERIES"
            if object_info["object_tags"] == "TVprogram" and not "Animation" in object_info["object_tags"]:
                category = "CATEG_TV_SHOW"
            return category

        def is_categories_equal(category_first, category_second):
            return category_first == category_second or category_first == "CATEG_MIXED" or category_second == "CATEG_MIXED"

        import json
        objects_data = json.loads(open("objects_data.json", "r").read())
        for key, recs in groups:
            have_cold_start=False
            watched_objects = []
            predicted_objects = []
            watched_objects_by_last_4_weeks = []
            for rec in recs:
                if "watched_objects" in rec:
                    watched_objects = rec["watched_objects"]
                if "results" in rec:
                    predicted_objects = json.loads(rec["results"])
                if "watched_objects_by_last_4_weeks" in rec:
                    watched_objects_by_last_4_weeks = rec["watched_objects_by_last_4_weeks"]
            if self.cold_start_mode or len(predicted_objects) == 0:
                predicted_objects = json.loads(self.cold_start)
                have_cold_start = True
            if len(watched_objects) == 0 or len(predicted_objects) == 0:
                continue
            for category_info in predicted_objects:
                category_name = category_info["type"]
                succesful_predicted = 0
                succesful_predicted_top_8 = 0
                succesful_predicted_top_20 = 0
                intersected_object_ids = []
                intersected_object_infos = []
                intersected_past_object_infos = []
                for object_info in watched_objects:
                    object_category = get_category(object_info, objects_data)
                    if is_categories_equal(category_name, object_category):
                        intersected_object_ids.append(object_info["object_id"])
                        intersected_object_infos.append(object_info)
                for object_info in watched_objects_by_last_4_weeks:
                    object_category = get_category(object_info, objects_data)
                    if is_categories_equal(category_name, object_category):
                        intersected_past_object_infos.append(object_info)
                for i, category_object in enumerate(category_info["predicted_objects"]):
                    for object_info in watched_objects:
                        if object_info["object_id"] == category_object["object_id"]:
                            if not object_info["object_id"] in intersected_object_ids:
                                intersected_object_ids.append(object_info["object_id"])
                                intersected_object_infos.append(object_info)
                            succesful_predicted += 1
                            if i < 8:
                                succesful_predicted_top_8 += 1
                            if i < 20:
                                succesful_predicted_top_20 += 1
                intersection_count = len(intersected_object_ids)
                if intersection_count > 0:
                    yield Record(uid=key['uid'], watched_objects=intersected_object_infos, watched_object_by_last_4_weeks=intersected_past_object_infos,
                                 predicted_objects=predicted_objects, have_cold_start=have_cold_start, type=category_name,
                                 recall=succesful_predicted / float(intersection_count), precision=succesful_predicted / float(len(category_info["predicted_objects"])),
                                 recall_top_8=succesful_predicted_top_8 / float(intersection_count), precision_top_8=succesful_predicted_top_8 / 8.,
                                 recall_top_20=succesful_predicted_top_20 / float(intersection_count), precision_top_20=succesful_predicted_top_20 / 20.)

def metrics_aggregator(groups):
    for key, recs in groups:
        recall_sum = 0.
        precision_sum = 0.
        recall_top_8_sum = 0.
        precision_top_8_sum = 0.
        recall_top_20_sum = 0.
        precision_top_20_sum = 0.
        count = 0.
        for rec in recs:
            count += 1
            recall_sum += rec["recall"]
            precision_sum += rec["precision"]
            recall_top_8_sum += rec["recall_top_8"]
            precision_top_8_sum += rec["precision_top_8"]
            recall_top_20_sum += rec["recall_top_20"]
            precision_top_20_sum += rec["precision_top_20"]
        yield Record(type=key["type"],
                     recall=recall_sum / count, precision=precision_sum / count,
                     recall_top_8=recall_top_8_sum / count, precision_top_8=precision_top_8_sum / count,
                     recall_top_20=recall_top_20_sum / count, precision_top_20=precision_top_20_sum / count)

def grep_cold_start(recs):
    for rec in recs:
        if rec.get('key') == '0':
            yield Record(value=rec['value'], key='0')

RECOMMENDATIONS_PREFIX = '//home/videoindex/recommender/backup/vitrina'
RECOMMENDATIONS_SUFFIX = '/recommendations_merged.'
RECOMMENDATIONS_VERSIONS = ["filter_none", "filter_basic", "filter_family", "filter_tv_app", "filter_vh", "filter_cold_start"]
RECOMMENDATIONS_SUFFIX_SUFFIX = '.json'
AGGREGATED_FUTURE_WATCHED_OBJECTS_TABLE = '//home/videolog/msvvitaly/watched_objects_by_user/watched_objects_by_next_week'
AGGREGATED_PAST_WATCHED_OBJECTS_TABLE = '//home/videolog/msvvitaly/watched_objects_by_user/watched_objects_by_last_4_weeks'
REFORMATTED_RECOMMENDATIONS_TABLE = '//home/videolog/msvvitaly/watched_objects_by_user/reformatted_recommendations'
USER_SESSIONS_PREFIX = '//user_sessions/pub/search/daily/'
USER_SESSIONS_SUFFIX = '/clean'
USER_SESSIONS_STAFF_SUFFIX = '/yandex_staff'
PROCESSED_TABLES_PREFIX = '//home/videolog/msvvitaly/watched_objects_by_user/'

def main():
    one_day = timedelta(days=1)
    parser = argparse.ArgumentParser()
    parser.add_argument('--token')
    args = parser.parse_args()

    yt_client = yt.YtClient('banach')
    if args.token:
      yt_client.config['token'] = args.token

    kwargs = {}
    if args.token:
        kwargs['token'] = args.token
    cluster = clusters.Banach(**kwargs)

    objects_data = {}
    for rec in yt_client.read_table('//home/videoindex/recommender/prod/ontodb/onto_films.latest'):
        objects_data[rec["OntoId"]] = rec["ObjectType"]
    f = open("objects_data.json", "w")
    f.write(json.dumps(objects_data, indent=4))
    f.close()
    for table in sorted(yt_client.list(RECOMMENDATIONS_PREFIX)):
        print table
        try:
            date = datetime.strptime(table, "%Y%m%d")
            current_date = date
        except:
            continue

        print "Check if can calc."

        can_calc = True
        for i in range(FUTURE_DAYS_COUNT):
            current_date += one_day
            current_user_session_table = USER_SESSIONS_PREFIX + current_date.strftime("%Y-%m-%d") + USER_SESSIONS_SUFFIX
            current_user_session_staff_table = USER_SESSIONS_PREFIX + current_date.strftime("%Y-%m-%d") + USER_SESSIONS_STAFF_SUFFIX
            if not yt_client.exists(current_user_session_table) or not yt_client.exists(current_user_session_staff_table):
                can_calc = False
                break
        if not can_calc:
            continue
        print "Can calc, will check if need it."
        ### check if need calc
        need_calc = False
        current_date = date
        past_to_concat = []
        for i in range(PAST_DAYS_COUNT):
            job=cluster.job()
            past_to_concat.append(job.table(PROCESSED_TABLES_PREFIX + current_date.strftime("%Y-%m-%d")))
            current_date -= one_day

        current_date = date
        future_to_concat = []
        for i in range(FUTURE_DAYS_COUNT):
            job=cluster.job()
            current_date += one_day
            current_user_session_table = USER_SESSIONS_PREFIX + current_date.strftime("%Y-%m-%d") + USER_SESSIONS_SUFFIX
            current_user_session_staff_table = USER_SESSIONS_PREFIX + current_date.strftime("%Y-%m-%d") + USER_SESSIONS_STAFF_SUFFIX
            if not yt_client.exists(PROCESSED_TABLES_PREFIX + current_date.strftime("%Y-%m-%d")) or \
                yt_client.row_count(PROCESSED_TABLES_PREFIX + current_date.strftime("%Y-%m-%d")) == 0:
                need_calc = True

                calculated = False
                while not calculated:
                    try:
                        print current_user_session_table
                        job = cluster.job()
                        sessions_to_concat = [job.table(current_user_session_table), job.table(current_user_session_staff_table)]
                        job.concat(*sessions_to_concat).groupby('key').sort('subkey').reduce(calc_user_stats(),
                                                                       files=[nile.files.RemoteFile('statbox/statbox-dict-last/blockstat.dict'),
                                                                       nile.files.RemoteFile('statbox/resources/libra.so')
                                                                        ],
                                                                       memory_limit=8000
                                            ).put(PROCESSED_TABLES_PREFIX + current_date.strftime("%Y-%m-%d"))
                        job.run()
                        print "ok"
                        calculated = True
                    except:
                        print "Can't calc, sleeping for 5 minutes..."
                        time.sleep(5 * 60)
            future_to_concat.append(job.table(PROCESSED_TABLES_PREFIX + current_date.strftime("%Y-%m-%d")))
        if not need_calc:
            continue
        print "Need calc!"

        ### aggregate watched objects by next week
        job=cluster.job()
        job.concat(*future_to_concat).groupby('corrected_uid').reduce(user_stats_aggregator()).put(AGGREGATED_FUTURE_WATCHED_OBJECTS_TABLE)
        job.run()

        ### aggregate watched objects by last 4 weeks
        job=cluster.job()
        job.concat(*past_to_concat).groupby('corrected_uid').reduce(user_stats_aggregator('watched_objects_by_last_4_weeks')).put(AGGREGATED_PAST_WATCHED_OBJECTS_TABLE)
        job.run()
        print "Joined watched objects by week"

        ### calc cold start
        cold_start_table = '//home/videolog/msvvitaly/watched_objects_by_user/cold_start_' + table
        job = cluster.job()
        us = job.table(RECOMMENDATIONS_PREFIX + "/" + table + RECOMMENDATIONS_SUFFIX + "filter_none" + RECOMMENDATIONS_SUFFIX_SUFFIX)
        us.map(grep_cold_start).map(recommendations_reformatter).put(cold_start_table)
        job.run()
        cold_start = None
        for rec in yt_client.read_table(cold_start_table):
            cold_start = rec['results']

        ### calculate metric for vitrine filters
        version_mapper = {"filter_basic" : "safe", "filter_none" : "normal", "filter_family" : "family", "filter_tv_app" : "tv_app", "filter_vh" : "vh", "filter_cold_start" : "cold_start"}
        for version in RECOMMENDATIONS_VERSIONS:
            print "Calculating offline metric for " + version + " version"
            recommendations_table = RECOMMENDATIONS_PREFIX + "/" + table + RECOMMENDATIONS_SUFFIX + version + RECOMMENDATIONS_SUFFIX_SUFFIX
            is_cold_start_version = False
            if version == "filter_cold_start":
                is_cold_start_version = True
                recommendations_table = RECOMMENDATIONS_PREFIX + "/" + table + RECOMMENDATIONS_SUFFIX + "filter_none" + RECOMMENDATIONS_SUFFIX_SUFFIX
            if not yt_client.exists(recommendations_table):
                continue
            ### reformat recommendations table
            job = cluster.job()
            us = job.table(recommendations_table).random(fraction=0.01)
            us.map(recommendations_reformatter).put(REFORMATTED_RECOMMENDATIONS_TABLE + version)
            job.run()
            ### join recommendations and watched objects and calc metrics
            job=cluster.job()
            to_concat = [job.table(AGGREGATED_FUTURE_WATCHED_OBJECTS_TABLE), job.table(AGGREGATED_PAST_WATCHED_OBJECTS_TABLE), job.table(REFORMATTED_RECOMMENDATIONS_TABLE + version)]
            result = job.concat(*to_concat)
            result.groupby('uid').reduce(prediction_quality_calculator(cold_start, cold_start_mode=is_cold_start_version), memory_limit=8000, files=["objects_data.json"]).\
                   sort('uid').put("//home/videolog/msvvitaly/watched_objects_by_user/user_info_" + table + "_" + version_mapper[version])
            job.run()
            job=cluster.job()
            job.table("//home/videolog/msvvitaly/watched_objects_by_user/user_info_" + table + "_" + version_mapper[version]).\
                      groupby('type').reduce(metrics_aggregator).put("//home/videolog/msvvitaly/watched_objects_by_user/offline_metric_" + table + "_" + version_mapper[version])
            job.run()
            print "Calculation recall metric for " + version + " version"
            job=cluster.job()
            recommendations = job.table(REFORMATTED_RECOMMENDATIONS_TABLE + version)
            joined_users = job.table(AGGREGATED_FUTURE_WATCHED_OBJECTS_TABLE).join(recommendations, by_left="uid", by_right="uid", type='left')
            result = joined_users.aggregate(video_morda_users=na.count(predicate=nf.custom(lambda x: x > 0, 'video_morda_request_count')),
                                            video_morda_users_with_recomms=na.count(predicate=nf.and_(nf.custom(lambda x: x > 0, 'video_morda_request_count'), sf.defined('results'))),
                                            video_users=na.count(predicate=nf.custom(lambda x: x > 0, 'video_request_count')),
                                            video_users_with_recomms=na.count(predicate=nf.and_(nf.custom(lambda x: x > 0, 'video_request_count'), sf.defined('results'))),
                                            web_users=na.count(predicate=nf.custom(lambda x: x > 0, 'web_request_count')),
                                            web_users_with_recomms=na.count(predicate=nf.and_(nf.custom(lambda x: x > 0, 'web_request_count'), sf.defined('results'))),
                                            video_serial_users_with_recomms=na.count(predicate=nf.and_(nf.custom(lambda x: x > 0, 'serial_request_count'), sf.defined('results'))),
                                            video_film_users_with_recomms=na.count(predicate=nf.and_(nf.custom(lambda x: x > 0, 'film_request_count'), sf.defined('results'))),
                                            video_serial_and_film_users_with_recomms=na.count(predicate=nf.and_(nf.or_(nf.custom(lambda x: x > 0, 'serial_request_count'), nf.custom(lambda x: x > 0, 'film_request_count')), sf.defined('results'))),
                                            video_serial_users=na.count(predicate=nf.custom(lambda x: x > 0, 'serial_request_count')),
                                            video_film_users=na.count(predicate=nf.custom(lambda x: x > 0, 'film_request_count')),
                                            video_serial_and_film_users=na.count(predicate=nf.or_(nf.custom(lambda x: x > 0, 'serial_request_count'), nf.custom(lambda x: x > 0, 'film_request_count')))
                                            ).put("//home/videolog/msvvitaly/watched_objects_by_user/recall_metric_" + table + "_" + version_mapper[version])
            job.run()

if __name__ == "__main__":
    main()
