#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    statface as ns,
    with_hints,
    Record,
)
import nile.files as nfi
from qb2.api.v1 import typing as qt, resources as qr, filters as qf
from pytils import (
    get_cluster,
    get_driver,
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
)
import datetime
import sys
import json
from collections import Counter


IRON_BRANCH_TABLE = "//home/videolog/strm_meta/iron_branch/concat"
UINT64_MAX = 18446744073709551615


def tryint(s, default=0):
    try:
        return int(s)
    except:
        return default


def get_max_prob_value(dct):
    try:
        return Counter(dct).most_common(1)[0][0]
    except:
        return ""


def repack_dict(d1, d2, values, t="str"):
    for v in values:
        if t == "str":
            d2[v] = d1.get(v or "")
        elif t == "int":
            d2[v] = tryint(d1.get(v or "0"))


def process_region(reg, geobase):
    try:
        reg = geobase.region_by_id(reg)
    except Exception:
        return 10000
    donext = False
    for reg_ in reg.path:
        if donext and reg_.type == 5:
            return reg_.id
        if reg_.type == 3:
            if reg_.id != 225:
                return reg_.id
            else:
                donext = True
                continue
    if donext:
        return 225
    else:
        return 10000


def get_country(region):
    for reg in region.path:
        if reg.type == 3:
            return reg.short_name
    return "UNK"


def get_most_common(counter):
    try:
        return counter.most_common(1)[0][0]
    except:
        return ""


def get_uniqid(yandexuid, yu_hash):
    if yandexuid and yandexuid != "-":
        try:
            return int(yandexuid)
        except ValueError:
            return 0
    elif yu_hash and len(yu_hash) == 32:
        try:
            return int(yu_hash, 16) % 2 ** 32
        except:
            return 0
    return 0


vhds_mapper_output_schema = dict(
    VideoContentID=qt.Optional[qt.String],
    VideoSessionID=qt.String,
    Service=qt.String,
    IP=qt.String,
    DetailedDeviceType=qt.String,
    BrowserName=qt.String,
    UserAgent=qt.String,
    PageID=qt.Integer,
    DeviceType=qt.Integer,
    VideoCategoryID=qt.Integer,
    EventTime=qt.Integer,
    UpdateTime=qt.Integer,
    UniqID=qt.UInt64,
    Price=qt.Integer,
    PartnerPrice=qt.Integer,
    RegionID=qt.UInt64,
    IsView=qt.Integer,
    Duration=qt.Integer,
    Hits=qt.Integer,
    ShownHits=qt.Integer,
    UserGender=qt.String,
    UserAgeSegment=qt.String,
)


class VhDetailedStatMapper(object):
    def __init__(self, update_time):
        self.update_time = update_time
        self.device_type_dict = {
            "unknown": 0,
            "phone": 3,
            "tablet": 4,
            "desktop": 5,
        }

    def __call__(self, recs):
        geobase = qr.get("Geobase")
        for rec in recs:
            result = {}
            vcid = rec.get("video_content_id", "")
            if vcid == "novcid":
                vcid = ""
            ref_from = rec.get("ref_from", "") or ""
            if ref_from == "-":
                ref_from = ""
            result["VideoContentID"] = vcid
            isView = (
                rec["is_view_new"] or rec["view_time"] > 0 or rec["price"] > 0
            )
            if (
                (not rec.get("channel", "") or not rec["view_time"])
                and not rec.get("price", 0)
                and not rec.get("partner_price", 0)  # and not
                # rec.get('hits_good', 0) and not rec.get('winhits_good', 0)
            ):
                continue
            result["VideoSessionID"] = rec.get("vsid", "") or ""
            result["Service"] = ref_from
            result["IP"] = rec.get("ip", "") or ""
            result["DetailedDeviceType"] = rec.get("os_family", "") or ""
            result["BrowserName"] = rec.get("browser_name", "") or ""
            result["UserAgent"] = rec.get("user_agent", "") or ""
            result["PageID"] = tryint(rec.get("page_id", "0"))
            dt_hr = rec.get("device_type", "unknown") or "unknown"
            device_type = self.device_type_dict.get(dt_hr, 0)
            result["DeviceType"] = device_type
            result["VideoCategoryID"] = tryint(rec.get("category_id", "0"))
            result["EventTime"] = tryint(rec.get("timestamp", "0"))
            result["UniqID"] = get_uniqid(rec["yandexuid"], rec["yu_hash"])
            if result["UniqID"] > UINT64_MAX:
                result["UniqID"] = 0
            result["Price"] = tryint(rec.get("price", "0"))
            result["PartnerPrice"] = tryint(rec.get("partner_price", "0"))
            result["UpdateTime"] = self.update_time
            regionId = tryint(rec.get("region", "0"))
            result["RegionID"] = process_region(regionId, geobase)
            result["IsView"] = isView
            result["Duration"] = int(rec["view_time"])
            result["Hits"] = tryint(rec.get("hits_good", "0"))
            result["ShownHits"] = tryint(rec.get("winhits_good", "0"))
            result["UserGender"] = get_max_prob_value(rec.get("gender", {}))
            result["UserAgeSegment"] = get_max_prob_value(
                rec.get("user_age_6s", {})
            )
            yield Record.from_dict(result)


def try_get_from_parent(chain, field, t="str", noskip=False):
    for element in chain[::-1]:
        if element["ContentTypeID"] not in {2, 4, 6, 14, 32} and not noskip:
            continue
        try:
            return element[field]
        except (IndexError, TypeError, KeyError):
            if t == "str":
                return ""
            else:
                return 0
    if t == "str":
        return ""
    else:
        return 0


def try_get_from_chain(chain, field, t="str", elem=-1):
    try:
        return chain[elem][field]
    except (IndexError, TypeError, KeyError):
        if t == "str":
            return ""
        else:
            return 0


def process_date(date, cluster, job_root, debug=False):
    source_table = "//cubes/video-strm/{}/sessions".format(date)
    mapped_table = "{}/{}_mapped".format(job_root, date)
    # reduced_table = "{}/{}".format(job_root, date)
    # reduced_table_2 = "{}/{}_2".format(job_root, date)
    reduced_table = "//home/videoquality/vh_analytics/vhds_reduced_test/{}".format(date)
    reduced_table_2 = "//home/videoquality/vh_analytics/vhds_reduced_test2/{}".format(date)

    job = cluster.job()

    iron_branch = job.table(IRON_BRANCH_TABLE).project(
        VideoContentID="JoinKey",
        ContentUUID="UUID",
        ContentTypeID=ne.custom(
            lambda x: try_get_from_chain(x, "ContentTypeID", t="int", elem=-1),
            "chain",
        ).add_hints(type=int),
        ParentUUID=ne.custom(
            lambda x: try_get_from_chain(x, "UUID", elem=-2), "chain"
        ).add_hints(type=str),
        ParentTypeID=ne.custom(
            lambda x: try_get_from_chain(x, "ContentTypeID", t="int", elem=-2),
            "chain",
        ).add_hints(type=int),
    )

    update_time = int(date.strftime("%s"))

    stream = (
        job.table(source_table)
        .map(
            with_hints(output_schema=vhds_mapper_output_schema)(
                VhDetailedStatMapper(update_time)
            ),
            files=[nfi.StatboxDict("Geobasev6.bin")],
        )
        .join(iron_branch, type="left", by="VideoContentID")
        .project(
            ne.all(
                exclude=[
                    "ContentUUID",
                    "ContentTypeID",
                    "ParentUUID",
                    "ParentTypeID",
                    "VideoContentID",
                ]
            ),
            ContentUUID=ne.custom(lambda x: x or "", "ContentUUID").add_hints(
                type=str
            ),
            VideoContentID=ne.custom(
                lambda x: x or "", "VideoContentID"
            ).add_hints(type=str),
            ParentUUID=ne.custom(lambda x: x or "", "ParentUUID").add_hints(
                type=str
            ),
            ContentTypeID=ne.custom(
                lambda x: x or 0, "ContentTypeID"
            ).add_hints(type=int),
            ParentTypeID=ne.custom(lambda x: x or 0, "ParentTypeID").add_hints(
                type=int
            ),
        )
    )

    if debug:
        stream = stream.put(mapped_table)

    normal_aggregators = dict(
        VideoContentID=na.max("VideoContentID"),
        IsView=na.max("IsView"),
        UniqID=na.max("UniqID"),
        Service=na.max("Service"),
        RegionID=na.max("RegionID"),
        UserGender=na.max("UserGender"),
        UserAgeSegment=na.max("UserAgeSegment"),
        IP=na.max("IP"),
        UserAgent=na.max("UserAgent"),
        DeviceType=na.max("DeviceType"),
        DetailedDeviceType=na.max("DetailedDeviceType"),
        BrowserName=na.max("BrowserName"),
        VideoCategoryID=na.max("VideoCategoryID"),
        Price=na.sum("Price"),
        PartnerPrice=na.sum("PartnerPrice"),
        Duration=na.sum("Duration"),
        Hits=na.sum("Hits"),
        ShownHits=na.sum("ShownHits"),
    )

    to_concat = []

    stream_ = (
        stream.filter(
            qf.or_(
                qf.nonzero("Price"),
                qf.nonzero("PartnerPrice"),
                qf.nonzero("Hits"),
                qf.nonzero("ShownHits"),
                qf.nonzero("IsView"),
            )
        )
        .groupby(
            "UpdateTime",
            "ParentUUID",
            "ContentUUID",
            "PageID",
            "ParentTypeID",
            "ContentTypeID",
            "VideoSessionID",
        )
        .aggregate(**normal_aggregators)
    )
    stream = stream_

    # normal situation: both uid and vcid are present
    to_concat.append(
        stream.filter(qf.and_(qf.nonzero("UniqID"), qf.nonzero("ContentUUID")))
    )

    # all other situations: group by everything available, sum IsView
    other_aggregators = normal_aggregators.copy()
    other_aggregators.update(
        dict(
            UpdateTime=na.max("UpdateTime"),
            VideoSessionID=na.max("VideoSessionID"),
            IsView=na.sum("IsView"),
        )
    )
    for key in {
        "BrowserName",
        "DeviceType",
        "DetailedDeviceType",
        "UserGender",
        "UserAgeSegment",
        "RegionID",
        "UniqID",
        "VideoCategoryID",
        "Service",
    }:
        other_aggregators.pop(key)
    to_concat.append(
        stream.filter(
            qf.not_(qf.and_(qf.nonzero("UniqID"), qf.nonzero("ContentUUID")))
        )
        .groupby(
            "PageID",
            "UniqID",
            "ParentUUID",
            "ContentUUID",
            "VideoCategoryID",
            "ParentTypeID",
            "ContentTypeID",
            "BrowserName",
            "DeviceType",
            "DetailedDeviceType",
            "UserGender",
            "UserAgeSegment",
            "RegionID",
            "Service",
        )
        .aggregate(**other_aggregators)
    )

    job.concat(*to_concat).put(reduced_table)

    (
        stream_.groupby(
            "PageID",
            "UniqID",
            "ParentUUID",
            "ContentUUID",
            "VideoCategoryID",
            "ParentTypeID",
            "ContentTypeID",
            "BrowserName",
            "DeviceType",
            "DetailedDeviceType",
            "UserGender",
            "UserAgeSegment",
            "RegionID",
            "Service",
        )
        .aggregate(**other_aggregators)
        .put(reduced_table_2)
    )

    job.run()

    return reduced_table


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/vhds_reduced_test"
    )
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--title", default="VH Detailed Stat")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--to", default=None)
    parser.add_argument("--output")
    args = parser.parse_args()

    cluster = get_cluster(clusters, args)
    from_ = getattr(args, "from", None)
    to_ = getattr(args, "to", None)

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        processed_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root=args.job_root, path_filter=get_date
            )
        )
        try:
            last_date = processed_dates[-1]
        except IndexError:
            last_date = datetime.date(2018, 9, 17)
        available_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=lambda x: get_date(x) and x.endswith("sessions"),
            )
        )

        dates = [x for x in available_dates if x > last_date]

    processed_tables = []
    exceptions = []
    for date in dates:
        print("processing {}".format(date))
        try:
            p = process_date(date, cluster, args.job_root, debug=args.debug)
            if p:
                processed_tables.append(p.split("/")[-1])
        except Exception as e:
            sys.stderr.write("Exception: {}".format(e))
            exceptions.append(e)

    if processed_tables:
        json.dump(processed_tables, open(args.output, "w"), indent=2)
    if exceptions:
        raise exceptions[0]


if __name__ == "__main__":
    main()
