#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
import os
import argparse
import json
import tempfile
import subprocess
import shutil
import codecs
import re
import datetime
import numpy as np
import tqdm
import colour
import pyexcel
import yt.wrapper as yt
from s2sphere import Cell, CellId, LatLng
from fastkml import kml, styles
from shapely.geometry import Polygon

# from simplekml import Kml, Color


colors = list(colour.Color("red").range_to(colour.Color("green"), 10))
colors_non_opinionated = list(
    colour.Color("white").range_to(colour.Color("blue"), 10)
)
scales = {"s2_p13": {}, "s2_p14": {}, "s2_p15": {}}
scales_folders = {"s2_p13": {}, "s2_p14": {}, "s2_p15": {}}
re_num = re.compile("[0-9]+")
metrics_names = {
    "traffic_gbytes": "Трафик (Гб)",
    "views": "Количество видео просмотров (шт)",
    "kbit_sec": "Средняя скорость видео (кбит/сек)",
    "kbit_sec_median": "Медианная скорость видео (кбит/сек)",
    "avg_resolution": "Среднее разрешение чанка",
    "refuse_share": "Видео-отказы (% от кол-ва просмотров)",
    "first_bufferization_kbit_sec": "Скорость первой буферизации (кбит/сек)",
    "avg_first_buffer_duration": "Длительность первой буферизации (с)",
    "bufferizations_duration_per_min": "Длительность буферизаций на мин. просмотра (с)",
    "bufferizations_duration_per_view": "Длительность буферизаций за просмотр (с)",
    "bufferizations_per_min": "Количество буферизаций на мин. просмотра (шт)",
    "bufferizations_per_view": "Количество буферизаций за просмотр (шт)",
}
for value in [144, 240, 360, 480, 720, 1080]:
    metrics_names["res{}_share".format(value)] = "Доля разрешения {}p".format(
        value
    )
for value in range(100, 5100, 100):
    metrics_names[
        "kbit_sec_less_than_{}_share".format(value)
    ] = "Доля скорости ниже {} кбит/сек".format(value)
styles_dict = {}


class IdGen:
    def __init__(self):
        self.id_ = 0

    def __call__(self):
        self.id_ += 1
        return "id{}".format(self.id_)


def wrap_hex(clr):
    clr = clr.replace("#", "")
    if len(clr) == 3:
        clr = clr[0] * 2 + clr[1] * 2 + clr[2] * 2
    return clr[4:6] + clr[2:4] + clr[0:2]


def get_bucket(value, percentiles):
    bucket = 0
    while bucket < len(percentiles) and value > percentiles[bucket]:
        bucket += 1
    return bucket


def get_color(element, key):
    level = element["grouping_level"]
    scale = scales[level][key]
    value = element[key]
    bucket = get_bucket(value, scale["percentiles"])
    if key.startswith("res") and key.endswith("_share"):
        colors_ = colors_non_opinionated
    elif "refuse" in key or (
        "buffer" in key and key != "first_bufferization_kbit_sec"
    ):
        colors_ = colors[::-1]
    else:
        colors_ = colors
    # return Color.hexa(wrap_hex(colors_[bucket].get_hex()))
    return "aa" + wrap_hex(colors_[bucket].get_hex())


def sort_key(s):
    num = re_num.search(s)
    if not num:
        return (0, s)
    return (int(num.group(0)), s)


def get_description(element, func=None):
    if func is None:
        func = lambda key: "less_than" not in key and key != "city"
    return "\n".join(
        sorted(
            [
                "{}: {}".format(metrics_names[key], element[key])
                for key in element
                if key in metrics_names and func(key)
            ],
            key=sort_key,
        )
    )


def get_coordinates(cellid):
    ll = [
        LatLng.from_point(Cell(CellId(cellid)).get_vertex(v)) for v in range(4)
    ]
    return [(l.lng().degrees, l.lat().degrees) for l in ll]


def read_jsonlines(filename):
    with open(filename) as f:
        lines = f.read().split("\n")
    return [json.loads(l.strip()) for l in lines if l]


def write_jsonlines(data, filename):
    with open(filename, "w") as f:
        f.write("\n".join(json.dumps(l) for l in data) + "\n")


def wrap_utf8(s):
    if isinstance(s, bytes):
        return s.decode("utf8")
    return s


def wrap_utf8_dict(dct):
    return {wrap_utf8(k): wrap_utf8(v) for k, v in dct.items()}


def load_data(tablenames, cwd, redownload=False):
    now = datetime.datetime.now()
    print("started loading data for {}".format(tablenames))
    result = []
    for tablename in tablenames:
        jsonlines_filename = os.path.join(cwd, get_json_name(tablename))
        if not os.path.exists(jsonlines_filename) or redownload:
            data = [wrap_utf8_dict(rec) for rec in yt.read_table(tablename)]
            write_jsonlines(data, jsonlines_filename)
        result.extend(read_jsonlines(jsonlines_filename))
    elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
    print("data read in {} minutes".format(elapsed))
    return result


def get_s2_level(element, style="new"):
    if style == "old":
        key = [x for x in element if x.startswith("s2")][0]
        return element[key]
    else:
        return element["grouping_value"]


def generate_filename(tablenames):
    fns = [tablename.split("/")[-1] for tablename in tablenames]
    return "_".join(fns)


def get_json_name(tablename):
    return tablename.split("/")[-1] + ".json"


def csv_save_as(filename, data):
    with codecs.open(filename, "w", "utf8") as f:
        for row in data:
            f.write(";".join(map(format, row)) + "\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", "-i", nargs="+")
    parser.add_argument("--config", "-c")
    parser.add_argument("--output_formats", "-of", default="xlsx,csv")
    parser.add_argument("--redownload", action="store_true")
    args = vars(parser.parse_args())
    args["output_formats"] = args["output_formats"].split(",")

    if args.get("config"):
        with open(args["config"]) as f:
            obj = json.load(f)
        args.update(obj)
    cwd = os.path.abspath(os.getcwd())

    data = load_data(args["input"], cwd, redownload=args.get("redownload"))
    if args.get("config"):
        filename = args["config"].replace(".json", "")
    else:
        filename = generate_filename(args["input"])
    print("filename is {}".format(filename))

    idgen = IdGen()

    ns = "{http://www.opengis.net/kml/2.2}"

    with tempfile.TemporaryDirectory() as tmpdirname:
        os.chdir(tmpdirname)
        kmlfile = kml.KML()
        doc = kml.Document(ns, idgen())
        kmlfile.append(doc)

        scale_keys = [
            k
            for k in data[0]
            if isinstance(data[0][k], (int, float, type(None)))
            and k != "grouping_value"
            and "less_than" not in k
        ]
        for k in sorted(scale_keys, key=lambda k: sort_key(metrics_names[k])):
            mother_folder = kml.Folder(ns, idgen(), metrics_names[k], "")
            doc.append(mother_folder)
            for level in ["s2_p13", "s2_p14", "s2_p15"]:
                values = np.array(
                    [
                        el[k]
                        for el in data
                        if isinstance(el[k], (int, float))
                        and el["grouping_level"] == level
                    ]
                )
                if len(values) == 0:
                    continue
                scales[level][k] = {
                    "min": min(values),
                    "max": max(values),
                    "percentiles": list(
                        np.percentile(values, list(range(10, 100, 10)))
                    ),
                }
                folder = kml.Folder(ns, idgen(), level, "")
                mother_folder.append(folder)
                scales_folders[level][k] = folder

        csv_header = ["s2_id", "grouping_level"]
        for i in range(4):
            csv_header.extend(["lon{}".format(i + 1), "lat{}".format(i + 1)])
        csv_keys = sorted(data[0].keys())
        for key in csv_keys:
            csv_header.append(key)
        csv_data = [csv_header]
        for element in tqdm.tqdm(data):
            level = element["grouping_level"]
            coords = get_coordinates(get_s2_level(element))
            name = str(get_s2_level(element))
            csv_row = [name, level]
            for pair in coords:
                csv_row.extend(list(pair))
            for key in csv_keys:
                csv_row.append(element.get(key, ""))
            base_description = get_description(element)
            keys = {key for key in element if "less_than" not in key}
            for key in keys:
                if key not in scale_keys:
                    continue
                if (
                    key == "avg_resolution"
                    or key.startswith("res")
                    and key.endswith("share")
                ):
                    description = get_description(
                        element,
                        func=lambda key: key == "avg_resolution"
                        or (key.startswith("res") and key.endswith("share")),
                    )
                elif "kbit_sec" in key:
                    description = get_description(
                        element,
                        func=lambda key: "kbit_sec" in key
                        or "less_than" in key,
                    )
                else:
                    description = base_description
                color = get_color(element, key)
                if color not in styles_dict:
                    style_id = idgen()
                    style_ = styles.Style(id=style_id)
                    linestyle = styles.LineStyle(
                        ns, idgen(), color="bbeeeeee", width=2
                    )
                    polystyle = styles.PolyStyle(ns, idgen(), color=color)
                    style_.append_style(polystyle)
                    style_.append_style(linestyle)
                    doc.append_style(style_)
                    styles_dict[color] = style_id
                style = styles_dict[color]
                pol = kml.Placemark(
                    ns,
                    idgen(),
                    name,
                    description,
                    styleUrl="#{}".format(style),
                )
                pol.geometry = Polygon(coords)
                scales_folders[level][key].append(pol)
            csv_data.append(csv_row)
        if "xlsx" in args["output_formats"]:
            xlsx_fn = "{}.xlsx".format(filename)
            pyexcel.save_as(dest_file_name=xlsx_fn, array=csv_data)
            if os.path.exists(os.path.join(cwd, xlsx_fn)):
                os.unlink(os.path.join(cwd, xlsx_fn))
            shutil.move(xlsx_fn, cwd)
        if "csv" in args["output_formats"]:
            csv_fn = "{}.csv".format(filename)
            csv_save_as(csv_fn, csv_data)
            if os.path.exists(os.path.join(cwd, csv_fn)):
                os.unlink(os.path.join(cwd, csv_fn))
            shutil.move(csv_fn, cwd)
        now = datetime.datetime.now()
        # kml.save("doc.kml", format=False)
        with open("doc.kml", "w") as f:
            f.write(kmlfile.to_string())
        elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
        print("data saved in {} minutes".format(elapsed))
        kmz_fn = "{}.kmz".format(filename)
        subprocess.check_output(["zip", kmz_fn, "doc.kml"])
        if os.path.exists(os.path.join(cwd, kmz_fn)):
            os.unlink(os.path.join(cwd, kmz_fn))
        shutil.move(kmz_fn, cwd)


if __name__ == "__main__":
    main()
