#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import sys
import os
import codecs
import argparse
import datetime
import re
import json
import yt.wrapper as yt
from yql.api.v1.client import YqlClient
from videolog_common import YqlRunner, date_range, get_date, apply_replacements


TITLE = "[MMA-4618] Query Spikes | YQL"
ROOT = "//user_sessions/pub/{service}/{timeframe}"
re_ts = re.compile("(^[0-9])?(1[0-9]{9})(^[0-9])?")


class Moscow(datetime.tzinfo):
    def utcoffset(self, dt):
        return datetime.timedelta(hours=3)

    def tzname(self, dt):
        return "Moscow"

    def dst(self, dt):
        return datetime.timedelta(0)


moscow = Moscow()


def parse_ts(ts):
    return datetime.datetime.fromtimestamp(int(ts), moscow)


def get_ts_from_str(str_):
    srch = re_ts.search(str_)
    if srch:
        return parse_ts(srch.group(2))


def parse_date(s):
    if "T" in s:
        return datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%S").replace(
            tzinfo=moscow
        )
    return datetime.datetime.strptime(s, "%Y-%m-%d").replace(tzinfo=moscow)


def get_input_tables(config, start_date, end_date):
    needed_dates = date_range(start_date, end_date)
    dailies = [
        x
        for x in yt.search(
            root=ROOT.format(service=config["service"], timeframe="daily"),
            node_type="table",
            path_filter=lambda x: (
                x.endswith("clean")
                and get_date(x)
                and get_date(x) >= start_date.date()
                and get_date(x) <= end_date.date()
            ),
        )
    ]
    set_dailies = {get_date(x) for x in dailies}
    if set_dailies == set(needed_dates):
        return dailies
    needed_datetimes = []
    if not dailies:
        dt_mover = start_date
    else:
        last_date = max(get_date(x) for x in dailies)
        nd = last_date + datetime.timedelta(days=1)
        dt_mover = datetime.datetime(nd.year, nd.month, nd.day, tzinfo=moscow)
    while dt_mover <= end_date:
        needed_datetimes.append(dt_mover)
        dt_mover += datetime.timedelta(minutes=30)
    fasts = [
        x
        for x in yt.search(
            root=ROOT.format(service=config["service"], timeframe="fast"),
            node_type="table",
            path_filter=lambda x: (
                x.endswith("clean")
                and get_ts_from_str(x)
                and get_ts_from_str(x) >= min(needed_datetimes)
                and get_ts_from_str(x) <= max(needed_datetimes)
            ),
        )
    ]
    set_fasts = {get_ts_from_str(x) for x in fasts}
    if set_fasts == set(needed_datetimes):
        return dailies + fasts
    else:
        raise Exception(
            "some tables are missing: {}".format(
                ",".join(map(str, sorted(set(needed_datetimes) - set_fasts)))
            )
        )


def set_default_values(config):
    def _setdefault(key, value):
        if key not in config:
            config[key] = value

    _setdefault("service", "video")
    _setdefault("days", 7)
    _setdefault("days_for_spikes", 1)
    _setdefault("users_threshold_for_tmp", 50)
    _setdefault("users_threshold_for_spikes", 50)
    _setdefault("use_existing_tmp_table", False)
    _setdefault("save_tmp_table", False)
    _setdefault("override_tmp_table", None)
    _setdefault("override_output_table", None)
    _setdefault("root", "//home/videoquality/vh_analytics/queries_spikes")


def safe(s):
    return str(s).split("+")[0].replace(":", "_").replace(" ", "_")


def main():
    """
        config example:
        {
            "service": "video",
            "days": 7,
            "days_for_spikes": 1,
            "users_threshold_for_tmp": 50,
            "users_threshold_for_spikes": 50,
            "save_tmp_table": true,
        }
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--pool", default="robot-mma-nirvana")
    parser.add_argument("--end_date", "-e")
    parser.add_argument("--output", default="output.json")
    parser.add_argument("--config", "-c")
    args = vars(parser.parse_args())

    with open(args["config"]) as f:
        config = json.load(f)
    set_default_values(config)

    if config["save_tmp_table"] and config["use_existing_tmp_table"]:
        raise Exception(
            "conflicting config --save_tmp_table and --use_existing_tmp_table"
        )

    end_date = parse_date(args["end_date"])
    start_date = end_date - datetime.timedelta(days=config["days"])

    input_tables = get_input_tables(config, start_date, end_date)

    if isinstance(end_date, datetime.date):
        dt_ = datetime.datetime(
            end_date.year, end_date.month, end_date.day, tzinfo=moscow
        )
    else:
        dt_ = end_date
    ts_threshold_for_tmp_low = (
        dt_ - datetime.timedelta(config["days"])
    ).strftime("%s")
    ts_threshold_for_tmp_high = dt_.strftime("%s")
    ts_threshold_for_spikes = (
        dt_ - datetime.timedelta(config["days_for_spikes"])
    ).strftime("%s")

    if not config["override_tmp_table"]:
        tmp_table = "{}/{}_{}_{}".format(
            config["root"], config["service"], safe(start_date), safe(end_date)
        )
    else:
        tmp_table = config["override_tmp_table"]

    if not config["override_output_table"]:
        output_table = "{}/{}_{}_{}_spikes".format(
            config["root"], config["service"], safe(start_date), safe(end_date)
        )
    else:
        output_table = config["override_output_table"]

    with open("queries_spikes_stub.sql") as f:
        query = f.read()
    query = apply_replacements(
        query,
        {
            "@[pool]": args["pool"],
            "@[input_tables]": json.dumps(input_tables, indent=4),
            "@[ts_threshold_for_tmp_low]": ts_threshold_for_tmp_low,
            "@[ts_threshold_for_tmp_high]": ts_threshold_for_tmp_high,
            "@[ts_threshold_for_spikes]": ts_threshold_for_spikes,
            "@[users_threshold_for_tmp]": str(
                config["users_threshold_for_tmp"]
            ),
            "@[users_threshold_for_spikes]": str(
                config["users_threshold_for_spikes"]
            ),
            "@[fresh_intent_extractor]": "$get{}FreshIntent".format(
                config["service"].title()
            ),
            "@[tmp_table]": tmp_table,
            "@[output_table]": output_table,
        },
    )
    if config["save_tmp_table"]:
        query = apply_replacements(query, {"/*save_tmp": "", "save_tmp*/": ""})
    if config["use_existing_tmp_table"]:
        query = apply_replacements(
            query, {"/*use_existing_map": "", "use_existing_map*/": ""}
        )
    else:
        query = apply_replacements(
            query, {"/*create_new_map": "", "create_new_map*/": ""}
        )

    proxy = os.environ["YT_PROXY"].lower()
    yc = YqlClient(db=proxy)
    yr = YqlRunner(yc, title=TITLE)
    yr.run(
        query,
        wait=True,
        attachments=[
            {
                "path": "analytics/videolog/mma_4618_queries_spikes/queries_spikes_common.sql"
            },
            {
                "path": "analytics/videolog/mma_4618_queries_spikes/spikes_mapper.py"
            },
        ],
    )

    with open(args["output"], "w") as f:
        json.dump({"cluster": proxy, "table": output_table}, f)


if __name__ == "__main__":
    main()
