#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
import datetime
import hashlib
import math
import json
from nile.api.v1 import clusters, with_hints, modified_schema, Record
import nile.files as nfi
from qb2.api.v1 import typing as qt, resources as qr
from pytils import (
    get_cluster,
    get_driver,
    date_range,
    yt_get_date_from_table as get_date,
    apply_replacements,
)
from yql.api.v1.client import YqlClient


ASNAME_TABLE = (
    "//home/search-research/ensuetina/AS_MAP/proper_AS_names_corrected"
)
TITLE = "Sessions2 | YQL"
STRM_ROOT = "//logs/strm-access-log/1d"
default_root = "//home/videoquality/vh_analytics/strm_video_nc"


def try_parse_json(s):
    try:
        return json.loads(s)
    except ValueError:
        return {}


def get_error_info(ystruct_):
    ystruct = {}
    for field in (
        "timestamp",
        "video_content_id",
        "error_id",
        "error_details",
        "is_fatal",
    ):
        ystruct[field] = getattr(ystruct_, field)
    try:
        ystruct["details"] = json.loads(ystruct["error_details"])
    except (ValueError, TypeError):
        ystruct["details"] = {"unparsed": ystruct["error_details"]}
    return ystruct


def process_errors(rec, result):
    for elem_ in rec.get("errors_all") or []:
        if len(result["errors_all"]) >= 200:
            break
        elem = get_error_info(elem_)
        result["errors_all"].append(elem)
        vcid = elem["video_content_id"]
        eid = elem["error_id"]
        dct = result["error_content_ids_ts"]
        if vcid not in dct:
            dct[vcid] = {}
        if eid not in dct[vcid]:
            dct[vcid][eid] = []
        dct[vcid][eid].append(elem["timestamp"])
        dct = result["error_content_ids"]
        if vcid not in dct:
            dct[vcid] = {}
        if eid not in dct[vcid]:
            dct[vcid][eid] = 0
        dct[vcid][eid] += 1


def get_view_type(timestamp, request_ts, live_timeout=30):
    return (
        "live"
        if abs((timestamp or 0) - (request_ts or 0)) < live_timeout
        else "dvr"
    )


def get_view_session(rec, result):
    view_time = rec["end_ts"] - rec["start_ts"]
    vcid = rec["video_content_id"] or "other"
    rf = rec["ref_from"] or "other"
    log_view_time = math.log(view_time) if view_time > 30 else 0
    start_dt = datetime.datetime.fromtimestamp(rec["start_ts"])
    view_date = start_dt.date()
    view_type = rec.get("view_type_from_request") or get_view_type(
        rec["start_ts"], rec["start_ts_from_request"]
    )
    bytes_sent = rec.get("bytes_sent") or 0
    tcpinfo_total_retrans = rec.get("tcpinfo_total_retrans") or 0
    result["bytes_sent"] += bytes_sent
    result["tcpinfo_total_retrans"] += tcpinfo_total_retrans
    result["chunks_count"] += rec["chunks_count"]
    result["duration"] += view_time
    return {
        "_key": "{vd}|{rf}|{vcid}|{vt}".format(
            vd=view_date, rf=rf, vcid=vcid, vt=view_type
        ),
        "_start": rec["start_ts"],
        "bytes_sent": bytes_sent,
        "chunks_count": rec["chunks_count"],
        "content_id": vcid,
        "log_view_duration": log_view_time,
        "page_id": rec.get("ref_partner_id") or "",
        "path": rec.get("path") or "",
        "channel_type": rec.get("channel_type") or "",
        "program_duration": 0,
        "request_ts": rec.get("start_ts_from_request") or 0,
        "tcpinfo_total_retrans": tcpinfo_total_retrans,
        "video_category_id": rec.get("category_id") or "",
        "video_content_id": vcid,
        "video_content_name": rec.get("ref_video_content_name") or "",
        "computed_program": rec.get("computed_program") or "UNKNOWN",
        "view_channel": rec.get("computed_channel") or "NO_CHANNEL",
        "view_channel_old": rec.get("channel_from_request") or "NO_CHANNEL",
        "view_date": str(view_date),
        "view_duration": view_time,
        "view_type": view_type,
    }


def get_as(ip, ip_origins):
    if ip:
        try:
            asn = ip_origins.region_by_ip(ip)
        except ValueError:
            asn = ["-"]
        return asn or ["-"]
    return ["-"]


def parse_as(as_, asname_dict):
    result = set()
    for as__ in as_ or []:
        result.add(asname_dict.get(as__, "other"))
    if result != {"other"}:
        result -= {"other"}
    if result != {"-"}:
        result -= {"-"}
    return sorted(result)


def get_provider(ip, asname_dict, ip_origins):
    try:
        return parse_as(get_as(ip, ip_origins), asname_dict)[0]
    except IndexError:
        return "unknown"


def get_hash(yandexuid, salt="e0440ebc0786e3d2cff6ef51319bc226"):
    md5 = hashlib.md5(yandexuid + salt)
    return md5.hexdigest()


def parse_slots(slots):
    if (slots is None) or (slots == ""):
        return []
    try:
        return [pair.split(",")[0] for pair in slots.split(";")]
    except (ValueError, IndexError):
        raise Exception(repr(slots))


class Sessions2Reduce(object):
    def __init__(self, asname_dict):
        self.asname_dict = asname_dict

    def __call__(self, groups):
        ip_origins = qr.get("IpOrigins")

        for key, recs in groups:
            result = key.to_dict()
            if not result["ref_from"]:
                result["ref_from"] = "other"
            result["errors_all"] = []
            result["error_content_ids_ts"] = {}
            result["error_content_ids"] = {}
            result["view_session"] = []
            result["bytes_sent"] = 0
            result["tcpinfo_total_retrans"] = 0
            result["chunks_count"] = 0
            result["duration"] = 0

            for rec in recs:
                for column in (
                    "ip",
                    "browser_version",
                    "browser_name",
                    "ref_partner_id",
                    "ref_yandexuid_hash",
                    "ref_from_block",
                    "ref_yandexuid",
                    "ref_reqid",
                    "slots",
                    "os_family",
                    "reg",
                ):
                    if rec.get(column) and not result.get(column):
                        result[column] = rec[column]
                process_errors(rec, result)
                result["view_session"].append(get_view_session(rec, result))

            for column in ("channel_type", "path"):
                result[column] = result["view_session"][0][column]
            result["channel"] = (
                result["view_session"][0].get("view_channel") or "NO_CHANNEL"
            )

            result["date"] = result["view_session"][0]["view_date"]
            result["start_time"] = result["view_session"][0]["_start"]
            result["end_time"] = (
                result["view_session"][-1]["_start"]
                + result["view_session"][-1]["view_duration"]
            )
            result["slots_arr"] = parse_slots(result.pop("slots", ""))
            result["start_hour"] = "{:02}".format(
                datetime.datetime.fromtimestamp(result["start_time"]).hour
            )
            result["yandexuid"] = result.pop("ref_yandexuid", "")
            result["partner_id"] = result.pop("ref_partner_id", "")
            if not result.get("ref_yandexuid_hash"):
                result["ref_yandexuid_hash"] = (
                    get_hash(result["yandexuid"]) or ""
                )
            result["session_key"] = "{}_{}".format(
                result["start_time"], result["vsid"]
            )
            result["provider"] = get_provider(
                result["ip"], self.asname_dict, ip_origins
            )

            yield Record.from_dict(result)


def process_date(
    date, cluster, pool=None, root=default_root, yql_client=None, redo=False
):
    pool = pool or "tmp"
    input_table = "{}/{}".format(STRM_ROOT, date)
    if date >= datetime.date(2019, 3, 15):
        errors_input_table = input_table.replace(
            "strm-access-log", "strm-player-access-log"
        )
        error_referer = '_other{"referer"}'
    else:
        errors_input_table = input_table
        error_referer = "referer"
    chunks_table = "{}/{}/chunks".format(root, date)
    sessions_table = "{}/{}/sessions".format(root, date)

    asname_dict = {
        rec.ASN: rec.ISP for rec in get_driver(cluster).read(ASNAME_TABLE)
    }

    t1 = datetime.datetime.now()
    if not get_driver(cluster).exists(chunks_table) or redo:
        with open("sessions2.sql") as f:
            query = f.read()
        base_replacements = [
            ("@input_table", input_table),
            ("@errors_input_table", errors_input_table),
            ("@error_referer", error_referer),
            ("@output_table", chunks_table),
            ("@pool", pool),
        ]
        query = apply_replacements(query, base_replacements)
        req = yql_client.query(query, title=TITLE)
        t = datetime.datetime.now()
        print("running query")
        req.run()
        req.wait_progress()

        t1 = datetime.datetime.now()
        print(
            "query took {:.01f} minutes".format((t1 - t).total_seconds() / 60)
        )

    t = t1
    print("running reduce job")
    job = cluster.job()

    sessions2_reduce = with_hints(
        output_schema=modified_schema(
            exclude=[
                "channel_from_request",
                "computed_channel",
                "computed_program",
                "end_ts",
                "end_ts_from_request",
                "errors_all",
                "ref_partner_id",
                "ref_source",
                "ref_video_content_name",
                "ref_yandexuid",
                "slots",
                "start_ts",
                "start_ts_from_request",
                "video_content_id",
                "view_type_from_request",
                "yandexuid_from_cookies",
            ],
            extend=dict(
                channel=str,
                date=str,
                duration=int,
                end_time=int,
                error_content_ids=qt.Json,
                error_content_ids_ts=qt.Json,
                errors_all=qt.Json,
                partner_id=str,
                provider=str,
                session_key=str,
                slots_arr=qt.Json,
                start_hour=str,
                start_time=int,
                view_session=qt.Json,
                yandexuid=str,
            ),
        )
    )(Sessions2Reduce(asname_dict))

    (
        job.table(chunks_table)
        .groupby("ref_from", "vsid")
        .sort("start_ts")
        .reduce(sessions2_reduce, files=[nfi.StatboxDict("IpOriginV6.xml")])
        .put(sessions_table)
    )

    job.run()

    t1 = datetime.datetime.now()
    print(
        "reduce job took {:.01f} minutes".format((t1 - t).total_seconds() / 60)
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--report", default="Video/Others/Strm/imp_id_fails")
    parser.add_argument("--query", default="mma_1913_query.yql")
    parser.add_argument("--redo", action="store_true")
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--title", default=TITLE)
    parser.add_argument("--root", default=None)
    parser.add_argument("--to", default=None)
    args = parser.parse_args()

    if args.root:
        root = args.root
    else:
        root = default_root

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        processed_dates = sorted(
            get_date(s)
            for s in get_driver(cluster).client.search(
                root=default_root,
                node_type="table",
                attributes=["row_count"],
                path_filter=lambda x: (x or "").endswith("sessions"),
                object_filter=lambda x: x.attributes["row_count"] > 0,
            )
            if get_date(s)
        )
        print("last date: {}".format(processed_dates[-1]))

        available_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//logs/strm-access-log/1d", path_filter=get_date
            )
        )
        print("last available date: {}".format(available_dates[-1]))

        dates = [x for x in available_dates if x > processed_dates[-1]]

    for date in dates:
        print("processing {}".format(date))
        process_date(
            date, cluster, pool=args.pool, yql_client=yql_client, root=root,
            redo=args.redo
        )


if __name__ == "__main__":
    main()
