#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Provides: check_hbf

import argparse
from collections import OrderedDict
import errno
import json
import os
import re
import subprocess
import sys
import time
import urllib2

import psutil


PACKAGE_NAMES = ("yandex-hbf-agent", "yandex-hbf-agent-static")
PROCESS_NAME = "yandex-hbf-agent"
CONFIG_PATH = "/etc/yandex-hbf-agent/yandex-hbf-agent.conf"
CONFIGSPEC_PATH = "/usr/share/yandex-hbf-agent/yandex-hbf-agent.configspec"
DEFAULT_UPDATE_PERIOD = 50
DEFAULT_UPDATE_PERIOD_RANDOM = 10
CHECK_NAME = "check_hbf"
TIMETAIL_REGEXP = r"(^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})"
LOG_PATH = "/var/log/yandex-hbf-agent/yandex-hbf-agent.log"
DELETED_PROTECTED_CHAINS_LOG = \
    "/var/tmp/check_hbf_deleted_protected_chains.log"
VIRTUAL_NET_DEVICES_PATH = "/sys/devices/virtual/net/"
BR_NETFILTER_PATH = "/sys/module/br_netfilter"


class JugglerStatus(object):

    def __init__(self, name):
        self.name = name
        self.status = 0
        self.description = ""

    def update(self, status, description=""):
        description = str(description).replace("\n", r"\n")
        if status > self.status:
            self.status = status
        if description:
            if self.description:
                self.description = self.description + "; " + description
            else:
                self.description = description

    def __str__(self):
        out = "PASSIVE-CHECK:{};{};{}"
        out = out.format(self.name, self.status,
                         self.description if self.description else "OK")
        return out


ANOTHER_APP = "Another app is currently holding the xtables lock"


def check_notrack(s):
    for iptables in ("iptables", "ip6tables"):
        cmd = ["sudo", "-n", iptables, "-S", "-t", "raw", "-w"]
        cmd_string = " ".join(cmd)
        p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
        stdout, stderr = p.communicate()
        if p.returncode != 0:
            s.update(2, "`{}` returned {}".format(cmd_string, p.returncode))
        if stderr:
            filtered_lines = [line for line in stderr.splitlines()
                              if not line.startswith(ANOTHER_APP)]
            filtered_stderr = "\n".join(filtered_lines)
            if filtered_stderr:
                s.update(2, "`{}` stderr: {}".format(cmd_string,
                                                     filtered_stderr))
        if stdout:
            count = count_notrack(stdout)
            if count < 2 or count > 4:
                s.update(
                    2, "{} NOTRACK rules listed by `{}`".format(
                        count, cmd_string
                    )
                )
        else:
            s.update(2, "Empty output from `{}`".format(cmd_string))


def count_notrack(out):
    count = 0
    for line in out.splitlines():
        if line.startswith("-A") and line.endswith("-j NOTRACK"):
            count += 1
    return count


def check_package(s):
    status_updates = []
    for name in PACKAGE_NAMES:
        installed, updates = _check_package(name)
        if installed:
            return
        else:
            status_updates.extend(updates)
    for update in status_updates:
        s.update(*update)


def dpkg_stderr_junk(l):
    return (l.startswith("dpkg-query: warning:") or
            l.strip().startswith("missing description"))


def _check_package(name):
    p = subprocess.Popen(["dpkg-query", "--show", "--showformat",
                          "${Status}", name],
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = p.communicate()
    filtered_stderr = "\n".join(line for line in stderr.splitlines()
                                if not dpkg_stderr_junk(line))
    retval = True
    status_updates = []
    if p.returncode != 0:
        status_updates.append((2, "dpkg-query returned " + str(p.returncode)))
        retval = False
    if filtered_stderr:
        status_updates.append((2, filtered_stderr))
        retval = False
    if stdout != "install ok installed":
        package_status = stdout if stdout else "<empty>"
        status_updates.append((2, "Package status: " + package_status))
        retval = False
    return retval, status_updates


def hbf_agent_cmd(cmd):
    first_word = cmd[0]
    normal_python = "python" in first_word and PROCESS_NAME in " ".join(cmd)
    static_python = first_word.endswith(PROCESS_NAME)
    return normal_python or static_python


def check_process(s):
    if hasattr(psutil.Process(), "memory_info"):
        memory_attr_name = "memory_info"
    else:
        memory_attr_name = "get_memory_info"

    found = False
    for proc in psutil.process_iter():
        try:
            proc_data = proc.as_dict(attrs=["cmdline", memory_attr_name])
        except psutil.NoSuchProcess:
            continue
        if proc_data["cmdline"]:
            cmdline = proc_data["cmdline"]
            if hbf_agent_cmd(cmdline):
                found = True
                rss = proc_data["memory_info"].rss
                break
    if not found:
        s.update(2, "Process `{}` not found".format(PROCESS_NAME))
    elif rss > 500 * 1024**2:
        s.update(2, "{} RSS is {} MiB".format(PROCESS_NAME, rss / 1024**2))
    return found


def to_juggler(agent_status):
    juggler_status = dict(agent_status)

    s = agent_status["status"]
    if s == "OK":
        juggler_status["status"] = 0
    elif s == "WARN":
        juggler_status["status"] = 1
    elif s == "CRIT":
        # Turn all agent criticals to warnings, until NOC improves server
        # performance.
        juggler_status["status"] = 1
    else:
        raise ValueError("Invalid `status`: " + str(s))

    lu = agent_status["last_update"]
    if not isinstance(lu, int):
        raise ValueError("Invalid `last_update`: " + str(lu))

    return juggler_status


def get_update_period(s):
    try:
        import configobj
        import validate
        c = configobj.ConfigObj(CONFIG_PATH, configspec=CONFIGSPEC_PATH)
        c.validate(validate.Validator())
        update_period = c["main"].get("update_period", DEFAULT_UPDATE_PERIOD)
        update_period_random = c["main"].get("update_period_random",
                                             DEFAULT_UPDATE_PERIOD_RANDOM)
    except Exception as e:
        s.update(1, "Unable to get `update_period`: " + str(e))
        update_period = DEFAULT_UPDATE_PERIOD
        update_period_random = DEFAULT_UPDATE_PERIOD_RANDOM
    return update_period + update_period_random


def check_status(s):
    try:
        url = "http://localhost:9876/status"
        r = urllib2.urlopen(url, timeout=5)
        agent_status = json.loads(r.read())
        juggler_status = to_juggler(agent_status)
    except Exception as e:
        s.update(2, "Unable to fetch agent status: " + str(e))
        return

    last_update = juggler_status["last_update"]
    update_period = get_update_period(s)
    update_limit = update_period * 5
    diff = int(time.time()) - last_update
    if diff > update_limit:
        desc = "HBF agent updated it's status {} s ago (limit is {} s)"
        desc = desc.format(diff, update_limit)
        s.update(2, desc)
    if diff < 0:
        desc = "`last update` value is from the future (clock skew)"
        s.update(2, desc)

    s.update(juggler_status["status"], juggler_status["desc"])


def get_protected_chains(s):
    try:
        from hbfagent.agent import get_local_rules
        dn = os.path.dirname
        rules_v4 = get_local_rules("v4", dn(CONFIG_PATH), dn(CONFIGSPEC_PATH))
        rules_v6 = get_local_rules("v6", dn(CONFIG_PATH), dn(CONFIGSPEC_PATH))
        protected_chains = set()
        for r in [rules_v4, rules_v6]:
            for k in r.protected_chains:
                protected_chains.update(r.protected_chains[k])
    except Exception as e:
        s.update(1, "Unable to get protected chains: " + str(e))
        try:
            with open(CONFIG_PATH) as f:
                c = json.load(f)
                protected_chains = c.get("protected_chains")
        except Exception as e:
            s.update(1, "Unable to get `protected_chains`: " + str(e))
            protected_chains = None
    return protected_chains


def run_and_check(cmd, s):
    cmd_string = " ".join(cmd)
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    stdout, stderr = p.communicate()
    if p.returncode != 0:
        s.update(2, "`{}` returned {}".format(cmd_string, p.returncode))
    if stderr:
        s.update(2, "`{}` stderr: {}".format(cmd_string, stderr))
    if not stdout:
        s.update(2, "Empty output from `{}`".format(cmd_string))
    return stdout


def get_log_timetail(s, log_path):
    cmd = ["sudo", "-n", "timetail", "-r", TIMETAIL_REGEXP, "-n", "300",
           "-j", "100000", log_path]
    return run_and_check(cmd, s)


def get_log_journalctl(s):
    cmd = ["journalctl", "SYSLOG_IDENTIFIER=yandex-hbf-agent", "-S", "-5m"]
    return run_and_check(cmd, s)


def get_log(s, log_path):
    codename = run_and_check(["lsb_release", "-cs"], s)
    if codename.strip() in ["xenial", "bionic", "focal"]:
        return get_log_journalctl(s)
    else:
        return get_log_timetail(s, log_path)


def check_protected_chains(s, log_path):
    protected_chains = get_protected_chains(s)
    if protected_chains is None:
        s.update(1, "Unable to get `protected_chains`")
        return
    if not protected_chains:
        return
    protected_chains = set(protected_chains)
    # Save log lines with deleted protected chains to temporary log file
    re_delete_chain = re.compile(r"Delete chain '(.*)'|"
                                 r"chain '(.*)' will be flushed/deleted")
    log = get_log(s, log_path)
    deleted_protected_chains_lines = []
    for line in log.splitlines():
        m = re_delete_chain.search(line)
        if m:
            chain = m.group(1) or m.group(2)
            if chain in protected_chains:
                deleted_protected_chains_lines.append(line)
    if deleted_protected_chains_lines:
        with open(DELETED_PROTECTED_CHAINS_LOG, "w") as log:
            log.write("\n".join(deleted_protected_chains_lines) + "\n")
    # Read saved log lines and set monitoring status
    deleted_protected_chains = set()
    try:
        for line in open(DELETED_PROTECTED_CHAINS_LOG):
            m = re_delete_chain.search(line)
            if m:
                chain = m.group(1) or m.group(2)
                deleted_protected_chains.add(chain)
    except IOError as e:
        if e.errno == errno.ENOENT:  # No such file or directory
            pass
        else:
            raise
    if deleted_protected_chains:
        s.update(2, "Following protected chains were deleted: {}".format(
            ", ".join(deleted_protected_chains)
        ))


def check_br_netfilter(s):
    br_ifaces = []
    for d in os.listdir(VIRTUAL_NET_DEVICES_PATH):
        if os.path.exists(os.path.join(VIRTUAL_NET_DEVICES_PATH, d, "bridge")):
            br_ifaces.append(d)
    if br_ifaces:
        if not os.path.exists(BR_NETFILTER_PATH):
            s.update(2, "You have bridge devices ({}), but 'br_netfilter'"
                        " is not loaded".format(", ".join(br_ifaces)))


def print_and_exit(s):
    print s
    sys.exit(0)


check_defaults = OrderedDict([
    ("package", True),
    ("process", True),
    ("status", True),
    ("br_netfilter", True),
    ("notrack", False),
    ("protected_chains", False)
])


def parse_args():
    p = argparse.ArgumentParser("check HBF agent")
    p.add_argument("--hbf-log-path", dest="log", default=LOG_PATH, help="Path to hbf-agent log")

    for check_name, check_default in check_defaults.items():
        opt_name = check_name.replace("_", "-")
        if check_default:
            default_option = "--with-" + opt_name
            alt_option = "--without-" + opt_name
            alt_help = "disable check 'check_{}'".format(check_name)
        else:
            default_option = "--without" + opt_name
            alt_option = "--with-" + opt_name
            alt_help = "enable check 'check_{}'".format(check_name)
        p.add_argument(default_option, dest=check_name, action="store_const",
                       const=check_default, default=check_default,
                       help=argparse.SUPPRESS)
        p.add_argument(alt_option, dest=check_name, action="store_const",
                       const=not check_default,
                       help=alt_help)
    return p.parse_args()


def main():
    args = parse_args()

    # jctl breaks with non-ascii data
    os.environ["LC_ALL"] = "C"

    s = JugglerStatus(CHECK_NAME)

    for check_name in check_defaults:
        if getattr(args, check_name):
            if check_name == "protected_chains":
                vars(sys.modules[__name__])["check_" + check_name](s, args.log)
            else:
                vars(sys.modules[__name__])["check_" + check_name](s)

    print_and_exit(s)


if __name__ == "__main__":
    main()
