#!/usr/bin/env python
#
# Provides: check_tun
#
# simple tunnels checker

import subprocess
import threading
import argparse
import commands
import time

KILL_TRIES = 5

CHECKS = {
    "ipv4_syn": "sudo tcpdump -c1 -i {0} 'tcp[13] & 2 != 0' 2>/dev/null",
    "ipv4_push": "sudo tcpdump -c1 -i {0} 'tcp[13] & 8 != 0' 2>/dev/null",
    "ipv4_rst": "sudo tcpdump -c1 -i {0} 'tcp[13] & 4 != 0' 2>/dev/null",
    "ipv6_syn": "sudo tcpdump -c1 -i {0} 'ip6[53] & 2 != 0' 2>/dev/null",
    "ipv6_push": "sudo tcpdump -c1 -i {0} 'ip6[53] & 8 != 0' 2>/dev/null",
    "ipv6_rst": "sudo tcpdump -c1 -i {0} 'ip6[53] & 4 != 0' 2>/dev/null",
}

CHECK_IP_OR = 0
CHECK_IP_AND = 1
CHECK_IP6_ONLY = 2
CHECK_IP4_ONLY = 3


class Command(object):
    def __init__(self, cmd):
        self.cmd = cmd
        self.process = None
        self.stdout = ""
        self.stderr = ""

    def _run_thread(self):
        self.process = subprocess.Popen(self.cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        self.stdout, self.stderr = self.process.communicate()

    def run(self, timeout=3):
        thread = threading.Thread(target=self._run_thread)
        thread.start()
        thread.join(timeout)

        tries = 0
        pids = find_pid(self.process.pid)
        while thread.is_alive():
            for pid in pids:
                sudo_kill(pid, "TERM")
            tries += 1
            if tries > KILL_TRIES:
                for pid in pids:
                    sudo_kill(pid, "KILL")
            thread.join()
            time.sleep(1)


def find_pid(proc):
    # use commands instead of Command class to prevent recursion in run method
    code, cmd = commands.getstatusoutput("ps --no-header --ppid {0} | awk '{{print $1}}'".format(proc))
    pids = []
    for i in cmd.split("\n"):
        try:
            pids.append(int(i))
        except ValueError:
            pass
    return pids


def sudo_kill(pid, sig):
    commands.getstatusoutput("sudo kill -{0} {1}".format(sig, pid))


def check_tunnel_addr(tun):
    code, cmd = commands.getstatusoutput("ip a | grep -i '{0}'".format(tun))
    if tun in cmd:
        return True
    return False


def check_tunnel_cmd(check_type, timeout, tun, cmd):
    def check(tun, cmd):
        cmd = Command(CHECKS[cmd].format(tun))
        cmd.run(timeout=timeout)

        if "Flags" in cmd.stdout:
            return True
        return False

    if check_type == CHECK_IP_OR:
        return check(tun, "ipv4_{0}".format(cmd)) or check(tun, "ipv6_{0}".format(cmd))
    elif check_type == CHECK_IP_AND:
        return check(tun, "ipv4_{0}".format(cmd)) and check(tun, "ipv6_{0}".format(cmd))
    elif check_type == CHECK_IP4_ONLY:
        return check(tun, "ipv4_{0}".format(cmd))
    elif check_type == CHECK_IP6_ONLY:
        return check(tun, "ipv6_{0}".format(cmd))
    raise Exception("Unknown check type")


def main(args):
    errors = []

    check_type = CHECK_IP_OR
    if args.v4:
        check_type = CHECK_IP4_ONLY
    elif args.v6:
        check_type = CHECK_IP6_ONLY
    elif args.v46:
        check_type = CHECK_IP_AND

    for tun in args.t:
        if not check_tunnel_addr(tun):
            errors.append("No {0} tunnel interface was found".format(tun))
            continue
        if check_tunnel_cmd(check_type, args.d, tun, "push"):
            continue
        if check_tunnel_cmd(check_type, args.d, tun, "syn"):
            errors.append("No PUSH packet, only SYN on {0}".format(tun))
            continue
        if check_tunnel_cmd(check_type, args.d, tun, "rst"):
            errors.append("Only RST packets on {0}".format(tun))
            continue
        errors.append("No packets were found on {0}".format(tun))

    if len(errors) > 0:
        print("PASSIVE-CHECK:check_tun;2;problems: {0}".format(", ".join(errors)))
    else:
        print("PASSIVE-CHECK:check_tun;0;ok")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Tunnels checkers for L7 balancer")
    parser.add_argument("-t", type=str, nargs="+", default=["tun0"], help="Use -t tun0 -t tun1 to specify multiple arguments. Worst time to check interface without packets is 30 seconds.")
    parser.add_argument("-v4", action="store_true", help="Check only IPv4")
    parser.add_argument("-v6", action="store_true", help="Check only IPv6")
    parser.add_argument("-v46", action="store_true", help="Check if IPv4 and IPv6 both working(default IPv4 or IPv6)")
    parser.add_argument("-d", type=int, default=3, help="Timeout for external commands")

    main(parser.parse_args())
