# -*- coding: utf-8 -*-

from __future__ import print_function

from . import abc_client
from .common import (
    build_yp_client,
    create_filter,
    pandas_table_to_json,
    print_pandas_table,
    register_common_args,
)

from library.python.par_apply import par_apply
from library.python import oauth as lpo

from color import colored
import pandas as pd

import argparse
import base64
import copy
import json
import logging
import os
import re


_ABC_SERVICE = "abc:service"


def _warning(message):
    logging.warning(colored(message, "yellow"))


def _error(message):
    logging.error(colored(message, "red"))


def _fatal_error(message):
    _error(message)
    exit(1)


def _merge_df(dest_df, source_df):
    if dest_df is None:
        dest_df = source_df
    else:
        dest_df = dest_df.append(source_df, ignore_index=True)
    return dest_df


def get_gpu_models(client, account_id, specified_segment_id):
    response = client.select_objects(
        "account",
        selectors=["/spec", "/status"],
        filter='[/meta/id] = "{}"'.format(account_id),
    )
    if len(response) == 0:
        return None

    assert len(response) == 1, "Expected exactly one object in response"
    spec, status = response[0]

    if "resource_limits" not in spec:
        return []

    resource_limits = spec["resource_limits"]["per_segment"]

    gpu_models = set()
    for segment in resource_limits:
        if specified_segment_id is not None and segment != specified_segment_id:
            continue
        segment_spec = resource_limits[segment]
        for model in segment_spec.get("gpu_per_model", {}):
            gpu_models.add(model)

    if "0" in gpu_models:
        gpu_models.remove("0")
    return list(gpu_models)


_COLUMNS = [
    "vcpu",
    "memory",
    "hdd",
    "hdd_bw",
    "ssd",
    "ssd_bw",
    "ipv4",
    "net_bw",
]


class Columns(object):
    def __init__(self, segment_spec, gpu_models):
        self.cpu = segment_spec.get("cpu", {}).get("capacity", 0)
        self.hdd = segment_spec.get("disk_per_storage_class", {}).get("hdd", {}).get("capacity", 0)
        self.hdd_bw = segment_spec.get("disk_per_storage_class", {}).get("hdd", {}).get("bandwidth", 0)
        self.ssd = segment_spec.get("disk_per_storage_class", {}).get("ssd", {}).get("capacity", 0)
        self.ssd_bw = segment_spec.get("disk_per_storage_class", {}).get("ssd", {}).get("bandwidth", 0)
        self.memory = segment_spec.get("memory", {}).get("capacity", 0)
        self.ipv4 = segment_spec.get("internet_address", {}).get("capacity", 0)
        self.net_bw = segment_spec.get("network", {}).get("bandwidth", 0)

        self.gpu_models = gpu_models
        self.gpu = {
            model: segment_spec.get("gpu_per_model", {}).get(model, {}).get("capacity", 0)
            for model in gpu_models
        }

    def is_not_empty(self):
        return any(
            item > 0
            for item in [
                self.cpu,
                self.hdd,
                self.hdd_bw,
                self.ssd,
                self.ssd_bw,
                self.memory,
                self.ipv4,
                self.net_bw,
                sum(self.gpu.values()),
            ]
        )

    def get_values(self):
        values = [
            self.cpu,
            self.memory,
            self.hdd,
            self.hdd_bw,
            self.ssd,
            self.ssd_bw,
            self.ipv4,
            self.net_bw,
        ]
        for model in self.gpu_models:
            values.append(self.gpu.get(model, 0))
        return values

    def subtract(self, other):
        """Subtract `other` values from self."""
        self.cpu -= other.cpu
        self.memory -= other.memory
        self.hdd -= other.hdd
        self.hdd_bw -= other.hdd_bw
        self.ssd -= other.ssd
        self.ssd_bw -= other.ssd_bw
        self.ipv4 -= other.ipv4
        self.net_bw -= other.net_bw
        self.gpu = {
            model: self.gpu.get(model, 0) - other.gpu.get(model, 0)
            for model in self.gpu_models
        }


def get_account_dataframe(client, account_id, slug, specified_segment_id, gpu_models):
    response = client.select_objects(
        "account",
        selectors=["/spec", "/status"],
        filter='[/meta/id] = "{}"'.format(account_id),
    )
    if len(response) == 0:
        return None

    assert len(response) == 1, "Expected exactly one object in response"
    spec, status = response[0]

    data = list()
    resource_limits = spec.get("resource_limits", {}).get("per_segment", [])

    for segment in resource_limits:
        if specified_segment_id is not None and segment != specified_segment_id:
            continue

        segment_spec = resource_limits[segment]
        limits = Columns(segment_spec, gpu_models)

        immediate_resource_usage = status["immediate_resource_usage"].get("per_segment", {})
        segment_usage = immediate_resource_usage.get(segment, {})
        usage = Columns(segment_usage, gpu_models)

        free = copy.deepcopy(limits)
        free.subtract(usage)

        has_limits = limits.is_not_empty()
        has_usage = usage.is_not_empty()

        if has_limits or has_usage:
            data.append(["limit", segment, account_id, slug] + limits.get_values())
            data.append(["usage", segment, account_id, slug] + usage.get_values())
            data.append(["free", segment, account_id, slug] + free.get_values())

    dataframe = pd.DataFrame(
        data,
        columns=["type", "segment", "account_id", "slug"] + _COLUMNS + gpu_models,
    )

    return dataframe


def get_pod_sets_resource_usage_by_account_and_segment(client, account_id, segment_id, gpu_models):
    internet_addresses_response = client.select_objects(
        "internet_address",
        selectors=["/meta/id", "/status/pod_id"],
    )
    pod_to_internet_addresses = dict()
    for address_id, pod_id in internet_addresses_response:
        if not pod_id:
            continue
        if pod_id not in pod_to_internet_addresses:
            pod_to_internet_addresses[pod_id] = []
        pod_to_internet_addresses[pod_id].append(address_id)

    def get_internet_address_in_use_count(pod_id):
        return len(pod_to_internet_addresses.get(pod_id, []))

    def load_pods(pod_set_id):
        response = client.select_objects(
            "pod",
            selectors=[
                "/meta/id",
                "/meta/pod_set_id",
                "/meta/account_id",
                "/spec/resource_requests",
                "/spec/disk_volume_requests",
                "/spec/gpu_requests",
            ],
            filter='[/meta/pod_set_id] = "{}"'.format(pod_set_id),
        )

        result = []
        for (
            pod_id,
            pod_set_id,
            pod_account_id,
            optional_resource_requests,
            disk_volume_requests,
            gpu_requests
        ) in response:
            resource_requests = optional_resource_requests or dict()

            disk_requests = {"hdd": 0, "ssd": 0}
            disk_requests_bw = {"hdd": 0, "ssd": 0}
            if disk_volume_requests:
                for request in disk_volume_requests:
                    storage_class = request["storage_class"]
                    size = request["quota_policy"]["capacity"]
                    size_bw = request["quota_policy"].get("bandwidth_guarantee", 0)
                    if storage_class in disk_requests:
                        disk_requests[storage_class] += size
                    if storage_class in disk_requests_bw:
                        disk_requests_bw[storage_class] += size_bw

            resources = {
                "ipv4": get_internet_address_in_use_count(pod_id),
                "memory": resource_requests.get("memory_limit", 0),
                "vcpu": resource_requests.get("vcpu_guarantee", 0),
                "hdd": disk_requests["hdd"],
                "hdd_bw": disk_requests_bw["hdd"],
                "ssd": disk_requests["ssd"],
                "ssd_bw": disk_requests_bw["ssd"],
                "net_bw": resource_requests.get("network_bandwidth_guarantee", 0),
            }
            for model in gpu_models:
                if gpu_requests:
                    resources[model] = sum(request.get("model", "") == model for request in gpu_requests)
                else:
                    resources[model] = 0

            result.append((
                pod_id,
                pod_set_id,
                pod_account_id,
                resources,
            ))

        return result

    def add_or_set(d, k, v):
        if k in d:
            d[k] += v
        else:
            d[k] = v

    pod_set_responses = client.select_objects(
        "pod_set",
        selectors=["/meta/id"],
        filter=create_filter([
            '[/spec/account_id] = "{}"'.format(account_id),
            '[/spec/node_segment_id] = "{}"'.format(segment_id),
        ])
    )
    pod_set_ids = [pod_set_response[0] for pod_set_response in pod_set_responses]

    pod_sets_usage = {}
    pod_sets_foreign_usage = {}
    for pods_data in par_apply(pod_set_ids, load_pods, 20):
        for pod_id, pod_set_id, pod_account_id, resources in pods_data:
            if pod_account_id and pod_account_id != account_id:  # foreign account of pod
                if pod_set_id not in pod_sets_foreign_usage:
                    pod_sets_foreign_usage[pod_set_id] = {"account_ids": set()}
                for key in resources:
                    add_or_set(pod_sets_foreign_usage[pod_set_id], key, resources[key])
                pod_sets_foreign_usage[pod_set_id]["account_ids"].add(pod_account_id)
            else:
                if pod_set_id not in pod_sets_usage:
                    pod_sets_usage[pod_set_id] = {}
                for key in resources:
                    add_or_set(pod_sets_usage[pod_set_id], key, resources[key])

    flatten_pod_sets_data = []
    for pod_set_id in pod_sets_usage:
        data = copy.deepcopy(pod_sets_usage[pod_set_id])
        data["segment"] = segment_id
        data["account_id"] = account_id
        data["pod_set_id"] = pod_set_id
        flatten_pod_sets_data.append(data)

    flatten_pod_sets_data_foreign = []
    for pod_set_id in pod_sets_foreign_usage:
        data = copy.deepcopy(pod_sets_foreign_usage[pod_set_id])
        data["segment"] = segment_id
        data["account_id"] = ",".join(data["account_ids"])
        del data["account_ids"]
        data["pod_set_id"] = pod_set_id
        flatten_pod_sets_data_foreign.append(data)

    usage_dataframe = pd.DataFrame(
        flatten_pod_sets_data,
        columns=["segment", "account_id", "pod_set_id"] + _COLUMNS + gpu_models,
    )
    foreign_usage_dataframe = pd.DataFrame(
        flatten_pod_sets_data_foreign,
        columns=["segment", "account_id", "pod_set_id"] + _COLUMNS + gpu_models,
    )
    return usage_dataframe, foreign_usage_dataframe


# These are not really secrets, see https://pg.at.yandex-team.ru/5490
# App: https://oauth.yandex-team.ru/client/375ae554525e48f382feaf3c9cd3a2da
_CLIENT_ID = base64.b64decode("Mzc1YWU1NTQ1MjVlNDhmMzgyZmVhZjNjOWNkM2EyZGE=").decode("utf-8")
_CLIENT_SECRET = base64.b64decode("NjM0ODlkM2U5ZGM0NDNhZThhOTY0NDUyOGIyZDU5M2Q=").decode("utf-8")


def _token_required():
    _fatal_error(
        "ABC OAuth token is required. Please use either:\n\n"
        "  - OAuth over SSH authorization\n"
        "    See https://pg.at.yandex-team.ru/5490 for details. Preferred way, safe\n\n"
        "  - `ABC_TOKEN` environment variable (less secure, for automation only)\n"

        "  - ~/.abc/token file with token\n"
        "    Insecure, bad way! See https://clubs.at.yandex-team.ru/golang/441\n\n"
        "  - command-line argument --abc-token\n"
        "How to obtain token: https://wiki.yandex-team.ru/intranet/abc/api/#autentifikacija"
    )


def _abc_access_required():
    _fatal_error(
        "ABC access is required for selected options. Please use Puncher "
        "to gain access to {}".format(abc_client.AbcClient.FQDN)
    )


def _abc_error_handler(subject, exception):
    _fatal_error(
        "Error requesting ABC services of `{subject}`. "
        "You should probably request access to `{fqdn}` via Puncher or fix the token".format(
            subject=subject,
            fqdn=abc_client.AbcClient.FQDN,
        )
    )


def _get_abc_token(abc_token, no_abc_resolving, token_required):
    if no_abc_resolving and token_required:
        _abc_access_required()

    # The most explicit way to specify token: command line args.
    if abc_token:
        return abc_token

    # Okay, let's try environment variable.
    abc_token = os.environ.get("ABC_TOKEN")
    if abc_token:
        return abc_token

    # Empty both args and environment, so try read from file.
    token_path = os.path.expanduser('~/.abc/token')
    if os.path.exists(token_path):
        with open(token_path, 'r') as f:
            return f.read().strip()

    # The best way to obtain token is over SSH, see https://pg.at.yandex-team.ru/5490

    # Note: this method is applied as the LAST CHANCE so user can choose
    # other methods specifying command line args or environment variable.
    abc_token = lpo.get_token(_CLIENT_ID, _CLIENT_SECRET)

    if not abc_token and token_required:
        _token_required()

    return abc_token


def parse_args(argv):
    parser = argparse.ArgumentParser(add_help=True)
    register_common_args(parser)
    parser.add_argument("account_id", help='account id {tmp, abc:service:123, abc:slug:upper, abc:slug:yp}')
    parser.add_argument(
        "-r", "--recursive",
        action="store_true",
        help="explain for all children accounts recursively",
    )
    parser.add_argument(
        "--abc-token",
        type=str,
        help="Explicitly pass ABC token if OAuth-over-SSH method is unavailable"
    )
    parser.add_argument(
        "-p", "--pod-set-limit",
        type=int,
        default=100,
        help="show usage only for the limited number of pod sets",
    )
    parser.add_argument(
        "--node-segment",
        help="show information only about given node segment",
    )
    parser.add_argument(
        "-a", "--no-abc-resolving",
        action="store_true",
        help="Do not try to resolve services by ABC",
    )
    parser.add_argument(
        "-j", "--join-usage",
        action="store_true",
        help="join usage, limits and frees in single table",
    )
    return parser.parse_args(argv)


def get_accounts(account_id, recursive, abc_token, no_abc_resolving):
    abc_slug_pattern = r"^abc:slug:([a-zA-Z0-9_-]+)$"
    abc_slug_match = re.match(abc_slug_pattern, account_id)

    abc_service_pattern = r"^abc:service:(\d+)$"
    abc_service_match = re.match(abc_service_pattern, account_id)

    abc_required = recursive or abc_slug_match is not None
    abc_token = _get_abc_token(abc_token, no_abc_resolving, abc_required)

    abc = None
    try:
        abc = abc_client.AbcClient(abc_token)
    except abc_client.AbcInaccessible:
        if not no_abc_resolving:
            _warning(
                'ABC client is inaccessible. Please check access to `{}` '
                'or use --no-abc-resolving option'.format(abc_client.AbcClient.FQDN)
            )

    abc_service_id = None  # Optional integer.
    abc_service_infos_by_id = {}

    if abc_slug_match is not None:
        if abc is None:
            _abc_access_required()

        slug = abc_slug_match.group(1)
        try:
            abc_service_id = abc.get_service_id_by_slug(slug)
        except Exception as exc:
            _abc_error_handler(subject=slug, exception=exc)

        if abc_service_id is None:
            _fatal_error("Service with slug `{}` was not found".format(slug))

        account_ids = ["{}:{}".format(_ABC_SERVICE, abc_service_id)]
        abc_service_infos_by_id[abc_service_id] = {}
        abc_service_infos_by_id[abc_service_id]["slug"] = slug
    elif abc_service_match is not None:
        account_ids = [account_id]
        abc_service_id = int(abc_service_match.group(1))

        if abc_token and abc:
            try:
                slug = abc.get_service_slug_by_id(abc_service_id)
            except Exception as exc:
                _abc_error_handler(subject=abc_service_id, exception=exc)

            if slug is None:
                _fatal_error("ABC service with account id `{}` was not found".format(abc_service_id))

            abc_service_infos_by_id[abc_service_id] = {}
            abc_service_infos_by_id[abc_service_id]["slug"] = slug
    else:
        # Non-ABC accounts, e.g. "tmp".
        account_ids = [account_id]

    if recursive:
        if abc_service_id is None:
            _fatal_error(
                "Recursive mode requires account id to be in either abc:slug:<slug> or abc:service:<nnn> format",
            )

        if not abc:
            _abc_access_required()

        try:
            recursive_abc_service_infos_by_id = abc.list_service_children(abc_service_id)
        except Exception as exc:
            _abc_error_handler(subject=account_id, exception=exc)

        abc_service_infos_by_id.update(recursive_abc_service_infos_by_id)
        account_ids += [
            "{}:{}".format(_ABC_SERVICE, account_id_)
            for account_id_ in recursive_abc_service_infos_by_id
        ]

    def get_slug(account_id):
        prefix = "{}:".format(_ABC_SERVICE)
        if not account_id.startswith(prefix):
            return "<unknown>"
        abc_service_id = int(account_id.replace(prefix, ""))
        # Root slug will not be queried when ABC token is not provided.
        return abc_service_infos_by_id.get(abc_service_id, {}).get("slug", "<root>")

    return [
        dict(id=account_id_, slug=get_slug(account_id_))
        for account_id_ in account_ids
    ]


def make_account_dataframes(client, accounts, node_segment, pod_set_limit):
    all_accounts_df = None
    all_pod_sets_usage_df = None
    all_pod_sets_foreign_usage_df = None

    for account in accounts:
        gpu_models = get_gpu_models(client, account["id"], node_segment)

        account_df = get_account_dataframe(
            client,
            account["id"],
            account["slug"],
            node_segment,
            gpu_models,
        )

        if account_df is None:
            _warning('Account "{}" not found'.format(account["id"]))
            continue

        if len(account_df) == 0:
            logging.debug('Account `%s` has zero limit and usage per every segment and resource', account["id"])
            continue

        all_accounts_df = _merge_df(all_accounts_df, account_df)

        for segment_id in account_df.segment.unique():
            if (
                all_accounts_df is not None
                and len(all_accounts_df) >= pod_set_limit
            ):
                logging.debug(
                    'Pod set usage in account `%s` and segment `%s` is omitted due to --pod-set-limit',
                    account["id"],
                    segment_id,
                )
                continue

            (
                pod_sets_usage_dataframe,
                pod_sets_foreign_usage_dataframe
            ) = get_pod_sets_resource_usage_by_account_and_segment(
                client,
                account["id"],
                segment_id,
                gpu_models,
            )

            all_pod_sets_usage_df = _merge_df(all_pod_sets_usage_df, pod_sets_usage_dataframe)
            all_pod_sets_foreign_usage_df = _merge_df(all_pod_sets_foreign_usage_df, pod_sets_foreign_usage_dataframe)

    def extract_type(type_):
        if all_accounts_df is None:
            return pd.DataFrame()

        filtered = all_accounts_df[
            all_accounts_df.type == type_
        ]
        filtered.reset_index(drop=True, inplace=True)
        filtered = filtered.drop(["type"], axis=1)
        return filtered

    all_accounts_usage_df = extract_type("usage")
    all_accounts_limits_df = extract_type("limit")
    all_accounts_free_df = extract_type("free")

    if all_pod_sets_usage_df is None:
        all_pod_sets_usage_df = pd.DataFrame()
    if all_pod_sets_foreign_usage_df is None:
        all_pod_sets_foreign_usage_df = pd.DataFrame()

    return (
        all_accounts_usage_df,
        all_accounts_limits_df,
        all_accounts_free_df,
        all_pod_sets_usage_df,
        all_pod_sets_foreign_usage_df,
        all_accounts_df,
    )


def process_cluster(args, accounts, cluster):
    client = build_yp_client(args)
    (
        accounts_usage_df,
        accounts_limit_df,
        accounts_free_df,
        pod_sets_usage_df,
        pod_sets_foreign_usage_df,
        all_accounts_df,
    ) = make_account_dataframes(
        client,
        accounts,
        args.node_segment,
        args.pod_set_limit,
    )

    accounts_stats_columns = [
        col for col in accounts_limit_df.columns
        if col not in ("account_id", "segment", "slug")
    ]
    pod_sets_stats_columns = [
        col for col in pod_sets_usage_df.columns
        if col not in ("pod_set_id", "segment", "account_id")
    ]

    if args.format == "json":
        args.no_pretty_units = True
        print(json.dumps(
            dict(
                accounts_usages=pandas_table_to_json(accounts_usage_df, args),
                accounts_limits=pandas_table_to_json(accounts_limit_df, args),
                accounts_frees=pandas_table_to_json(accounts_free_df, args),
                pod_sets_usages=pandas_table_to_json(pod_sets_usage_df, args),
                pod_sets_foreign_usages=pandas_table_to_json(pod_sets_foreign_usage_df, args),
            ),
            indent=4,
        ))
    else:
        if accounts_limit_df.empty and accounts_usage_df.empty:
            print("No account data to print")
        else:
            if args.join_usage:
                print("Account(s) limits, usages and frees:")
                print_pandas_table(all_accounts_df, args, accounts_stats_columns)
                print()
            else:
                print("Account(s) limits:")
                print_pandas_table(accounts_limit_df, args, accounts_stats_columns)
                print()

                print("Account(s) usages:")
                print_pandas_table(accounts_usage_df, args, accounts_stats_columns)
                print()

                print("Account(s) frees:")
                print_pandas_table(accounts_free_df, args, accounts_stats_columns)
                print()

        if pod_sets_usage_df.empty:
            if args.pod_set_limit > 0:
                print("No pod set data to print")
        else:
            print("Pod set(s) usages:")
            print_pandas_table(pod_sets_usage_df, args, pod_sets_stats_columns)

        if not pod_sets_foreign_usage_df.empty:
            print("Pod set(s) foreign usages (usage of pods with account_id != pod_set.account_id):")
            print_pandas_table(pod_sets_foreign_usage_df, args, pod_sets_stats_columns)


def main(argv):
    args = parse_args(argv)

    accounts = get_accounts(
        args.account_id,
        args.recursive,
        args.abc_token,
        args.no_abc_resolving,
    )

    if not args.cluster:
        args.cluster = "sas,man,vla,iva,myt"

    clusters = args.cluster.split(",")
    for cluster in clusters:
        args.cluster = cluster
        if args.format != "json":
            print("Cluster `{}`:".format(cluster))
        process_cluster(args, accounts, cluster)
