#!/usr/bin/env python3

""" Manage iptables&ipset rules """

from collections import defaultdict
import socket
from itertools import chain
import logging
import sys
import re
import subprocess
import argparse
import requests


C_URL = "https://c.yandex-team.ru/api-cached"
MYGROUP_API_URL = C_URL + "/generator/aggregation_group?fqdn="
GROUPS2HOSTS_URL = C_URL + "/groups2hosts/"

IPT_CHAIN_NAME = "INPUT"
RULE_TEMPLATE = "-p tcp --dport {} -m set ! --match-set {} src"\
    " -j REJECT --reject-with tcp-reset"

IPSET_NAME_MAX_LEN = 31
TIMEOUT = 5

def call_cmd(*args):
    """ Just call command. Output goes to stdout. """
    cmd = list(chain.from_iterable(args))
    return subprocess.call(cmd)

def call_shell(*args):
    """ Call command using shell """
    cmd = list(chain.from_iterable(args))
    return subprocess.call(" ".join(cmd), shell=True)


def get_my_group():
    """ Get my group name from MYGROUP_API_URL """
    my_fqdn = socket.getfqdn()
    try:
        req = requests.get(MYGROUP_API_URL + my_fqdn, timeout=TIMEOUT)
        if req.status_code == requests.codes.ok:      # pylint: disable=E1101
            return req.text

        return None
    except requests.exceptions.RequestException:
        return None


def get_my_neighbors(my_group):
    """ Get addresses of all my neighbors from GROUPS2HOSTS_URL """
    try:
        req = requests.get(GROUPS2HOSTS_URL + my_group, timeout=TIMEOUT)
        if req.status_code != requests.codes.ok:      # pylint: disable=E1101
            return None
        hosts_list = req.text
    except requests.exceptions.RequestException:
        return None

    neighbors = defaultdict(set)
    for host in hosts_list.splitlines():
        records = socket.getaddrinfo(host, 0, 0, socket.SOCK_STREAM)
        # list of tuples (family, socktype, proto, canonname, sockaddr)
        #                   0        1        2        3          4
        # see https://docs.python.org/3/library/socket.html#socket.getaddrinfo
        for rec in records:
            neighbors[rec[0]].add(rec[4][0])

    return neighbors


def make_ipset_name(group, s_proto_name):
    """ Truncate name when it's longer then IPSET_MAX_NAME_LEN """
    ipset_name = "{}-{}".format(group, s_proto_name)
    # 31 is max length for ipset's set name and also preserve
    # one char for tmp rule name
    logging.debug("ipset long name: %s, len: %d", ipset_name, len(ipset_name))
    if len(ipset_name) > IPSET_NAME_MAX_LEN-1:
        ipset_name = "s_" + ipset_name[-(IPSET_NAME_MAX_LEN-len("s_")-1):]
    safe_ipset_name = re.sub(r"[^a-zA-Z0-9_\-]", "_", ipset_name)
    return safe_ipset_name


def make_iptables_rules(group, neighbors, port=1334):
    """ Create proper iptables rules """
    logger = logging.getLogger(__name__)
    short_proto_names = {socket.AF_INET: "4", socket.AF_INET6: "6"}
    iptables_name = {socket.AF_INET: "iptables", socket.AF_INET6: "ip6tables"}
    family_name = {socket.AF_INET: "inet", socket.AF_INET6: "inet6"}

    for family in neighbors:
        if family not in short_proto_names:
            logger.error("Unknown family: %s", family)
            continue

        ipset_name = make_ipset_name(group, short_proto_names[family])
        check_rule = " ".join(
            (iptables_name[family], "-C", IPT_CHAIN_NAME,
             RULE_TEMPLATE.format(port, ipset_name),
             "2>/dev/null"))
        add_rule = " ".join(
            (iptables_name[family], "-A", IPT_CHAIN_NAME,
             RULE_TEMPLATE.format(port, ipset_name)))

        tmp_name = "_" + ipset_name
        call_cmd(["ipset", "create", tmp_name, "hash:ip", "family",
                  family_name[family]])
        for host in neighbors[family]:
            call_cmd(["ipset", "add", tmp_name, host])

        # Create main table to be sure it exists
        call_cmd(["ipset", "-!", "create", ipset_name, "hash:ip", "family",
                  family_name[family]])

        call_cmd(["ipset", "swap", tmp_name, ipset_name])
        call_cmd(["ipset", "destroy", tmp_name])

        if call_shell([check_rule]) == 0:
            logger.debug("Rule already found in chain %s", IPT_CHAIN_NAME)
        else:
            logger.debug("Rule not found in chain %s, adding it.",
                         IPT_CHAIN_NAME)
            call_shell([add_rule])


def remove_iptables_rules(group, neighbors, port=1334):
    """ Delete iptables rules """
    logger = logging.getLogger(__name__)
    short_proto_names = {socket.AF_INET: "4", socket.AF_INET6: "6"}
    iptables_name = {socket.AF_INET: "iptables", socket.AF_INET6: "ip6tables"}

    for family in neighbors:
        if family not in short_proto_names:
            logger.error("Unknown family: %s", family)
            continue
        ipset_name = make_ipset_name(group, short_proto_names[family])
        if call_shell(["ipset", "list", ipset_name, ">/dev/null",
                       "2>/dev/null"]) > 0:
            continue
        delete_rule = " ".join(
            (iptables_name[family], "-D", IPT_CHAIN_NAME,
             RULE_TEMPLATE.format(port, ipset_name)))
        if call_shell([iptables_name[family], "-C", IPT_CHAIN_NAME,
                       RULE_TEMPLATE.format(port, ipset_name),
                       "2>/dev/null"]) == 0:
            logger.debug("Rule found, delete it")
            call_shell([delete_rule])
        else:
            logger.debug("No rule found in chain %s", IPT_CHAIN_NAME)

def monrun_check_iptables_rules(group, neighbors, port=1334):
    """ Check if rule is in iptables """
    logger = logging.getLogger(__name__)
    short_proto_names = {socket.AF_INET: "4", socket.AF_INET6: "6"}
    iptables_name = {socket.AF_INET: "iptables", socket.AF_INET6: "ip6tables"}

    rule_not_found = False
    for family in neighbors:
        if family not in short_proto_names:
            logger.error("Unknown family: %s", family)
            continue
        ipset_name = make_ipset_name(group, short_proto_names[family])
        if call_shell(["ipset", "list", ipset_name, ">/dev/null",
                       "2>/dev/null"]) > 0:
            print("2;Ipset {} not found".format(ipset_name))
            return 1
        if call_shell([iptables_name[family], "-C", IPT_CHAIN_NAME,
                       RULE_TEMPLATE.format(port, ipset_name),
                       "2>/dev/null"]) != 0:
            rule_not_found = True
    if rule_not_found:
        print("2;Rule not found in iptables")
    else:
        print("0;Rule found in iptables")

def main():
    """ Main function """
    parser = argparse.ArgumentParser()
    parser.add_argument("action", choices=["close", "open", "check", "status"])
    parser.add_argument("port", type=int, nargs="?", default=1334)
    parser.add_argument(
        "-f", "--force",
        help="Force to close even dangerous ports", action="store_true")
    parser.add_argument(
        "-g", "--group", dest="c_group",
        help="Use specified conductor group")
    args = parser.parse_args()

    logger = logging.getLogger(__name__)
    if args.c_group:
        my_group = args.c_group
    else:
        my_group = get_my_group()

    if my_group is None:
        logger.error("Failed to get my group name")
        sys.exit(1)
    logger.info("My group is: %s", my_group)

    my_neighbors = get_my_neighbors(my_group)
    if not my_neighbors:
        logger.error("No neighbors found")
        sys.exit(2)

    if args.action == "close":
        if args.port in [22]:
            logger.error("Too dangerous to operate on port %d", args.port)
            if args.force:
                logger.error("But I am forced to do that")
            else:
                sys.exit(3)
        make_iptables_rules(my_group, my_neighbors, args.port)
    elif args.action in ["check", "status"]:
        monrun_check_iptables_rules(my_group, my_neighbors, args.port)
    elif args.action == "open":
        remove_iptables_rules(my_group, my_neighbors, args.port)
    else:
        logger.error("Unknown action %s", args.action)
        sys.exit(1)


if __name__ == "__main__":
    main()
