#!/usr/bin/env python3

import argparse
import itertools
import json
import os
import re
import requests
from collections import defaultdict
from dataclasses import dataclass, asdict
from typing import Final, Tuple, Dict, List, ClassVar
from time import time
from sys import stderr
import urllib.parse

try:
    from numpy import percentile
    from tqdm import tqdm
    from yasmapi3 import GolovanRequest
except ImportError as e:
    print(f"ImportError: {e}")
    print("Try: pip3 install -i https://pypi.yandex-team.ru/simple numpy tqdm yasmapi3")
    exit(1)

RANGE: Final = 30 * 24 * 60 * 60
PERIOD: Final = 1 * 60 * 60
VALID_YASM_PERIODS: Final = [
    5,
    300,
    3600,
    10800,
    21600,
    43200,
    86400,
]
assert PERIOD in VALID_YASM_PERIODS


def format_percent(x: float) -> str:
    if x < 100:
        return "%.2g" % x
    else:
        return "%d" % int(x)


def optimal_to_color(k: float) -> Tuple[str, str]:
    if k < 0:
        return "green", "0, 255, 0"
    if k > 0:
        return "red", "255, 0, 0"
    return "yellow", "255, 255, 0"


@dataclass
class ResourceInfo:
    yasm_url: str
    usage: float
    should_add: float


@dataclass
class QloudHostsGroup:
    project: str
    env: str
    component: str = None

    def __str__(self):
        res = f"{self.project}.{self.env}"
        res += f".{self.component}" if self.component else ""
        return res

    def __hash__(self):
        return hash(str(self))

    @staticmethod
    def list(args):
        return [
            QloudHostsGroup(prj, env, comp)
            for prj, env, comp in itertools.product(args.projects, args.envs, args.components)
        ]

    @staticmethod
    def hosts():
        return "QLOUD"


class QloudResources:
    RESOURCE_TO_SIGNAL_NAME: Final = {
        "CPU": "quant(portoinst-cpu_guarantee_usage_perc_hgram,0.95)",
        "RAM": "mul(quant(portoinst-anon_limit_usage_perc_hgram,0.95),div(portoinst-memory_limit_gb_tmmv,portoinst-memory_guarantee_gb_tmmv))",
        "RX": "mul(quant(portoinst-net_rx_utilization_hgram,0.95),div(portoinst-net_limit_mb_summ,portoinst-net_guarantee_mb_summ))",
        "TX": "mul(quant(portoinst-net_tx_utilization_hgram,0.95),div(portoinst-net_limit_mb_summ,portoinst-net_guarantee_mb_summ))",
        "DISK_EPH": "portoinst-volume_/ephemeral_usage_perc_txxx",
        "DISK_ROOT": "portoinst-volume_root_usage_perc_txxx",
    }
    TAG_PATTERN: Final = "itype=qloud;prj=mail.{project}.{env};{optional_component}ctype=unknown"
    GOLOVAN_URL: Final = "https://yasm.yandex-team.ru/chart/hosts=QLOUD;graphs={{{signal},const(0),const({usage}),const({limit_on_dc_fail}),const(100)}};{tag}/?from={st}&to={et}"

    @staticmethod
    def list():
        return QloudResources.RESOURCE_TO_SIGNAL_NAME.keys()

    @staticmethod
    def resource_tag(resource: str, hosts: QloudHostsGroup):
        optional_component = "component=%s;" % hosts.component if hosts.component else ""
        return QloudResources.TAG_PATTERN.format(
            project=hosts.project, env=hosts.env, optional_component=optional_component
        )

    @staticmethod
    def resource_signal(resource: str, hosts: QloudHostsGroup):
        signal_name = QloudResources.RESOURCE_TO_SIGNAL_NAME.get(resource)
        return "{}:{}".format(QloudResources.resource_tag(resource, hosts), signal_name)

    @staticmethod
    def resource_url(
        resource: str,
        hosts: QloudHostsGroup,
        usage: float,
        start_time: str,
        end_time: str,
        dc_count: float,
    ):
        signal_name = QloudResources.RESOURCE_TO_SIGNAL_NAME.get(resource)
        limit_on_dc_fail = 100 * (dc_count - 1) / dc_count
        url_templ = QloudResources.GOLOVAN_URL.format(
            signal=urllib.parse.quote(signal_name, safe=""),
            tag=QloudResources.resource_tag(resource, hosts),
            usage=usage,
            limit_on_dc_fail=limit_on_dc_fail,
            st=int(start_time * 1000),
            et=int(end_time * 1000),
        )
        return url_templ

    @staticmethod
    def resource_info(
        resource: str,
        group: QloudHostsGroup,
        values: List[float],
        start_time: float,
        end_time: float,
        target: float,
        dc_count: float,
        **kwargs,
    ):
        usage = percentile(values, target)
        assert usage <= max(values)
        url = QloudResources.resource_url(resource, group, usage, start_time, end_time, dc_count)
        should_add = usage * dc_count / (dc_count - 1) - 100
        return ResourceInfo(url, usage, should_add)


class ClusterInfoReader:
    def __init__(self, oauthToken):
        self.IamTokenUrl = "https://iam.cloud.yandex-team.ru/v1/tokens"
        self.ClustersListUrl = "https://gw.db.yandex-team.ru/managed-postgresql/v1/clusters"
        self.IamToken = None
        self.OauthToken = oauthToken

    def get_clusters(self, folderId):
        url = self.ClustersListUrl
        params = {"folder_id": folderId}
        headers = {"Authorization": "Bearer {token}".format(token=self._iam_token())}
        resp = requests.get(url, params=params, headers=headers, timeout=60, verify=False)
        resp.raise_for_status()
        return resp.json()["clusters"]

    def _iam_token(self):
        if not self.IamToken:
            self.IamToken = self._get_iam_token(self._oauth_token())

        return self.IamToken

    def _get_iam_token(self, oauthToken):
        url = self.IamTokenUrl
        headers = {"Content-Type": "application/json"}
        data = json.dumps({"yandexPassportOauthToken": oauthToken})
        resp = requests.post(url, headers=headers, data=data, timeout=30, verify=False)
        resp.raise_for_status()

        return json.loads(resp.text)["iamToken"]

    def _oauth_token(self):
        return self.OauthToken


@dataclass
class MdbCluster:
    folder_name: str
    id: str
    name: str
    max_connections: int
    database: str
    MDB_FOLDERS: ClassVar[Dict] = {
        "xiva": {
            "id": "foocn84l3bfvv91tsdkb",
            "cluster_regex": {
                "xstore": r"^xiva_xstore_(production_[0-9]{2}|corp)$",
                "xtable": r"^xiva_xtable_(production_[0-9]{2}|corp)$",
                "xconf": r"^xiva_conf$",
            },
        },
        "collectors": {
            "id": "foosaanhpqhofon45s8d",
            "cluster_regex": {
                "rpopdb": r"^rpopdb_production_[0-9]{2}$",
                "rpopdb_transfer": r"^rpopdb_transfer[0-9]{2}$",
            },
        },
        "tractor": {
            "id": "foom5upuus069lavolqb",
            "cluster_regex": {
                "production": r"tractor_disk_production",
            },
        },
    }
    TOKEN_FILE_PATH: ClassVar[str] = "~/.yc.token"

    def __str__(self):
        res = f"{self.folder_name}.{self.name}"
        return res

    def __hash__(self):
        return hash(str(self))

    @staticmethod
    def list(args):
        try:
            yc_token = open(os.path.expanduser(MdbCluster.TOKEN_FILE_PATH), "r").read().strip()
        except FileNotFoundError:
            print("ERROR: no yc token, read help", file=stderr)
            exit(1)

        cluster_reader = ClusterInfoReader(oauthToken=yc_token)

        clusters = []

        for folder_name in args.folders:
            folder_cfg = MdbCluster.MDB_FOLDERS.get(folder_name)
            clusters_cfg = folder_cfg["cluster_regex"]
            patterns = {
                pattern: db_name
                for db_name, pattern in clusters_cfg.items()
                if not args.databases or db_name in args.databases
            }
            if not patterns:
                continue

            raw_clusters = cluster_reader.get_clusters(folder_cfg["id"])
            for cluster in raw_clusters:
                db_name = None
                for pattern, db in patterns.items():
                    if re.match(pattern, cluster["name"]):
                        db_name = db
                        break
                if db_name:
                    pg_cfg_key = "postgresqlConfig_{}".format(cluster["config"]["version"])
                    pg_cfg = cluster["config"][pg_cfg_key]["effectiveConfig"]
                    max_connections = (
                        int(pg_cfg["maxConnections"]) - 15
                    )  # https://docs.yandex-team.ru/cloud/managed-postgresql/qa/all#user-conn-number
                    clusters.append(
                        MdbCluster(
                            folder_name, cluster["id"], cluster["name"], max_connections, db_name
                        )
                    )

        return clusters

    @staticmethod
    def hosts():
        return "CON"


class MdbResources:
    @dataclass
    class SignalParams:
        name_template: str
        itype: str

    RESOURCE_TO_SIGNAL_NAME_ITYPE: Final = {
        "CPU": SignalParams(
            "perc(sum(portoinst-cpu_usage_cores_tmmv, portoinst-cpu_usage_system_cores_tmmv), portoinst-cpu_guarantee_cores_tmmv)",
            "mdbdom0",
        ),
        "RAM": SignalParams(
            "mul(quant(portoinst-anon_limit_usage_perc_hgram,0.95),div(portoinst-memory_limit_gb_tmmv,portoinst-memory_guarantee_gb_tmmv))",
            "mdbdom0",
        ),
        "NET": SignalParams(
            "perc(sum(portoinst-net_rx_mb_summ, portoinst-net_tx_mb_summ), portoinst-net_guarantee_mb_summ)",
            "mdbdom0",
        ),
        "DISK": SignalParams(
            "perc(push-disk-used_bytes_pgdata_tmmx, push-disk-total_bytes_pgdata_tmmx)",
            "mailpostgresql",
        ),
        "IO": SignalParams(
            "perc(sum(portoinst-io_read_fs_bytes_tmmv, portoinst-io_write_fs_bytes_tmmv), portoinst-io_limit_bytes_tmmv)",
            "mdbdom0",
        ),
        "Connections": SignalParams(
            "perc(sum(push-postgres_conn_idle_tmmx, push-postgres_conn_idle_in_transaction_tmmx, push-postgres_conn_aborted_tmmx, push-postgres_conn_active_tmmx, push-postgres_conn_waiting_tmmx), {max_connections})",
            "mailpostgresql",
        ),
    }
    TAG_PATTERN: Final = "itype={itype};ctype={ctype};tier=primary"
    GOLOVAN_URL: Final = "https://yasm.yandex-team.ru/chart/hosts=CON;graphs={{{signal},const(0),const({usage}),const({crit_threshold}),const(100)}};{tag}/?from={st}&to={et}"

    @staticmethod
    def list():
        return MdbResources.RESOURCE_TO_SIGNAL_NAME_ITYPE.keys()

    @staticmethod
    def resource_tag(resource: str, cluster: MdbCluster):
        signal_params = MdbResources.RESOURCE_TO_SIGNAL_NAME_ITYPE.get(resource)
        return MdbResources.TAG_PATTERN.format(itype=signal_params.itype, ctype=cluster.id)

    @staticmethod
    def resource_signal(resource: str, cluster: MdbCluster):
        signal_params = MdbResources.RESOURCE_TO_SIGNAL_NAME_ITYPE.get(resource)
        return "{}:{}".format(
            MdbResources.resource_tag(resource, cluster),
            signal_params.name_template.format(**asdict(cluster)),
        )

    @staticmethod
    def resource_url(
        resource: str,
        cluster: MdbCluster,
        usage: float,
        start_time: str,
        end_time: str,
        crit_threshold: float,
    ):
        signal_params = MdbResources.RESOURCE_TO_SIGNAL_NAME_ITYPE.get(resource)
        url_templ = MdbResources.GOLOVAN_URL.format(
            signal=urllib.parse.quote(
                signal_params.name_template.format(**asdict(cluster)), safe=""
            ),
            tag=MdbResources.resource_tag(resource, cluster),
            usage=usage,
            crit_threshold=crit_threshold,
            st=int(start_time * 1000),
            et=int(end_time * 1000),
        )
        return url_templ

    @staticmethod
    def resource_info(
        resource: str,
        cluster: MdbCluster,
        values: List[float],
        start_time: float,
        end_time: float,
        target: float,
        crit_threshold: float,
        **kwargs,
    ):
        try:
            usage = percentile(values, target)
            assert usage <= max(values)
        except Exception as e:
            print("ERROR", resource, cluster, e, file=stderr)
            usage = 0
        url = MdbResources.resource_url(
            resource, cluster, usage, start_time, end_time, crit_threshold
        )
        should_add = usage - crit_threshold
        return ResourceInfo(url, usage, should_add)


def resource_table(
    results: Dict[str, ResourceInfo],
    target: float,
    chart_width: int,
):
    print("#|")
    print("||")
    first = True
    for resource in results.keys():
        if first:
            first = False
        else:
            print("|", end="")
        print(f"%%(wacko wrapper=text align=center)**{resource}**%%")
    print("||")

    print("||")
    first = True
    for resource, result in results.items():
        url, usage, should_add = result.yasm_url, result.usage, result.should_add
        advice_color_name, advice_color_triple = optimal_to_color(should_add)
        advice_description = "should add" if should_add >= 0 else "can remove"
        advice_value = format_percent(abs(should_add))
        image_url = url.replace("yasm.", "s.yasm.")
        usage = format_percent(usage)
        if first:
            first = False
        else:
            print("|", end="")
        print(
            f"""<span style="background-color: rgba({advice_color_triple}, 0.1)">!!(orange)usagevv{target}vv!!={usage}%, !!({advice_color_name}){advice_description} **{advice_value}%**!!</span>
[![]({image_url}&width={chart_width})]({url})"""
        )
    print("||")

    print("|#")


class SplitArgs(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, [value.strip() for value in values.split(",")])


def main():
    args = parse_args()

    end_time: Final = time()
    start_time: Final = end_time - RANGE

    host_group_cls, resources_cls = hosts_group_and_resources(args.hosts)
    hosts_groups = host_group_cls.list(args)

    signals = []
    for hosts in hosts_groups:
        for resource in resources_cls.list():
            signal = resources_cls.resource_signal(resource, hosts)
            signals.append(signal)

    results = defaultdict(lambda: defaultdict(list))
    for i, (timestamp, values) in tqdm(
        enumerate(GolovanRequest(host_group_cls.hosts(), PERIOD, start_time, end_time, signals)),
        total=RANGE // PERIOD,
        file=stderr,
    ):
        for hosts in hosts_groups:
            for resource in resources_cls.list():
                signal = resources_cls.resource_signal(resource, hosts)
                results[hosts][resource].append(values[signal])

    for group, results in results.items():
        print("=== {}".format(group))
        data = {
            resource: resources_cls.resource_info(
                resource,
                group,
                values,
                start_time,
                end_time,
                **vars(args),
            )
            for resource, values in results.items()
        }
        resource_table(
            data,
            args.target,
            chart_width=560,
        )


def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=f"""
Generate capacity report

For running with hosts=mdb save yc oauth token to {MdbCluster.TOKEN_FILE_PATH}
Get token: https://oauth.yandex-team.ru/authorize?response_type=token&client_id=8cdb2f6a0dca48398c6880312ee2f78d

Examples:
  $ ./generate_capacity_report.py qloud --projects 'xiva-server,xivahub,xivamob,xivamesh,xivaconf,xivadba,reaper' > /tmp/capacity-report.md
  $ ./generate_capacity_report.py mdb --folders 'xiva' > /tmp/capacity-report.md
        """,
    )
    parser.add_argument(
        "--target",
        dest="target",
        default=99,
        type=float,
        help="resource usage percentile (default: %(default)s)",
    )

    subparsers = parser.add_subparsers(dest="hosts", required=True)

    qloud_parser = subparsers.add_parser("qloud")
    qloud_parser.add_argument(
        "--projects",
        dest="projects",
        action=SplitArgs,
        required=True,
        help="comma separated projects",
    )
    qloud_parser.add_argument(
        "--envs",
        dest="envs",
        default=["production", "corp"],
        action=SplitArgs,
        help="comma separated environments (default: %(default)s)",
    )
    qloud_parser.add_argument(
        "--components",
        dest="components",
        default=[None],
        action=SplitArgs,
        help="comma separated components (default: %(default)s (aggregate all components in environment))",
    )
    qloud_parser.add_argument(
        "--dc_count",
        dest="dc_count",
        default=3,
        type=int,
        help="DC count (default: %(default)s)",
    )

    mdb_parser = subparsers.add_parser("mdb")
    mdb_parser.add_argument(
        "--folders",
        dest="folders",
        action=SplitArgs,
        required=True,
        help="comma separated folders",
    )
    mdb_parser.add_argument(
        "--databases",
        dest="databases",
        default=[],
        action=SplitArgs,
        help="comma separated databases (default: %(default)s (aggregate all components in environment))",
    )
    mdb_parser.add_argument(
        "--crit",
        dest="crit_threshold",
        default=90,
        action=SplitArgs,
        help="crit level of resource usage (default: %(default)s)",
    )

    return parser.parse_args()


def hosts_group_and_resources(hosts: str):
    if hosts == "qloud":
        return QloudHostsGroup, QloudResources
    if hosts == "mdb":
        return MdbCluster, MdbResources
    raise NotImplementedError(f"hosts={hosts}")


if __name__ == "__main__":
    main()
