import argparse
import logging

import iptables

import ipaddr
from config import Config


log = logging.getLogger(__name__.split(".")[-1])
log.addHandler(logging.NullHandler())


SPECIAL_TARGETS = ('LOG', 'NFLOG', 'ACCEPT', 'DROP', 'REJECT', 'RETURN')
INDENT_STEP = 4
INDENT_STRING = ('.' + (' ' * (INDENT_STEP - 1)))
ROOT_CHAINS = ["INPUT", "Y_FW", "OUTPUT", "Y_FW_OUT", "FORWARD"]


class IPTablesTraceException(Exception):
    pass


class SGR:
    csi = '\033['
    bold   = csi + '1m'   # noqa
    red    = csi + '31m'  # noqa
    green  = csi + '32m'  # noqa
    yellow = csi + '33m'  # noqa
    blue   = csi + '34m'  # noqa
    reset  = csi + '0m'   # noqa


TRACE_COLORS = {
    'ACCEPT': SGR.green,
    'LOG': SGR.yellow,
    'NFLOG': SGR.yellow,
    'DROP': SGR.red,
    'REJECT': SGR.red,
}


def paint(color, string):
    return color + string + SGR.reset


def do_trace(table_obj, chain, level=0, dump_only=False):
    if args.table not in table_obj:
        msg = "Table {!r} is not in iptables."
        raise IPTablesTraceException(msg.format(args.table))

    if not chain:
        for chain in ROOT_CHAINS:
            if chain in table_obj[args.table]:
                break
        else:
            raise IPTablesTraceException("No root chain found (expecting one of %s)" % ROOT_CHAINS)
    if chain not in table_obj[args.table]:
        msg = "Chain {!r} is not in {!r} table."
        raise IPTablesTraceException(msg.format(chain, args.table))

    trace = list()
    for rule in table_obj[args.table][chain]:
        m = not dump_only and rule.match(source_ip=args.ip, dest_ip=args.dest_ip, sport=args.sport, dport=args.dport, proto=args.proto)

        if m or args.full or args.debug:
            trace.append((level, rule, m))

        if not m and args.full:
            dump_only = True

        if m and rule.jump == "RETURN":
            break
        if m or args.full:
            target = rule.jump or rule.goto  # target can be None (no -j or -g)
            if target and target not in SPECIAL_TARGETS:
                trace += do_trace(table_obj, target, level + 1, dump_only)
        if m and rule.goto:
            break
    return trace


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'ip', type=ipaddr.IPAddress,
        help='source IP address to trace in iptables tree'
    )
    parser.add_argument(
        '--dest-ip', type=ipaddr.IPAddress,
        help='destination IP address to trace in iptables tree'
    )
    parser.add_argument(
        '-t', '--table', default='filter',
        help='table name to trace in (default: %(default)s)'
    )
    parser.add_argument(
        '-c', '--chain',
        help='Chain name to begin trace in (default: %s)' % (" / ".join(ROOT_CHAINS))
    )
    parser.add_argument(
        '-p', '--dport', type=int,
        help='destination port number'
    )
    parser.add_argument(
        '--sport', type=int,
        help='source port number'
    )
    protocols = ('tcp', 'udp', 'icmp', 'ipv6-icmp', 'ipencap', 'any')
    parser.add_argument(
        '--proto', choices=protocols, default='any',
        help='protocol name (default: %(default)s)'
    )
    log_levels = ('debug', 'info', 'warning', 'error', 'critical')
    parser.add_argument(
        '--log-level', choices=log_levels, default='error',
        help='log level (default: %(default)s)'
    )
    parser.add_argument(
        '--file',
        help='Load iptables dump from given file'
    )
    parser_group = parser.add_mutually_exclusive_group()
    parser_group.add_argument(
        '--debug', action='store_true',
        help='show content of all visited chains'
    )
    parser_group.add_argument(
        '--full-tree', dest='full', action='store_true',
        help='show entire tree')

    args = parser.parse_args()
    return args


def main():
    global args
    args = parse_args()
    config = Config()
    use_yandex_iptables = config.use_yandex_iptables()

    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError("Invalid log level: '{}'.".format(args.log_level))
    log_format = '%(asctime)s %(levelname)s: %(message)s'
    logging.basicConfig(level=numeric_level, format=log_format)

    dump = None
    if args.file:
        with open(args.file) as f:
            dump = f.read()

    if isinstance(args.ip, ipaddr.IPv4Address):
        table_obj = iptables.IPTables("v4", rule_class=iptables.Rule, dump=dump, use_yandex_iptables=use_yandex_iptables)
    else:
        table_obj = iptables.IPTables("v6", rule_class=iptables.Rule, dump=dump, use_yandex_iptables=use_yandex_iptables)

    if dump is None:
        table_obj.load_current(counters=True)

    if args.full or args.debug:
        default_color = SGR.blue
    else:
        default_color = SGR.reset

    trace = do_trace(table_obj, args.chain)
    for level, rule, match in trace:
        string = rule.string
        if match:
            string = paint(
                TRACE_COLORS.get(
                    rule.jump or rule.goto,
                    default_color
                ),
                string
            )
        print INDENT_STRING * level + string + (" " + rule.counters if rule.counters else "")
