from collections import defaultdict
import tempfile
import re
import subprocess
import shutil
import functools
import codecs
import datetime
import colour
import os
import tqdm
from s2sphere import Cell, CellId, LatLng
from fastkml import kml, styles
from shapely.geometry import Polygon
import numpy as np
from de2.common import Exporter, filter_keys


colors = list(colour.Color("red").range_to(colour.Color("green"), 10))
colors_bad = colors[::-1]
colors_non_opinionated = list(
    colour.Color("white").range_to(colour.Color("blue"), 10)
)
re_num = re.compile("[0-9]+")

metrics_names = {
    "NetworkType": "Тип мобильной сети",
    "bufferizations_duration_per_min": "Длительность буферизаций на мин. просмотра (с)",
    "bufferizations_duration_per_min_position": "Длительность буферизаций на мин. просмотра (с) — позиция",
    "bufferizations_duration_per_view": "Длительность буферизаций на просмотр (с)",
    "bufferizations_per_min": "Количество буферизаций на мин. просмотра (шт)",
    "bufferizations_per_view": "Количество буферизаций на просмотр (шт)",
    "first_buffer_duration_avg": "Средняя длительность первой буферизации (с)",
    "first_buffer_duration_p50": "Медианная длительность первой буферизации (с)",
    "first_buffer_throughput_avg": "Средняя скорость скачивания чанка во время первой буферизации (кбит/сек)",
    "first_buffer_throughput_p50": "Медианная скорость скачивания чанка во время первой буферизации (кбит/сек)",
    "first_buffer_more_than_3sec_share": "Доля первых буферизаций длительностью больше 3 секунд",
    "operator": "Мобильный оператор",
    "refuse_share": "Доля отказов от количества просмотров",
    "region": "Регион",
    "lat_lon": "Широта и долгота точек",
    "radio_devices": "Количество радио-девайсов",
    "radio_measurements": "Количество радио-измерений",
    "res_avg": "Среднее разрешение чанка",
    "SignalStrength_avg": "Средний уровень сигнала",
    "SignalStrength_avg_position": "Средний уровень сигнала — позиция",
    "SignalStrength_p10": "10-я перцентиль уровня сигнала",
    "SignalStrength_p50": "Медиана уровня сигнала",
    "SignalStrength_p50_position": "Медиана уровня сигнала — позиция",
    "SignalStrength_p90": "90-я перцентиль уровня сигнала",
    "s2_level": "Уровень клетки s2",
    "throughput_p50": "Медианная скорость скачивания чанка (кбит/сек)",
    "throughput_p50_position": "Медианная скорость скачивания чанка (кбит/сек) — позиция",
    "throughtput_avg": "Средняя скорость скачивания чанка (кбит/сек)",
    "traffic_gb": "Входящий видеотрафик (Гб)",
    "tvt": "Общее время просмотра (сек)",
    "video_chunks": "Количество видео-чанков (шт)",
    "video_devices": "Количество видео-девайсов",
    "video_sessions": "Количество видео-просмотров",
}
for value in [144, 240, 360, 480, 720, 1080]:
    metrics_names[
        "res{}_share".format(value)
    ] = "Доля чанков разрешения {}p".format(value)

ns = "{http://www.opengis.net/kml/2.2}"


class IdGen:
    def __init__(self):
        self.id_ = 0

    def __call__(self):
        self.id_ += 1
        return "id{}".format(self.id_)


def wrap_hex(clr):
    clr = clr.replace("#", "")
    if len(clr) == 3:
        clr = clr[0] * 2 + clr[1] * 2 + clr[2] * 2
    return clr[4:6] + clr[2:4] + clr[0:2]


def get_bucket(value, percentiles):
    bucket = 0
    while bucket < len(percentiles) and value > percentiles[bucket]:
        bucket += 1
    return bucket


def sort_key(s):
    num = re_num.search(s)
    if not num:
        return (0, s)
    return (int(num.group(0)), s)


def roundwrap(value):
    if not isinstance(value, float):
        return value
    return round(value, 4)


def get_description(element, func=None):
    return "\n".join(
        sorted(
            [
                "{}: {}".format(metrics_names[key], roundwrap(element[key]))
                for key in sorted(element)
                if key in metrics_names
            ],
            key=sort_key,
        )
    )


coords_cache = {}


metrics_names_to_color_scales = {
    "bufferizations_duration_per_min": colors_bad,
    "bufferizations_duration_per_min_position": colors_bad,
    "bufferizations_per_min": colors_bad,
    "first_buffer_duration_avg": colors_bad,
    "first_buffer_duration_p50": colors_bad,
    "first_buffer_throughput_avg": colors,
    "first_buffer_throughput_p50": colors,
    "refuse_share": colors,
    "throughput_p50": colors,
    "traffic_gb": colors,
    "video_devices": colors,
    "video_sessions": colors,
}


def get_color_scale(key):
    if key in metrics_names_to_color_scales:
        return metrics_names_to_color_scales[key]
    return colors


@functools.lru_cache(maxsize=16384)
def get_coordinates_raw(cellid):
    return [
        LatLng.from_point(Cell(CellId(cellid)).get_vertex(v)) for v in range(4)
    ]


def get_coordinates(cellid):
    if cellid in coords_cache:
        return coords_cache[cellid]
    ll = get_coordinates_raw(cellid)
    coords = [(l.lng().degrees, l.lat().degrees) for l in ll]
    coords_cache[cellid] = coords
    return coords


class BaseKmlExporter(Exporter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.polygons = defaultdict(list)
        self.values_for_percentiles = defaultdict(list)
        self.scales = {}

    def get_color(self, key, value):
        if key not in self.scales:
            values = self.values_for_percentiles[key]
            self.scales[key] = {
                "min": min(values),
                "max": max(values),
                "percentiles": list(
                    np.percentile(values, list(range(10, 100, 10)))
                ),
            }
        bucket = get_bucket(value, self.scales[key]["percentiles"])
        metric_name = key[1]
        colors = get_color_scale(metric_name)
        return "aa" + wrap_hex(colors[bucket].get_hex())

    def process(self, row):
        keys = list(row.keys())
        keys = filter_keys(keys, self.config)
        description = get_description(row, keys)  # TODO: adapt for v2
        mapped_keys = [k for k in keys if k in self.config["keys_for_mapping"]]
        for k in mapped_keys:
            path = (
                row.get("region") or None,
                k,
                eval(self.config["subfolder_predicate"]),
            )
            value = row[k]
            if not value:
                continue
            self.polygons[path].append(
                {
                    "s2_value": row["s2_value"],
                    "description": description,
                    "metric_value": value,
                }
            )
            self.values_for_percentiles[path].append(value)


class KmzExporter(BaseKmlExporter):
    @staticmethod
    def new_doc():
        idgen = IdGen()
        kmlfile = kml.KML()
        doc = kml.Document(ns, idgen())
        kmlfile.append(doc)
        return (idgen, kmlfile, doc)

    def create_folders_if_necessary(self, folder_dict, key, folderkey):
        if folderkey not in folder_dict:
            folder = kml.Folder(
                ns, self.idgen(), metrics_names[folderkey[1]], ""
            )
            folder.visibility = 0
            self.doc.append(folder)
            folder_dict[folderkey] = folder
        if key not in folder_dict:
            folder = kml.Folder(ns, self.idgen(), key[2], "")
            folder.visibility = 0
            folder_dict[folderkey].append(folder)
            folder_dict[key] = folder

    def create_style(self, color):
        style_id = self.idgen()
        style_ = styles.Style(id=style_id)
        linestyle = styles.LineStyle(
            ns, self.idgen(), color="bbeeeeee", width=2
        )
        polystyle = styles.PolyStyle(ns, self.idgen(), color=color)
        style_.append_style(polystyle)
        style_.append_style(linestyle)
        self.doc.append_style(style_)
        self.styles_dict[color] = style_id

    def create_polygon(self, polygon, folder, key):
        color = self.get_color(key, polygon["metric_value"])
        if color not in self.styles_dict:
            self.create_style(color)
        style = self.styles_dict[color]
        coords = get_coordinates(polygon["s2_value"])
        pol = kml.Placemark(
            ns,
            self.idgen(),
            str(polygon["s2_value"]),
            polygon["description"],
            styleUrl="#{}".format(style),
        )
        pol.visibility = 0
        pol.geometry = Polygon(coords)
        folder.append(pol)

    def _save_kmlfile(self, filename):
        print("saving {}".format(filename))
        now = datetime.datetime.now()
        cwd = os.getcwd()
        with tempfile.TemporaryDirectory() as tmpdirname:
            os.chdir(tmpdirname)
            to_string = self.kmlfile.to_string()
            with open("doc.kml", "w") as f:
                f.write(to_string)
            kmz_fn = "{}.kmz".format(filename)
            subprocess.check_output(["zip", kmz_fn, "doc.kml"])
            if os.path.exists(os.path.join(self.wd, kmz_fn)):
                os.unlink(os.path.join(self.wd, kmz_fn))
            shutil.move(kmz_fn, self.wd)
            os.chdir(cwd)
        elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
        print("data saved in {} minutes".format(elapsed))

    def save_data(self):
        print("started saving data")
        now = datetime.datetime.now()
        self.kmlfile = None
        self.styles_dict = {}
        prev_key = None
        folder_dict = {}
        keys = sorted(self.polygons.keys())
        for key in tqdm.tqdm(keys):
            if not self.kmlfile or key[0] != prev_key[0]:
                if self.kmlfile:
                    self._save_kmlfile(prev_key[0])
                self.idgen, self.kmlfile, self.doc = self.new_doc()
                self.styles_dict = {}
            folderkey = (key[0], key[1])
            self.create_folders_if_necessary(folder_dict, key, folderkey)
            folder = folder_dict[key]
            for polygon in self.polygons[key]:
                self.create_polygon(polygon, folder, key)
            prev_key = key
        self._save_kmlfile(prev_key[0])
        elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
        print("total data saving time: {}".format(elapsed))


class KmlExporter(KmzExporter):
    def _save_kmlfile(self, filename):
        print("saving {}".format(filename))
        now = datetime.datetime.now()
        to_string = self.kmlfile.to_string()
        with open("{}/{}.kml".format(self.wd, filename), "w") as f:
            f.write(to_string)
        elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
        print("data saved in {} minutes".format(elapsed))

    def save_data(self):
        print("started saving data")
        now = datetime.datetime.now()
        self.kmlfile = None
        self.styles_dict = {}
        prev_key = None
        keys = sorted(self.polygons.keys())
        for key in tqdm.tqdm(keys):
            if not self.kmlfile or key != prev_key:
                if self.kmlfile:
                    self._save_kmlfile(
                        "{}_{}".format(prev_key[1], prev_key[2])
                    )
                self.idgen, self.kmlfile, self.doc = self.new_doc()
                self.styles_dict = {}
            for polygon in self.polygons[key]:
                self.create_polygon(polygon, self.kmlfile, key)
            prev_key = key
        self._save_kmlfile("{}_{}".format(prev_key[1], prev_key[2]))
        elapsed = (datetime.datetime.now() - now).total_seconds() / 60.0
        print("total data saving time: {}".format(elapsed))
