# -*- coding: utf-8 -*-

from __future__ import print_function

from yp.client import YpClient, find_token

from yp.common import YtResponseError

from yt.yson import YsonEntity

from color import colored
from tabulate import tabulate
import numpy as np

import json
import logging


def to_int_or_empty(value):
    if np.isnan(value):
        return ''
    else:
        return int(value)


def to_true_or_empty(value):
    if np.isnan(value):
        return ''
    else:
        return "True" if bool(value) else ""


def _human_readable_value(value, unit, unit_names, digits_after_comma):
    if value == "na":
        return value

    i = 0
    while i + 1 < len(unit_names) and value >= unit:
        i += 1
        value /= (1.0 * unit)

    return str(round(value, digits_after_comma)) + " " + unit_names[i]


def human_readable_cores(cores, digits_after_comma=2):
    return _human_readable_value(
        cores,
        1000,
        ("cores", "Kcores", "Mcores", "Gcores", "Tcores", "Pcores"),
        digits_after_comma,
    )


def human_readable_memory(memory, digits_after_comma=3):
    return _human_readable_value(
        memory,
        1024,
        ("B", "KiB", "MiB", "GiB", "TiB", "PiB"),
        digits_after_comma,
    )


def human_readable_bandwidth(bandwidth, digits_after_comma=3):
    return _human_readable_value(
        bandwidth,
        1024,
        ("B/s", "KiB/s", "MiB/s", "GiB/s", "TiB/s", "PiB/s"),
        digits_after_comma,
    )


def _preprocess_pandas_table(table, args=None, total_stats_columns=None):
    if args is not None and args.query is not None:
        table.query(args.query, inplace=True)

    if args is not None and args.sort_by is not None:
        table.sort_values(args.sort_by.split(","), ascending=args.sort_order == "asc", inplace=True)

    table.reset_index(drop=True, inplace=True)

    total = None
    if total_stats_columns is not None:
        total = table[total_stats_columns].sum()

    if args is not None and args.output_limit > 0:
        table = table.head(args.output_limit)

    # Avoid pandas warning `SettingWithCopyWarning`.
    table_total = table.copy()
    if total is not None:
        table_total.loc["total"] = total
        table_total[total_stats_columns] = table_total[total_stats_columns].fillna(0).applymap(np.int64)
    table = table_total.copy()

    table = table.applymap(lambda x: "" if isinstance(x, YsonEntity) else x)

    pretty_units = (args is None) or (not args.no_pretty_units)

    if pretty_units:
        for column in table.columns:
            if "cpu" in column:
                table[column] = table[column] / 1000

    for column in table.columns:
        if column == "ipv4":
            table[column] = table[column].apply(to_int_or_empty)
        elif column in ("net_10G", "alerts"):
            table[column] = table[column].apply(to_true_or_empty)
        elif pretty_units:
            converter = None
            if "ssd_bw" in column or "hdd_bw" in column or "net_bw" in column:
                converter = human_readable_bandwidth
            elif "memory" in column or "disk" in column or "ssd" in column or "hdd" in column:
                converter = human_readable_memory
            elif "cpu" in column:
                converter = human_readable_cores
            if converter is not None:
                table[column] = table[column].apply(converter)

    table.fillna("", inplace=True)

    return table


def print_pandas_table(table, args=None, total_stats_columns=None):
    table = _preprocess_pandas_table(table, args, total_stats_columns)

    output_format = "tabular"
    if args is not None:
        output_format = args.format

    if output_format == "json":
        print(json.dumps(table.to_dict(orient="records"), indent=4))
    elif output_format == "tabular":
        print(tabulate(table, headers="keys", tablefmt="fancy_grid"))
    else:
        raise RuntimeError('Unknown output format "{}"'.format(output_format))


def pandas_table_to_json(table, args=None, total_stats_columns=None):
    if table.empty:
        return []

    table = _preprocess_pandas_table(table, args, total_stats_columns)
    return table.to_dict(orient="records")


def build_yp_client(args, cluster=None):
    def get_yp_cluster(args):
        return args.cluster if cluster is None else cluster

    def get_yp_token(args):
        return args.token if args.token is not None else find_token()

    client = YpClient(
        get_yp_cluster(args),
        config=dict(token=get_yp_token(args)),
    )

    try:
        client.select_objects("pod", selectors=["/meta/id"], limit=1)
    except Exception as ex:
        if isinstance(ex, YtResponseError) and ex.contains_code(109):  # Authentication error.
            logging.error(colored(
                "Your authentication token was rejected by the server. "
                "Please refer to https://wiki.yandex-team.ru/yp/accesscontrol for obtaining a valid token",
                "red"
            ))
        else:
            logging.error(colored("Error validating connection with YP", "red"))
        exit(1)

    return client


def register_yp_token_args(parser):
    parser.add_argument('--token', required=False, help='YP OAuth token')


def register_common_args(parser):
    register_yp_token_args(parser)
    parser.add_argument(
        '--cluster',
        '--address',
        required=False,
        help='YP cluster (comma-separated list is also allowed for `account explain` mode, e.g. `sas,man,vla`)',
    )
    parser.add_argument(
        '--output-limit',
        type=int,
        default=5000000,
        help='output row limit {%(default)s by default}',
    )
    parser.add_argument(
        '--query',
        help='pandas query for output filtering {cpu in millicores, memory and disk in bytes}',
    )
    parser.add_argument(
        '--sort-by',
        help='sort by tuple separated by commas',
    )
    parser.add_argument(
        '--sort-order',
        default='desc',
        choices=['asc', 'desc'],
        help='sort order',
    )
    parser.add_argument(
        '--no-pretty-units',
        required=False,
        help='do not make numbers human-readable',
        action='store_true',
    )
    parser.add_argument(
        '--format',
        default='tabular',
        choices=['json', 'tabular'],
        help='output format {%(default)s by default}',
    )


def create_filter(filters):
    if not filters:
        return None
    return " AND ".join(["({})".format(filter) for filter in filters])


def create_node_filter(args, client, default_node_status=None):
    filters = []

    segment = "default" if args.segment is None else args.segment
    # Explicitly specified --segment "" eliminates this filter.
    if segment:
        node_segment_filter = client.get_object("node_segment", segment, selectors=["/spec/node_filter"])[0]
        filters.append(node_segment_filter)

    if args.node_filter is not None:
        assert args.node_filter, "Node filter must be non-empty if specified"
        filters.append(args.node_filter)

    if args.node_id is not None:
        assert args.node_id, "Node id must be non-empty if specified"
        filters.append('[/meta/id] = "{}"'.format(args.node_id))

    node_status = args.node_status
    if default_node_status is not None and args.node_id is None and node_status is None:
        node_status = default_node_status

    # Explicitly specified --node-status "" eliminates this filter.
    if node_status:
        filters.append("[/status/hfsm/state] = \"{}\"".format(node_status))

    return create_filter(filters)


def register_common_nodes_args(parser):
    register_common_node_segment_args(parser)
    parser.add_argument('--node-id', help='node id {man1-5227.search.yandex.net}')
    parser.add_argument('--node-filter', help='node filter {[/labels/my_lovely_node] = true}')


def register_measurement_args(parser):
    parser.add_argument('--cpu', action='store_true', help='show cpu in real millicores instead of vcpu')


def register_common_node_segment_args(parser):
    parser.add_argument('--segment', help='segment {dev, default, ""}')


def register_nodes_status_args(parser, default=None):
    parser.add_argument('--node-status', help='node status {up, down}', default=default)
