
import random
import datetime

from yp.common import YP_PRODUCTION_CLUSTERS, YP_TESTING_CLUSTERS, YP_PRESTABLE_CLUSTERS
import yt.wrapper
import getpass
import infra.analytics.io_limits_pipeline.utils as utils


from infra.rtc.analytics.yasm_metrics_collect.mappers import AbcToPodsMapper, ServiceToPodsMapper, PodsToPodsMapper, \
    PodsToMetricsMapper
from infra.rtc.analytics.yasm_metrics_collect.mappers import WalleToHostsMapper, HostsToHostsMapper, \
    HostsToMetricsMapper
from infra.rtc.analytics.yasm_metrics_collect.data import is_histogram
from infra.rtc.analytics.yasm_metrics_collect.requests import get_all_abcs, get_all_rtc_walles
from infra.rtc.analytics.yasm_metrics_collect.utils import INPUT_TO_MAPPER_TYPE

# FIXME
import infra.rtc.analytics.yasm_metrics_collect.utils as utils_for_preprocessing


def extract(string, tp):
    return string[len(tp):].strip().strip("\"'`")


def parse_signal(sig):
    parsed_sig = {}
    sig = sig.replace("SELECT ", "")
    args = sig.split(":")
    parsed_sig["aggregation"] = [extract(agg, "") for agg in args[0].split("|")]
    parsed_sig["signals"] = "||".join([extract(agg, "") for agg in args[1].split("||")])
    return parsed_sig


def construct_timestamp(dtm):
    if dtm.startswith("-"):
        current = int(datetime.datetime.now().timestamp())
        vals = dtm[1:].split(":")
        shift = ((int(vals[0]) * 24 + int(vals[1])) * 60 + int(vals[2])) * 60 + int(vals[3])
        return current - shift
    else:
        return int(datetime.datetime.strptime(dtm, "%Y-%m-%d %H:%M:%S").timestamp())


def construct_interval(intv):
    intv = intv.replace("PERIODS ", "")
    return [construct_timestamp(t) for t in intv.strip().split(" to ")]


def parse_query(query_file):
    query = {
        "signals": [],
        "ints": []
    }
    overwrite_items = None

    with open(query_file) as fle:
        for line in fle:
            line = line.strip()
            long_args = {
                "SELECT": ["signals", parse_signal],
                "PERIODS": ["ints", construct_interval]
            }

            for input, sve in long_args.items():
                if line.startswith(input):
                    query[sve[0]].append(sve[1](line))

            simple_args = {
                "PROXY": "proxy",
                "CACHE": "cache",
                "TO": "output_table",
                "ERROR": "error_table",
                "TIMEFORMAT": "intv_type",
                "PARSEAGGRS": "parse_aggrs",
                "RAWSAMPLING": "raw_sampling",
                "COLLECTALL": "collector"
            }
            for inp, sve in simple_args.items():
                if line.startswith(inp):
                    query[sve] = extract(line, inp)

            list_args = {
                "CLUSTERS": "clusters",
                "JOBS": "job_count",
                "TAGS": "host_tags"
            }
            for inp, sve in list_args.items():
                if line.startswith(inp):
                    args = extract(line, inp).split(",")
                    query[sve] = [extract(clst, "") for clst in args]

            if line.startswith("FROM"):
                args = extract(line, "FROM").split(":")
                query["input_type"] = extract(args[0], "")
                query["input_table"] = extract(args[1], "")

            if line.startswith("OVERWRITE"):
                args = extract(line, "OVERWRITE").split(":")
                overwrite_type = extract(args[0], "")

                overwrite_items = {INPUT_TO_MAPPER_TYPE[overwrite_type]: i for i in line.split(" ")[1].split(",")}

            if overwrite_items:
                yt.wrapper.config.set_proxy(query["proxy"])
                yt.wrapper.write_table(query["input_table"], overwrite_items)

    return query


def check_and_fix_query(query, preprocessor):

    if preprocessor:
        # FIXME: для мвп захардкодили таблицу с подами
        preprocessing_func = getattr(utils_for_preprocessing, preprocessor)
        preprocessing_func()

    required_fields = [
        "output_table", "signals", "input_table", "ints"
    ]

    for field in required_fields:
        if field not in query:
            raise TypeError("No field {0} in query".format(field))

    if "proxy" not in query:
        query["proxy"] = "hahn"
    if "collector" in query:
        if query["collector"] not in ["abcs", "walles"]:
            raise ValueError("Unknown collector value: {0}".format(query["collector"]))
        return query
    if len(query["signals"]) == 0:
        raise ValueError("Empty signal array")
    query["signals_with_hist_count"] = set()
    for sg in query["signals"]:
        if is_histogram(sg["signals"]) and len(sg["aggregation"]) > 1:
            raise TypeError("Can only pass one aggregator for histogram signals")
        if "hist" in sg["aggregation"]:
            query["signals_with_hist_count"].union(set(sg["signals"].split("||")))
    query["signals_with_hist_count"] = len(query["signals_with_hist_count"])
    if "cache" not in query:
        query["cache"] = "//tmp/" + getpass.getuser() + "/yasmcache" + str(random.randint(1000000, 9999999))
    if "clusters" not in query:
        query["clusters"] = YP_PRODUCTION_CLUSTERS
    for clst in query["clusters"]:
        if clst not in set(YP_PRODUCTION_CLUSTERS + YP_PRESTABLE_CLUSTERS + YP_TESTING_CLUSTERS):
            raise ValueError("Unknown cluster: {0}".format(clst))
    if query["input_type"] not in set(["abcs", "pods", "pods_cached", "services", "walles", "hosts", "hosts_cached"]):
        raise ValueError("Unknown input type: {0}, use one of these: abcs, pods, pods_cached, services".format(query["input_type"]))
    if "job_count" not in query or not query["job_count"]:
        query["job_count"] = [100]
    if len(query["job_count"]) > 2:
        raise ValueError("Job counts require <=2 numbers, got {0} instead".format(len(query["job_count"])))
    for i in range(len(query["job_count"])):
        query["job_count"][i] = int(query["job_count"][i])
    if "intv_type" not in query:
        query["intv_type"] = "timestamp"
    if query["intv_type"] not in ["timestamp", "string", "both"]:
        raise ValueError("Unknown date format: {0}, use one of these: timestamp, string, both".format(query["intv_type"]))
    if "parse_aggrs" not in query:
        query["parse_aggrs"] = True
    try:
        query["parse_aggrs"] = utils.str2bool(query["parse_aggrs"])
    except:
        raise ValueError("Parse arrgs: {0} is not a boolean-like value".format(query["parse_aggrs"]))
    if "raw_sampling" not in query:
        query["raw_sampling"] = 1.0
    try:
        query["raw_sampling"] = float(query["raw_sampling"])
    except:
        raise ValueError("Raw sampling parameter should be float")
    if "host_tags" not in query:
        query["host_tags"] = []
    return query


def launch(query):
    if "collector" in query:
        yt.wrapper.config.set_proxy(query["proxy"])
        if query["collector"] == "abcs":
            data = get_all_abcs()
        elif query["collector"] == "walles":
            data = get_all_rtc_walles()
        yt.wrapper.write_table(query["output_table"], data, raw=False)
        return

    utils.read_token()
    yt.wrapper.config.set_proxy(query["proxy"])

    if query["input_type"] in ["abcs", "services", "pods", "pods_cached"]:
        if query["input_type"] != "pods_cached":
            yt.wrapper.run_map(
                AbcToPodsMapper(utils.get_oauth_token(), query["clusters"]) if query["input_type"] == "abcs"
                    else ServiceToPodsMapper(utils.get_oauth_token(), query["clusters"]) if query["input_type"] == "services"
                    else PodsToPodsMapper(utils.get_oauth_token()),
                source_table = query["input_table"],
                destination_table = query["cache"],
                memory_limit = 3 * 1024 * 1024 * 1024,
                job_count = query["job_count"][0]
            )
        yt.wrapper.run_map(
            PodsToMetricsMapper(query["ints"], query["signals"], utils.get_solomon_token(), query["error_table"] != "",
                query["signals_with_hist_count"], query["intv_type"], query["parse_aggrs"], query["raw_sampling"]),
            source_table = query["cache"] if query["input_type"] != "pods_cached" else query["input_table"],
            destination_table = [query["output_table"], query["error_table"]]
                if query["error_table"] != "" else query["output_table"],
            memory_limit = 30 * 1024 * 1024 * 1024,
            job_count = query["job_count"][-1]
        )

    else:
        if query["input_type"] != "hosts_cached":
            yt.wrapper.run_map(
                WalleToHostsMapper(query["host_tags"]) if query["input_type"] == "walles"
                    else HostsToHostsMapper(query["host_tags"]),
                source_table = query["input_table"],
                destination_table = query["cache"],
                memory_limit = 3 * 1024 * 1024 * 1024,
                job_count = query["job_count"][0]
            )
        yt.wrapper.run_map(
            HostsToMetricsMapper(query["ints"], query["signals"], utils.get_solomon_token(), query["error_table"] != "",
                query["signals_with_hist_count"], query["intv_type"], query["parse_aggrs"], query["raw_sampling"]),
            source_table = query["cache"] if query["input_type"] != "hosts_cached" else query["input_table"],
            destination_table = [query["output_table"], query["error_table"]]
                if query["error_table"] != "" else query["output_table"],
            memory_limit = 30 * 1024 * 1024 * 1024,
            job_count = query["job_count"][-1]
        )


def front_pipeline_combined(query_file, preprocessor):
    launch(check_and_fix_query(parse_query(query_file), preprocessor))
