import datetime
import time
import pickle
import json
import random
import logging
import requests
import argparse
import sys

from pprint import pformat

import yt.wrapper
import infra.analytics.io_limits_pipeline.utils as utils


def parse_signal_group(signal_group):
    try:
        return utils.signals_group_mapping[signal_group]
    except KeyError:
        raise Exception("No such signal group located")


def get_existing_service_names(diff_table):
    yt.wrapper.config.set_proxy("hahn")
    existing_services = set()
    data = yt.wrapper.read_table(diff_table)

    for obs in data:
        existing_services.add(obs["nannyservice"])

    return existing_services


def unpickle_data(fname, diff_table):

    with open(fname, 'rb') as f:
        data_new = pickle.load(f)

    if diff_table:
        existing_services = get_existing_service_names(diff_table)
        return {k: [i for i in v if i["service_name"] not in existing_services] for k, v in data_new.items()}
    else:
        return data_new


def get_meta_yt_data(table_name, diff_table):
    yt.wrapper.config.set_proxy("hahn")
    data = yt.wrapper.read_table(table_name)
    result = {}

    existing_services = get_existing_service_names(diff_table)

    for obs in data:

        obj = {"dc": obs["dc"], "host": obs["host"], "pod_name": obs["pod_name"],
                               "service_name": obs["service_name"], "ssd": True, "yasm_tag": obs["yasm_tag"]}

        if obs["host"] not in result and obs["service_name"] not in existing_services:
            result[obs["host"]] = [obj]

        elif obj["host"] in result and obs["service_name"] not in existing_services:
            result[obs["host"]].append(obj)
        else:
            pass

    return result


def insert_data(lst, st_time, table_path="//home/runtimecloud/research/io_research/yasm_host_tags", appnd=True):

    yt.wrapper.config.set_proxy("hahn")
    yt.wrapper.write_table(yt.wrapper.TablePath(table_path, append=appnd), lst)
    logging.info("Inserted {0} rows, it took {1} seconds to process".format(len(lst), time.time() - st_time))


def get_processed_hosts():

    used_hosts = set()

    with open("processed_hosts_solomon_extraction.txt") as f:
        for line in f:
            used_hosts.add(line.replace('\n', ''))

    return used_hosts


def add_null_signal(signals, result):
    """
    итерируем по списку диктов, если нет нужных ключей то добавляем со значением none
    :param signals:
    :param result:
    :return:
    """

    for sig in signals:
        if sig not in result:
            result[sig] = None

    return result


def construct_rows_from_signal(host, tmstmp_st, tmstmp_end, hosts_objs, signals, process_balancers, deploy=False):

    final_result = []

    base_url = "https://solomon.yandex-team.ru/api/v2/projects/yasm_{0}/sensors/data"
    expression_template = '''signal="{0}", host="{1}", cluster="host_*", nanny="{2}", ctype="{3}"'''
    hdrs = {
        "Authorization": "OAuth {}".format(utils.get_solomon_token()),
        "Content-Type": "application/json;charset=UTF-8",
        "Accept": "application/json"
    }

    fmt = "%Y-%m-%dT%H:%M:%S.%fZ"

    start, end = (datetime.datetime.fromtimestamp(tmstmp_st) - datetime.timedelta(hours=3)).strftime(fmt), \
                 datetime.datetime.fromtimestamp(tmstmp_end).strftime(fmt)

    pods = hosts_objs[host]
    for pod in pods:

        balancer = "balancer" in pod["service_name"]

        if (balancer and not process_balancers):
            continue

        try:
            if isinstance(pod["yasm_tag"], str):
                parsed_tags = parse_yasm_tags(pod["yasm_tag"])
            else:
                parsed_tags = pod["yasm_tag"]

            pod_url = pod["pod_name"] + "." + pod["dc"] + ".yp-c.yandex.net"

            try:
                itype, ctype, nanny = parsed_tags.get("itype"), parsed_tags.get("ctype", ""), parsed_tags.get("nanny",
                                                                                                          "")
                itype = parsed_tags.get("itype")
            except AttributeError:
                logging.info("Epty tags detected on pod {0} for service {1}".format(pod_url, pod["service_name"]))
                continue

            if deploy is False:
                expr = "{" + expression_template.format(signals, pod_url, nanny, ctype) + "}"
            else:

                additional_tagstring = ""

                if parsed_tags.get("prj"):
                    additional_tagstring += ', prj="{0}"'.format(parsed_tags.get("prj"))
                if parsed_tags.get("deploy_unit"):
                    additional_tagstring += ', deploy_unit="{0}"'.format(parsed_tags.get("deploy_unit"))
                if parsed_tags.get("stage"):
                    additional_tagstring += ', stage="{0}"'.format(parsed_tags.get("stage"))

                expr = "{" + expression_template.format(signals, pod["host"], nanny, ctype) + additional_tagstring + "}"

            if nanny == "":
                expr = expr.replace(', nanny=""', '')
            if ctype == "":
                expr = expr.replace(', ctype=""', '')

            dowsampling = {"gridMillis": 5000, "aggregation": "LAST", "disabled": True}

            payload = {
                "program": expr,
                "from": start,
                "to": end,
                "downsampling": dowsampling
            }

            try:
                reslt = requests.post(base_url.format(itype), data=json.dumps(payload), headers=hdrs)
                reslt = reslt.json()
            except:
                logging.info("No data extracted from solomon on pod {0} for service {1}".format(
                    pod_url, pod["service_name"]))
                reslt = {}

            try:
                vector = reslt['vector']
                if not nanny:
                    nanny = pod["service_name"]

                final_result.extend(parse_vector(vector, host, pod_url, nanny, pod['dc']))

            except:
                logging.info("Error while parsing vector on pod {0} for service {1}".format(pod_url,
                                                                                            pod["service_name"]))

        except:
            logging.info("Error while parsing pod for service {0}".format(pod["service_name"]))

    logging.info("this many datapoints {0} processed for host {1}".format(len(final_result), host))

    return final_result


def parse_yasm_tags(tagstr):
    try:
        splitted = {i.split("=")[0]: i.split("=")[1] for i in tagstr.split(";")}
        return splitted
    except IndexError:
        return None


# balancer_report-report-service_total-input_size_hgram
def parse_vector(vector, host, pod, nanny, dc):
    result = []

    if len(vector) > 0 and vector[0].get("timeseries"):

        try:
            for tmsmpnum, tmstmp in enumerate(vector[0]["timeseries"]["timestamps"]):
                obs = {}
                obs["timestamp"] = int(tmstmp / 1000)

                for signum in range(len(vector)):

                    # проверим, является ли сигнал гистограмный вот таким странным образом
                    try:
                        if 'buckets' not in vector[signum]["timeseries"]["values"][tmsmpnum]:
                            try:
                                obs[vector[signum]["timeseries"]["labels"]["signal"]] = 0 if \
                                    vector[signum]["timeseries"] \
                                ["values"][tmsmpnum]["sum"] is None else \
                                    vector[signum]["timeseries"]["values"][tmsmpnum]["sum"]
                            except IndexError:
                                pass
                        else:
                            if len(vector[signum]["timeseries"]["values"][tmsmpnum]["buckets"]) == 0:
                                obs[vector[signum]['timeseries']['labels']['signal']] = float(0)
                            else:
                                if len(vector[signum]["timeseries"]["values"][tmsmpnum]["buckets"]) == 1 and \
                                        vector[signum]["timeseries"]["values"][tmsmpnum]["buckets"][0] >= 1:

                                    power = vector[signum]["timeseries"]["values"][tmsmpnum]["startPower"] + 1
                                    base = vector[signum]["timeseries"]["values"][tmsmpnum]["base"]

                                    obs[vector[signum]["timeseries"]["labels"]["signal"]] = float(round(base**power, 0))

                                elif len(vector[signum]["timeseries"]["values"][tmsmpnum]["buckets"]) > 1:
                                    # считаем гистограммное среднее
                                    buckets = vector[signum]["timeseries"]["values"][tmsmpnum]["buckets"]

                                    bounds, total_sum = [], 0
                                    # FIXME: не будет работать для сигналов, которые могут принимать отрицательные значения
                                    lower_bound = 0

                                    for upper_bound in vector[signum]["timeseries"]["values"][tmsmpnum]["bounds"]:
                                        bounds.append((lower_bound, upper_bound))
                                        lower_bound = upper_bound

                                    for bucknum, buckval in enumerate(buckets):
                                        if buckval > 0:
                                            total_sum += buckval * ((bounds[bucknum][0] + bounds[bucknum][1]) / 2)

                                    if total_sum > 0:
                                        total_sum = total_sum / sum(buckets)

                                    obs[vector[signum]["timeseries"]["labels"]["signal"]] = float(total_sum)

                    except IndexError:
                        pass

                obs["host"] = host
                obs["pod"] = pod
                obs["nannyservice"] = nanny
                obs["dc"] = dc

                result.append(obs)
        except KeyError:
            logging.info("Error while parsing vector for service {0}".format(nanny))

    else:
        pass

    return result


def process_solomon_data(host_chunk_size, tablename, filename, days, signals, deploy, input_table, diff_table,
                         test_run, process_balancers):

    time_itervals = generate_time_intervals(days)

    if not input_table:
        host_objs = unpickle_data(filename, diff_table)
    else:
        host_objs = get_meta_yt_data(input_table, diff_table)

    if test_run:
        host_objs = {i[0]: i[1] for i in random.sample(host_objs.items(), 50)}

    iron_hosts = list(host_objs.keys())
    chunked_hosts = split_chunks(iron_hosts, host_chunk_size)

    for ch in chunked_hosts:
        chunk_data = []
        st = time.time()

        for hst in ch:

            for start, end in time_itervals:
                res = construct_rows_from_signal(hst, start, end, host_objs, signals, process_balancers, deploy=deploy)
                chunk_data.extend(res)

        try:
            insert_data(chunk_data, st, table_path=tablename)
        except:
            time.sleep(30)
            insert_data(chunk_data, st, table_path=tablename)


def split_chunks(lst, chnksize):

    for i in range(0, len(lst), chnksize):
        yield lst[i:i + chnksize]


def generate_time_intervals(days):
    intervals = []

    end = int(datetime.datetime.now().timestamp()) - 60*5

    for i in range(days):
        start = end - 86400
        intervals.append((start, end))
        end = start - 1

    return intervals


if __name__ == '__main__':
    logging.basicConfig(
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler("solomon_yasm_data_processing.log")
        ],
        level=logging.INFO,
        format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

    parser = argparse.ArgumentParser()
    parser.add_argument("--filename", type=str, default=None)
    parser.add_argument("--source_table", type=str, default=None)
    parser.add_argument("--tablename", type=str)
    parser.add_argument("--days", type=int, default=7)
    parser.add_argument("--chunknum", type=int, default=35)
    parser.add_argument("--process_balancers", type=utils.str2bool, default=False)
    parser.add_argument("--signals_group", type=str, default="hdd")
    parser.add_argument("--nirvana", type=utils.str2bool, default="false")
    parser.add_argument("--deploy", type=utils.str2bool, default="false")
    parser.add_argument("--diff_table", type=str, default=None)
    parser.add_argument("--test_run", type=utils.str2bool, default="false")
    args = parser.parse_args()

    st = time.time()
    utils.read_solomon_token()
    utils.read_yql_token()
    signals = parse_signal_group(args.signals_group)
    process_solomon_data(args.chunknum, args.tablename, args.filename, args.days, signals, args.deploy,
                         args.source_table, args.diff_table, args.test_run, args.process_balancers)
    logging.info('this much time it took total: ', str(time.time() - st))
