from argparse import ArgumentParser
from collections import namedtuple
import copy
import logging
import os
import re
import yaml

import mail.unistat.cpp.cython.logs as logs
from mail.unistat.cpp.cython.meters import (
    AccessLogCount,
    AccessLogCountByFirstStatusDigit,
    AccessLogRequestTimeHist,
    CountErrors,
    HttpClientHttpRequestTotalTimeHist,
    SupervisorLogRestartMeters,
)
import mail.sharpei.unistat.cpp.run as sharpei_unistat
from mail.sharpei.unistat.cpp.run import (
    AccessLogHistogram,
    AccessLogHistogramWeak,
    AccessLogRequestCounter,
    AccessLogTskv,
    HttpClientLogRequestCounter,
    PaLogHistogram,
    PaLogHistogramForEndpoint,
    StatusCounter,
)

logging.basicConfig(
    level=logging.WARNING, format="[%(asctime)s] [%(levelname)s]: %(message)s"
)
log = logging.getLogger(__name__)

ACCESS_LOG_BORDERS = (0, 5, 10, 20, 50, 100, 300, 500, 1000)

PA_BORDERS = (0, 5, 10, 20, 30, 40, 50, 100, 200, 300, 500, 1000)

HTTP_CLIENT_BORDERS = (0, 5, 10, 20, 50, 100, 300, 500, 1000)


class DSAndDiskMetricsProvider(object):
    @staticmethod
    def get_endpoints():
        return [
            "/create_user",
            "/get_user",
            "/ping",
            "/pingdb",
            "/reset_cache",
            "/stat",
            "/update_user",
        ]

    @staticmethod
    def get_sharpei_log_errors_to_search_for():
        return [
            "appropriate host not found",
            "cached shard databases not found",
            "cached shard databases roles not found",
            "cached shard name not found",
            "error in request to meta database",
            "invalid request",
            "reset error",
            "shard not found",
            "shard with alive master not found",
            "sharpei error",
            "unknown error",
        ]

    @staticmethod
    def get_access_log_metrics(name_prefix):
        metrics = [
            # hardcoded to save backward compatibility
            AccessLogCount("xxx"),
            AccessLogCountByFirstStatusDigit(""),
            AccessLogHistogramWeak(ACCESS_LOG_BORDERS, "", "access_log_request"),
        ]

        for endpoint in DSAndDiskMetricsProvider.get_endpoints():
            metrics += [
                AccessLogRequestCounter(endpoint, name_prefix),
                AccessLogHistogram(ACCESS_LOG_BORDERS, endpoint, name_prefix),
            ]

        return metrics

    @staticmethod
    def get_pa_metrics(sharddb_hosts):
        return [
            PaLogHistogram(PA_BORDERS, "get_all_shards", "pa"),
            PaLogHistogram(PA_BORDERS, "get_master", "pa"),
            PaLogHistogram(PA_BORDERS, "get_shard", "pa"),
            PaLogHistogram(PA_BORDERS, "get_status", "pa"),
            PaLogHistogram(PA_BORDERS, "get_user_data", "pa"),
            PaLogHistogram(PA_BORDERS, "ping", "pa"),
        ]

    @staticmethod
    def get_http_client_metrics():
        return []

    @staticmethod
    def get_status_log_metrics(name_prefix, hosts):
        return []

    @staticmethod
    def get_sharpei_registration_messages_to_search_for():
        return []

class CloudMetricsProvider(object):
    @staticmethod
    def get_endpoints():
        return [
            "/ping",
            "/stat",
            "/v3/stat",
            "/stat/strict"
        ]

    @staticmethod
    def get_sharpei_log_errors_to_search_for():
        return [
            # concrete error messages from sharpei::ErrorCategory::message(),
            # please refer to https://a.yandex-team.ru/arc/trunk/arcadia/mail/sharpei/include/internal/errors.h
            "appropriate host not found",
            "cached shard databases not found",
            "cached shard databases roles not found",
            "cached shard name not found",
            "cluster polling error",
            "roles for mode not found",
            "shard not found",
            "shard with alive master not found",
            "shards polling error",
            "sharpei error",
            "unknown error",
            "yc hosts cache is empty",
        ]

    @staticmethod
    def get_access_log_metrics(name_prefix):
        metrics = [
            AccessLogCount(name_prefix),
            AccessLogCountByFirstStatusDigit(name_prefix),
            AccessLogHistogramWeak(ACCESS_LOG_BORDERS, "", name_prefix),
        ]

        for endpoint in CloudMetricsProvider.get_endpoints():
            metrics += [
                AccessLogRequestCounter(endpoint, name_prefix),
                AccessLogHistogram(ACCESS_LOG_BORDERS, endpoint, name_prefix),
            ]

        return metrics

    @staticmethod
    def get_pa_metrics(sharddb_hosts):
        return [ PaLogHistogram(PA_BORDERS, "get_status", "pa") ]

    @staticmethod
    def get_http_client_metrics():
        return [
            HttpClientLogRequestCounter("ticket", "http_client_tvm"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "ticket", "http_client_tvm_ticket"
            ),
            HttpClientLogRequestCounter("tokens", "http_client_iam"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "tokens", "http_client_iam_tokens"
            ),
            HttpClientLogRequestCounter("managed-postgresql", "http_client_yc_hosts"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "managed-postgresql", "http_client_yc_hosts"
            ),
        ]

    @staticmethod
    def get_status_log_metrics(name_prefix, hosts):
        return []

    @staticmethod
    def get_sharpei_registration_messages_to_search_for():
        return []


class MailMetricsProvider(object):
    @staticmethod
    def get_endpoints():
        return [
            "/org_conninfo",
            "/domain_conninfo",
            "/conninfo",
            "/ping",
            "/pingdb",
            "/reset",
            "/stat",
            "/sharddb_stat",
            "/v2/stat",
            "/v3/stat",
        ]

    @staticmethod
    def get_sharpei_log_errors_to_search_for():
        return [
            # concrete error messages from sharpei::ErrorCategory::message(),
            # please refer to https://a.yandex-team.ru/arc/trunk/arcadia/mail/sharpei/include/internal/errors.h
            "appropriate host not found",
            "cached shard databases not found",
            "cached shard databases roles not found",
            "cached shard name not found",
            "cluster polling error",
            "endpoint provider error",
            "error in request to meta database",
            "meta database request expired in apq queue",
            "meta database request timeout",
            "meta master provider error",
            "meta polling error",
            "roles for mode not found",
            "shard not found",
            "shard with alive master not found",
            "shards polling error",
            "sharpei error",
            "unknown error",
            "yc hosts cache is empty",
        ]

    @staticmethod
    def get_sharpei_registration_messages_to_search_for():
        return [
            "successfully registered",
            "user already registered",
            "user registration already in progress",
            "sharddb error during registration",
            "maildb error during registration",
            "shard is occupied by user",
            "both mdb and sharddb prepare failed",
            "illegal mdb and sharddb state"
        ]

    @staticmethod
    def get_access_log_metrics(name_prefix):
        metrics = [
            AccessLogCount(name_prefix),
            AccessLogCountByFirstStatusDigit(name_prefix),
            AccessLogHistogramWeak(ACCESS_LOG_BORDERS, "", name_prefix),
        ]

        for endpoint in MailMetricsProvider.get_endpoints():
            metrics += [
                AccessLogRequestCounter(endpoint, name_prefix),
                AccessLogHistogram(ACCESS_LOG_BORDERS, endpoint, name_prefix),
            ]

        return metrics

    @staticmethod
    def get_pa_metrics(sharddb_hosts):
        metrics = [PaLogHistogramForEndpoint(PA_BORDERS, "get_user_data", "pa", ep) for ep in sharddb_hosts]
        metrics += [PaLogHistogramForEndpoint(PA_BORDERS, "get_all_shards", "pa", ep) for ep in sharddb_hosts]
        return metrics + [
            PaLogHistogram(PA_BORDERS, "get_domain_shard_id", "pa"),
            PaLogHistogram(PA_BORDERS, "get_master", "pa"),
            PaLogHistogram(PA_BORDERS, "get_org_shard_id", "pa"),
            PaLogHistogram(PA_BORDERS, "get_reg_data", "pa"),
            PaLogHistogram(PA_BORDERS, "get_shard", "pa"),
            PaLogHistogram(PA_BORDERS, "get_status", "pa"),
            PaLogHistogram(PA_BORDERS, "ping", "pa"),
            PaLogHistogram(PA_BORDERS, "registration", "pa"),
        ]

    @staticmethod
    def get_http_client_metrics():
        return [
            HttpClientLogRequestCounter("userinfo", "http_client_bb"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "userinfo", "http_client_bb_userinfo"
            ),
            HttpClientLogRequestCounter("hosted_domains", "http_client_bb"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "hosted_domains", "http_client_bb_hosted_domains"
            ),
            HttpClientLogRequestCounter("ticket", "http_client_tvm"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "ticket", "http_client_tvm_ticket"
            ),
            HttpClientLogRequestCounter("tokens", "http_client_iam"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "tokens", "http_client_iam_tokens"
            ),
            HttpClientLogRequestCounter("managed-postgresql", "http_client_yc_hosts"),
            HttpClientHttpRequestTotalTimeHist(
                HTTP_CLIENT_BORDERS, "managed-postgresql", "http_client_yc_hosts"
            ),
        ]

    @staticmethod
    def get_status_log_metrics(name_prefix, hosts):
        shard_id = "sharddb"
        alive = True
        dead = False
        metrics = []
        for host in hosts:
            metrics += [
                StatusCounter(name_prefix, shard_id, host, alive),
                StatusCounter(name_prefix, shard_id, host, dead),
            ]
        return metrics


def get_metrics_provider(type):
    if type == "datasync_and_disk":
        return DSAndDiskMetricsProvider()
    elif type == "mail":
        return MailMetricsProvider()
    elif type == "cloud":
        return CloudMetricsProvider()
    assert False


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--cfg_path", type=str)
    args = parser.parse_args()
    config = read_yaml(args.cfg_path)
    config = namedtuple("args", config.keys())(**config)
    return config


def make_sharpei_log_metrics(prefix, errors):
    metrics = []
    for err in errors:
        metrics += [CountErrors(err, prefix)]
    return metrics


def extract_hostlist(data):
    hosts = []
    for v in data["config"]["modules"]["module"]:
        if v["_name"] == "sharpei":
            if "hostlist" in v["configuration"]["meta_connection"]["endpoint_provider"]:
                for item in v["configuration"]["meta_connection"]["endpoint_provider"]["hostlist"]:
                    hosts += [item["host"]]
    return hosts


def read_yaml(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)


def read_included(data):
    if isinstance(data, dict):
        return [read_yaml(data['_file'])['config']]
    elif isinstance(data, list):
        return [read_yaml(elem['_file'])['config'] for elem in data]
    assert False


def preprocess_includes(data):
    if isinstance(data, dict):
        for key in list(data.keys()):
            if key == 'include':
                included_content = read_included(data[key])
                del data[key]
                for content in included_content:
                    data.update(content)
            else:
                preprocess_includes(data[key])
    elif isinstance(data, list):
        for elem in data:
            preprocess_includes(elem)


def wrap_for_lookup(foo):
    def wrapped(data):
        try:
            return foo(data)
        except KeyError:
            return None
        except Exception as e:
            assert False, str(e)

    return wrapped


def lookup_in_cfg(cfg_path, find):
    """
    Traverses files starting from the |cfg_path| and further to his base configs.
    For each file |find| will be invoked.
    If lookup is succseeded, |find| should return the found value.
    This value will be immediately returned as the result of the entire search.
    If lookup failed in the given file, |find| should return None and search
    will continue in the base file of the current file.
    """
    data = read_yaml(cfg_path)
    preprocess_includes(data)
    value = find(data)
    if value:
        return value
    return lookup_in_cfg(data["base"], find) if "base" in data else None


def make_sharpei_config(cfg_path):
    remove_scheme = re.compile(r"https?://")
    find_log_path = lambda name: lookup_in_cfg(cfg_path, wrap_for_lookup(
        lambda data: data["config"]["log"][name]["sinks"][0]["path"])
    )
    find_pa_log_path = lookup_in_cfg(cfg_path, wrap_for_lookup(
        lambda data: next((v["configuration"]["pa"]["log"] for v in data["config"]["modules"]["module"] if v["_name"] == "sharpei"), None)
    ))
    return SharpeiConfig(
        httpclient_log=os.path.join(os.curdir, find_log_path("http_client")),
        access_log=os.path.join(os.curdir, find_log_path("access")),
        sharpei_log=os.path.join(os.curdir, find_log_path("sharpei")),
        profiler_log=os.path.join(os.curdir, find_pa_log_path),
        status_log=os.path.join(os.curdir, find_log_path("status")),
    )


SharpeiConfig = namedtuple(
    "SharpeiConfig",
    (
        "httpclient_log",
        "access_log",
        "sharpei_log",
        "profiler_log",
        "status_log",
    ),
)


def main():
    args = parse_args()
    log.info("chdir %s" % os.path.abspath(args.dir))
    os.chdir(args.dir)

    sharpei_config = make_sharpei_config(args.sharpei_config_path)

    sharddb_hosts = lookup_in_cfg(args.sharpei_config_path, wrap_for_lookup(extract_hostlist))

    fast_forward = not args.from_beginning

    provider = get_metrics_provider(args.type)

    logs_list = []

    if provider.get_sharpei_log_errors_to_search_for():
        sharpei_log = sharpei_unistat.SharpeiLog(
            [],
            make_sharpei_log_metrics(
                "sharpei_errors", provider.get_sharpei_log_errors_to_search_for()
            ) + make_sharpei_log_metrics(
                "sharpei_registration", provider.get_sharpei_registration_messages_to_search_for()
            ),
            fast_forward,
            sharpei_config.sharpei_log,
        )
        logs_list += [sharpei_log]

    if provider.get_access_log_metrics("access_log"):
        access_log = AccessLogTskv(
            [],
            provider.get_access_log_metrics("access_log"),
            fast_forward,
            sharpei_config.access_log,
        )
        logs_list += [access_log]

    if provider.get_http_client_metrics():
        http_client_log = sharpei_unistat.SharpeiHttpClientLog(
            [],
            provider.get_http_client_metrics(),
            fast_forward,
            sharpei_config.httpclient_log,
        )
        logs_list += [http_client_log]

    if provider.get_pa_metrics(sharddb_hosts):
        pa_log = sharpei_unistat.SharpeiPaLog(
            [], provider.get_pa_metrics(sharddb_hosts), fast_forward, sharpei_config.profiler_log
        )
        logs_list += [pa_log]

    if sharddb_hosts:
        status_log = sharpei_unistat.StatusLog(
            [],
            provider.get_status_log_metrics("polling_status", sharddb_hosts),
            fast_forward,
            sharpei_config.status_log,
        )
        logs_list += [status_log]

    sharpei_unistat.run(args.host, args.port, logs_list, args.log, logLevel="info")


if __name__ == "__main__":
    main()
