#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import argparse
import datetime
from nile.api.v1 import (
    filters as nf,
    aggregators as na,
    extractors as ne,
    clusters,
    Record,
    with_hints,
    extended_schema,
)
import re
import math
import json
import random
from collections import defaultdict
from videolog_common import get_driver, get_cluster

from qb2.api.v1 import extractors as se, filters as sf, typing as qt

import nile

from url_canonizer_py import CanonizePageUrl, CanonizeFrameUrl


FILE_LIST = [
    "simple_owners.lst",
    "url_canonizer_py.so",
    "url_generators.def",
    "url_generators.flashvars",
    "urlsconvert.original2thumb.re",
    "urlsconvert.parsedurl.re",
    "urlsconvert.srcurl.re",
    "urlsconvert.thumbs.re",
    "urlsconvert.urls.re",
]
TITLE = "Selrank Daily Additive Stats"


def date_range(from_, to_):
    if isinstance(from_, basestring):
        from_ = datetime.datetime.strptime(from_, "%Y-%m-%d").date()
    if isinstance(to_, basestring):
        to_ = datetime.datetime.strptime(to_, "%Y-%m-%d").date()
    mvr = min([from_, to_])
    result = []
    while mvr <= max([from_, to_]):
        result.append(mvr)
        mvr += datetime.timedelta(days=1)
    if to_ < from_:
        result = result[::-1]
    return result


def count_lvt(view_time):
    if view_time < 30:
        return 0
    return math.log(view_time - 25)


re_yandex = re.compile(r"yandex\.(.+)/video")


class VcParser(object):
    def __init__(self, date):
        self.date = date

    def _old_parse(self, dec, rec):
        for d in dec:
            try:
                parsed_vc = json.loads(d)
            except (TypeError, ValueError, AttributeError):
                continue
            for elem in parsed_vc:
                url = elem.keys()[0]
                try:
                    p = elem[url].get("p")
                except (KeyError, ValueError, AttributeError, TypeError):
                    continue
                for video_info in p:
                    guid = video_info[0]
                    cor = 0
                    if len(str(guid)) <= 1:
                        guid = None
                        cor = -1
                    viewtime = video_info[5 + cor]
                    duration = video_info[4 + cor]
                    try:
                        frame_url = video_info[11 + cor]
                    except IndexError:
                        frame_url = url
                    yield Record(
                        page_url=self._canonize_page_url(url),
                        frame_url=self._canonize_frame_url(frame_url),
                        guid=guid,
                        yandexuid=rec.yandexuid,
                        duration=float(duration),
                        length=float(viewtime),
                    )

    @staticmethod
    def _canonize_page_url(url):
        try:
            return CanonizePageUrl(url)
        except:
            return url

    @staticmethod
    def _canonize_frame_url(url):
        try:
            return CanonizeFrameUrl(url)
        except:
            return url

    def _new_parse(self, dec, rec):
        for d in dec:
            page_url = d["url"]
            by_uid = defaultdict(lambda: {"length": 0, "duration": 0})
            for elem in d["data"]:
                dct = by_uid[elem["uid"]]
                dct["length"] = float(elem["total_played_duration"])
                try:
                    dct["frame_url"] = elem["frame_url"]
                except KeyError:
                    dct["frame_url"] = d["url"]
                try:
                    dct["duration"] += float(elem["duration"])
                except ValueError:
                    pass
            for uid in by_uid:
                dct = by_uid[uid]
                yield Record(
                    page_url=self._canonize_page_url(page_url),
                    yandexuid=rec.yandexuid,
                    frame_url=self._canonize_frame_url(dct["frame_url"]),
                    length=dct["length"],
                    duration=dct["duration"],
                    guid=uid,
                )

    def __call__(self, recs):
        for rec in recs:
            if re_yandex.search(str(rec.referer)):
                continue
            try:
                dec = json.loads(rec.dec)
            except (TypeError, ValueError):
                continue
            if "rebuffering_times" in rec.dec:
                for rec_ in self._new_parse(dec, rec):
                    yield rec_
            else:
                for rec_ in self._old_parse(dec, rec):
                    yield rec_


class StatsParser(object):
    def __init__(self, date):
        self.date = parse_date(date)

    def __call__(self, records):
        re_yandex = re.compile(r"yandex\.(.+)/video")
        for rec in records:
            # try:
            #     date = datetime.datetime.fromtimestamp(
            #         int(rec.event_timestamp)
            #     )
            # except:
            #     continue
            device_id = rec.device_id
            val = rec.raw_event_value

            try:
                obj = json.loads(val)

                frame_url = obj.get("frame url") or ""
                page_url = obj.get("page url") or ""
                play_length = float(obj.get("play length"))
                # try:
                #     length = int(obj.get('length'))
                # except (TypeError, AttributeError, ValueError):
                #     length = 0
            except (TypeError, AttributeError, ValueError):
                continue
            try:
                length = float(obj.get("length")) or 0
            except ValueError:
                length = None

            if re_yandex.search(frame_url) or re_yandex.search(page_url):
                continue

            if not play_length or play_length > 86400:
                continue

            try:
                page_url = CanonizePageUrl(page_url)
            except:
                continue
            try:
                frame_url = CanonizeFrameUrl(frame_url)
            except:
                continue

            if play_length > 30:
                lvt = math.log(play_length, math.e)
            else:
                lvt = 0

            result = dict(
                # date=date,
                device_id=device_id,
                frame_url=frame_url,
                page_url=page_url,
                length=play_length,
                duration=length or None,
                lvt=lvt,
            )

            yield Record(**result)


class DataReducer(object):
    def __init__(self, date):
        self.date = parse_date(date)

    def __call__(self, groups):
        for key, records in groups:
            result = {}
            static_data = {}
            for rec in records:
                if not static_data:
                    try:
                        static_data = rec.static_data
                    except AttributeError:
                        pass
                if not result:
                    result.update(rec.to_dict())
                else:
                    result["data"].update(rec.data)
            result["data"] = {
                k: v
                for k, v in result["data"].items()
                if parse_date(k) >= (self.date - datetime.timedelta(days=180))
            }
            result["static_data"] = static_data
            if result["data"]:
                yield Record(**result)


def parse_date(s):
    try:
        return datetime.datetime.strptime(s, "%Y-%m-%d").date()
    except (TypeError, ValueError, AttributeError):
        return


re_date = re.compile(r"[0-9]{4}-[0-9]{2}-[0-9]{2}")


def get_date_from_table(table):
    try:
        return re_date.search(table).group(0)
    except AttributeError:
        return


def platform_pipeline(platform, job, table, date):
    files = [nile.files.LocalFile(x) for x in FILE_LIST]
    if platform == "desktop":
        stream = (
            job.table(table, ignore_missing=True)
            .qb2(
                log="bar-navig-log",
                fields=[
                    "timestamp",
                    "yandexuid",
                    "geo_id",
                    "referer",
                    "url",
                    "parsed_http_params",
                    se.dictitem(
                        "decoded_vc", from_="parsed_http_params"
                    ).with_type(qt.List[qt.String]),
                    se.custom(
                        "dec", lambda x: x[0] if x else "-", "decoded_vc"
                    ).with_type(qt.String),
                ],
                filters=[
                    sf.default_filtering("bar-navig-log"),
                    sf.defined("decoded_vc", "yandexuid"),
                ],
            )
            .map(
                with_hints(
                    output_schema=dict(
                        page_url=qt.Optional[qt.String],
                        yandexuid=qt.Optional[qt.String],
                        frame_url=qt.Optional[qt.String],
                        length=qt.Optional[qt.Float],
                        duration=qt.Optional[qt.Float],
                        guid=qt.Optional[qt.String],
                    )
                )(VcParser(date)),
                files=files,
                memory_limit=2000,
            )
            .filter(sf.not_(sf.contains("page_url", "yandex.ru/video")))
        )
        noguid = stream.filter(sf.not_(sf.nonzero("guid")))
        guid = (
            stream.filter(sf.nonzero("guid"))
            .groupby("yandexuid", "guid")
            .aggregate(
                page_url=na.any("page_url"),
                frame_url=na.any("frame_url"),
                length=na.max("length"),
                duration=na.any("duration"),
            )
            .project(
                ne.all(exclude=["guid"]),
                lvt=ne.custom(count_lvt, "length").with_type(
                    qt.Optional[qt.Float]
                ),
            )
            .groupby("yandexuid", "page_url", "frame_url")
            .aggregate(
                length=na.sum("length"),
                lvt=na.sum("lvt"),
                duration=na.any("duration"),
            )
        )
        stream = job.concat(noguid, guid).filter(
            sf.compare("length", "<", value=86400)
        )
        return stream
    elif platform == "mobile":
        return (
            job.table(table, ignore_missing=True)
            .qb2(
                log="metrika-mobile-log",
                fields=[
                    "app_platform",
                    "device_id",
                    "event_name",
                    "raw_event_value",
                    "session_type",
                    "date",
                    "geo_id",
                    "event_timestamp",
                    se.dictitem("AppID", from_="parsed_log_line").with_type(
                        qt.String
                    ),
                ],
                filters=[
                    sf.equals("event_name", "video statistics"),
                    sf.region_belongs([225], field="geo_id"),
                ],
            )
            .map(
                with_hints(
                    output_schema=dict(
                        page_url=qt.Optional[qt.String],
                        device_id=qt.Optional[qt.String],
                        frame_url=qt.Optional[qt.String],
                        length=qt.Optional[qt.Float],
                        duration=qt.Optional[qt.Float],
                        lvt=qt.Optional[qt.Float],
                    )
                )(StatsParser(date)),
                files=files,
                memory_limit=2000,
            )
        )


def pre_reduce(groups):
    rnd = random.SystemRandom()
    for key, records in groups:
        result = key.to_dict()
        cnt = 0
        bad = False
        length = 0
        lvt = 0
        duration = None
        page_url = None
        frame_url = None
        for rec in records:
            length += rec.length
            lvt += rec.lvt
            if not duration and getattr(rec, "duration", ""):
                duration = rec.duration
            if not page_url and getattr(rec, "page_url", ""):
                page_url = rec.page_url
            if not frame_url and getattr(rec, "frame_url", ""):
                frame_url = rec.frame_url
            elif rnd.randint(0, 10) == 1 and getattr(rec, "duration", ""):
                duration = rec.duration
            cnt += 1
            if length > 86400 or length < 0 or cnt > 100:
                bad = True
                break
        if bad:
            continue
        result["length"] = length
        result["lvt"] = lvt
        result["duration"] = duration
        result["frame_url"] = frame_url
        result["page_url"] = page_url
        yield Record(**result)


def process_table_fast(table, result_path, job, fields, platform, threshold=5):
    date = get_date_from_table(table)

    id_field = {"desktop": "yandexuid", "mobile": "device_id"}[platform]

    fields_plus_id = [id_field] + fields

    aggregate_kwargs = dict(
        tvt=na.sum("length"),
        lvt=na.sum("lvt"),
        duration=na.any("duration"),
        page_url=na.any("page_url"),
        frame_url=na.any("frame_url"),
        users=na.count_distinct_estimate(id_field),
        shows=na.count(),
        avg_vt=na.mean("length"),
    )
    for f in fields:
        if f in aggregate_kwargs:
            aggregate_kwargs.pop(f)

    vc_parsed = (
        platform_pipeline(platform, job, table, date)
        .groupby(*fields_plus_id)
        .reduce(
            with_hints(
                output_schema={
                    id_field: qt.Optional[qt.String],
                    "page_url": qt.Optional[qt.String],
                    "frame_url": qt.Optional[qt.String],
                    "duration": qt.Optional[qt.Float],
                    "lvt": qt.Optional[qt.Float],
                    "length": qt.Optional[qt.Float],
                }
            )(pre_reduce),
            intensity="cpu",
        )
        .groupby(*fields)
        .aggregate(**aggregate_kwargs)
    )

    if threshold:
        vc_parsed = vc_parsed.filter(
            nf.custom(lambda x: x >= threshold, "shows")
        )

    vc_parsed = vc_parsed.project(
        *fields,
        static_data=ne.custom(
            lambda d, page_url, frame_url: {
                "duration": d,
                "page_url": page_url,
                "frame_url": frame_url,
            },
            "duration",
            "page_url",
            "frame_url",
        ).with_type(qt.Json),
        data=ne.custom(
            lambda t, l, u, s, a: {
                date: {"tvt": t, "lvt": l, "users": u, "shows": s, "avg_vt": a}
            },
            "tvt",
            "lvt",
            "users",
            "shows",
            "avg_vt",
        ).with_type(qt.Json)
    ).put("{}/{}".format(result_path, date))


def process_table(table, table_additive, hahn, fields, platform, threshold=5):
    table_additive_schematized = "{}_schematized".format(table_additive)
    date = get_date_from_table(table)
    job = hahn.job()
    get_row_count(table_additive, hahn)
    id_field = {"desktop": "yandexuid", "mobile": "device_id"}[platform]

    fields_plus_id = [id_field] + fields

    aggregate_kwargs = dict(
        tvt=na.sum("length"),
        lvt=na.sum("lvt"),
        duration=na.any("duration"),
        page_url=na.any("page_url"),
        frame_url=na.any("frame_url"),
        users=na.count_distinct_estimate(id_field),
        shows=na.count(),
        avg_vt=na.mean("length"),
    )
    for f in fields:
        if f in aggregate_kwargs:
            aggregate_kwargs.pop(f)

    vc_parsed = (
        platform_pipeline(platform, job, table, date)
        .groupby(*fields_plus_id)
        .reduce(
            with_hints(
                output_schema={
                    id_field: qt.Optional[qt.String],
                    "page_url": qt.Optional[qt.String],
                    "frame_url": qt.Optional[qt.String],
                    "duration": qt.Optional[qt.Float],
                    "lvt": qt.Optional[qt.Float],
                    "length": qt.Optional[qt.Float],
                }
            )(pre_reduce)
        )
        .groupby(*fields)
        .aggregate(**aggregate_kwargs)
    )

    if threshold:
        vc_parsed = vc_parsed.filter(
            nf.custom(lambda x: x >= threshold, "shows")
        )

    vc_parsed = vc_parsed.project(
        *fields,
        static_data=ne.custom(
            lambda d, page_url, frame_url: {
                "duration": d,
                "page_url": page_url,
                "frame_url": frame_url,
            },
            "duration",
            "page_url",
            "frame_url",
        ).with_type(qt.Json),
        data=ne.custom(
            lambda t, l, u, s, a: {
                date: {"tvt": t, "lvt": l, "users": u, "shows": s, "avg_vt": a}
            },
            "tvt",
            "lvt",
            "users",
            "shows",
            "avg_vt",
        ).with_type(qt.Json)
    )

    # schema = {"data": qt.Json, "static_data": qt.Json}
    # for field in fields:
    #     schema[field] = str

    job.concat(
        vc_parsed, job.table(table_additive, ignore_missing=True)
    ).groupby(*fields).reduce(
        with_hints(output_schema=extended_schema())(DataReducer(date))
    ).sort(
        *fields
    ).put(
        table_additive
    )
    last_date = parse_date(
        get_driver(hahn).client.get_attribute(table_additive, "last_date")
    )
    date_parsed = parse_date(date)
    get_row_count(table_additive, hahn)
    job.run()

    # if hahn.driver.exists(table_additive_schematized):
    #     hahn.driver.remove(table_additive_schematized)

    # job = hahn.job()

    # job.table(table_additive).put(table_additive_schematized, schema=schema)

    # job.run()

    get_row_count(table_additive, hahn)
    if date_parsed > last_date and date_parsed <= datetime.date.today():
        driver = get_driver(hahn)
        driver.client.set_attribute(table_additive, "last_date", date)
        # driver.client.set_attribute(
        #     table_additive_schematized, "last_date", date
        # )


def get_row_count(table, cluster):
    rc = get_driver(cluster).client.get_attribute(table, "row_count", 0)
    print("table {} has {} rows".format(table, rc))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--from")
    parser.add_argument("--to")
    parser.add_argument("--pool")
    parser.add_argument("--proxy", default="hahn")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--threshold", type=int, default=5)
    parser.add_argument("--platform", default="desktop")
    parser.add_argument("--fields", default="frame_url")
    parser.add_argument("--title", default=TITLE)
    parser.add_argument("--additive", default=None)
    args = parser.parse_args()

    args.fields = args.fields.split(",")

    if args.threshold <= 1:
        args.threshold = 0
    if not args.additive:
        args.additive = "//home/videolog/selrank_stats/additive_{}_both_fields_wo_threshold_schematized_test".format(
            args.platform
        )

    hahn = get_cluster(clusters, args)

    table_additive = args.additive
    get_row_count(table_additive, hahn)
    # if not hahn.driver.exists(table_additive):
    #     hahn.write(table_additive, [])
    #     hahn.driver.client.set_attribute(
    #         table_additive,
    #         'last_date',
    #         (
    #             datetime.date.today() - datetime.timedelta(days=180)
    #         ).strftime('%Y-%m-%d')
    #     )

    current_date = datetime.date.today()

    from_ = parse_date(getattr(args, "from"))
    to_ = parse_date(getattr(args, "to"))

    log_dir = {"desktop": "bar-navig-log", "mobile": "metrika-mobile-log"}[
        args.platform
    ]

    if not from_ or not to_:
        last_date = parse_date(
            get_driver(hahn).client.get_attribute(table_additive, "last_date")
        )
        tables = list(
            get_driver(hahn).client.search(
                root="//logs/{}/1d".format(log_dir),
                path_filter=(
                    lambda x: parse_date(get_date_from_table(x))
                    and parse_date(get_date_from_table(x)) > last_date
                    and parse_date(get_date_from_table(x)) <= current_date
                ),
            )
        )
        print(format(tables))
        for table in tables:
            if get_driver(hahn).exists(table):
                process_table(
                    table,
                    table_additive,
                    hahn,
                    args.fields,
                    args.platform,
                    threshold=args.threshold,
                )
        print("finished")
    else:
        if args.fast:
            job = hahn.job().env(
                parallel_operations_limit=10,
                tmp_root="//home/videoquality/vh_analytics/tmp",
            )
            tables = list(
                get_driver(hahn).client.search(
                    root="//logs/{}/1d".format(log_dir),
                    path_filter=(
                        lambda x: parse_date(get_date_from_table(x))
                        and parse_date(get_date_from_table(x)) >= from_
                        and parse_date(get_date_from_table(x)) <= to_
                    ),
                )
            )
            result_path = "//home/videolog/selrank_stats/by_day"
            for table in tables:
                process_table_fast(
                    table,
                    result_path,
                    job,
                    args.fields,
                    args.platform,
                    threshold=args.threshold,
                )
            job.run()
        else:
            for date in date_range(from_, to_):
                table = "//logs/{}/1d/{}".format(log_dir, date)
                if get_driver(hahn).exists(table):
                    get_row_count(table_additive, hahn)
                    process_table(
                        table,
                        table_additive,
                        hahn,
                        args.fields,
                        args.platform,
                        threshold=args.threshold,
                    )


if __name__ == "__main__":
    main()
