import collections
import logging
import re
import struct
import socket
from copy import deepcopy
from subprocess import Popen, PIPE

from config import Config

import ipaddr
import prctl

from util import enumerate_lines
from util import Timer
from util import ArgParserFast
from util import LazyOptionParser

from hbfagent import orlyutil
from hbfagent import jugglerutil


ORLY_RULE = 'hbf-apply-rules'

BUILTIN_CHAINS = {
    "filter": {"INPUT", "FORWARD", "OUTPUT"},
    "nat": {"PREROUTING", "INPUT", "OUTPUT", "POSTROUTING"},
    "mangle": {"PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"},
    "raw": {"PREROUTING", "OUTPUT"},
    "security": {"INPUT", "FORWARD", "OUTPUT"}
}

log = logging.getLogger(__name__.split(".")[-1])
log.addHandler(logging.NullHandler())


def builtin_chain(table, chain):
    return table in BUILTIN_CHAINS and chain in BUILTIN_CHAINS[table]


class IPTablesError(Exception):
    pass


class SimpleRule(object):
    __slots__ = ['string', 'chain']

    first_options = ("-A", "-I", "-P")

    def __init__(self, line, **_):
        self.string = line
        s = line.find(' ') + 1
        e = line.find(' ', s)
        self.chain = line[s:e]

    def __eq__(self, other):
        return self.string == other.string

    def __ne__(self, other):
        return not self.__eq__(other)

    def __str__(self):
        return self.string

    def str_delete(self):
        return "-D " + " ".join(self.string.split()[1:])


class PortRange(LazyOptionParser):
    def parse(self, string):
        self._items = []
        for item in string.split(","):
            if ":" in item:
                start, end = map(int, item.split(":"))
                self._items.append((start, end))
            else:
                item = int(item)
                self._items.append((item, item))

    def __contains__(self, port):
        self.ensure_parsed()
        for item in self._items:
            if item[0] <= port <= item[1]:
                return True
        return False


class Project_id_match(LazyOptionParser):
    def parse(self, text):
        def convert_hex(val):
            if val[:2] == "0x":
                return int(val[2:], base=16)
            return int(val)

        m = re.match(r"(.*)\s*&\s*(.*)\s*=\s*(.*)$", text)
        if not m:
            raise ValueError("invalid u32 match %r" % text)
        offset, self.mask, self.project_id = tuple(map(convert_hex, m.groups()))

        if offset not in (16, 32):
            raise ValueError("Unknown IPv6 offset %d" % offset)
        self.src = (offset == 16)

    def is_src(self):
        self.ensure_parsed()
        return self.src

    def match(self, ip):
        self.ensure_parsed()
        ip_proj = struct.unpack(">L", ip.packed[8:12])[0]
        return ip_proj & self.mask == self.project_id & self.mask


class Rule(object):
    ''' Rule structure '''
    first_options = ("-A", "-I", "-P")
    rule_parser = ArgParserFast()
    rule_parser.add_argument(*first_options, dest='chain')
    rule_parser.add_argument('-s', '--source')
    rule_parser.add_argument('-d', '--destination')
    rule_parser.add_argument('-p', '--protocol')
    rule_parser.add_argument('-j', '--jump')
    rule_parser.add_argument('-g', '--goto')
    rule_parser.add_argument('-f', dest='fragment', action='store_true')
    rule_parser.add_argument('--ports', type=PortRange)
    rule_parser.add_argument('--sport', '--source-port', type=PortRange)
    rule_parser.add_argument('--sports', '--source-ports', type=PortRange)
    rule_parser.add_argument('--dport', '--destination-port', type=PortRange)
    rule_parser.add_argument('--dports', '--destination-ports', type=PortRange)
    rule_parser.add_argument('--state')
    rule_parser.add_argument('--u32', type=Project_id_match)
    rule_parser.add_argument('-m', dest='module', action='append')

    def __init__(self, line, counters=None):
        self.string = line
        self.counters = counters

        self.rule, self.unknown_args = \
            self.rule_parser.parse_known_args(line.split())

    def __getattr__(self, name):
        return getattr(self.rule, name)

    def __str__(self):
        return self.string

    def match(self, source_ip=None, dest_ip=None, sport=None, dport=None, proto=None):
        if self.source and source_ip:
            net = ipaddr.IPNetwork(self.source)
            if source_ip not in net:
                return False

        if self.destination and dest_ip:
            net = ipaddr.IPNetwork(self.destination)
            if dest_ip not in net:
                return False

        for search, mports, port in [
            (sport, self.sports, self.sport),
            (dport, self.dports, self.dport),
        ]:
            if search:
                if mports and search not in mports:
                    return False
                if port and search not in port:
                    return False

        if (
            self.ports
            and sport and sport not in self.ports
            and dport and dport not in self.ports
        ):
            return False

        if self.protocol and proto and proto != 'any':
            if self.protocol != proto:
                return False

        if self.u32:
            if self.u32.is_src() and source_ip and not self.u32.match(source_ip):
                return False
            if not self.u32.is_src() and dest_ip and not self.u32.match(dest_ip):
                return False

        return True


class IPTables(collections.MutableMapping):

    # Stats
    test_time = 0.0
    apply_time = 0.0
    apply_time_protected = 0.0
    gc_time = 0.0

    rule_count = 0
    chain_count = 0

    header_re = re.compile(r"#\s*Yandex-HBF-Agent\s*:\s*(?P<value>\S+)", re.I)

    # FIXME: Should be immutable collections.Mapping
    '''Comparable list or chains and their rules'''

    def __init__(self, ip_version="v6", dump=None, rule_class=SimpleRule, use_yandex_iptables=False):
        self.ip_version = ip_version
        self.use_yandex_iptables = use_yandex_iptables
        self.need_sudo = not (
            prctl.cap_inheritable.net_admin and
            prctl.cap_inheritable.net_raw and
            prctl.cap_inheritable.dac_override and
            prctl.cap_inheritable.sys_module
        )
        log.debug("Using yandex iptables: {}".format(self.use_yandex_iptables))
        if self.use_yandex_iptables:
            ipt_prefix = "/opt/yandex-iptables/"
        else:
            ipt_prefix = ""

        if ip_version == "v6":
            self._iptables_exe = ''.join((ipt_prefix, "ip6tables"))
            self._iptables_save_exe = ''.join((ipt_prefix, "ip6tables-save"))
            self._iptables_restore_exe = ''.join((ipt_prefix, "ip6tables-restore"))
        elif ip_version == "v4":
            self._iptables_exe = ''.join((ipt_prefix, "iptables"))
            self._iptables_save_exe = ''.join((ipt_prefix, "iptables-save"))
            self._iptables_restore_exe = ''.join((ipt_prefix, "iptables-restore"))
        else:
            raise ValueError("Invalid 'ip_version': " + str(ip_version))

        self.rule_class = rule_class
        self.tables = {}
        self.protected_chains = {}
        self.config = Config()
        if dump:
            self.parse_dump(dump)
            self._remove_builtin()
            self._check_protected_chains()

    def load_current(self, counters=False):
        ec, dump = self._iptables_save(counters)
        if ec != 0:
            raise IPTablesError(self._iptables_save_exe + " failed.")
        self.parse_dump(dump)
        self._remove_builtin()

    def get_vxL_chain_dump(self, chain):
        cmd = [self._iptables_exe, "-v", "-x", "-n", "-L", chain, "-w"]
        if self.need_sudo:
            cmd.insert(0, 'sudo')
        ec, output, error = self._popen(cmd)
        if ec != 0:
            log.debug("'{}' exit code: {}".format(" ".join(cmd), ec))
        if error:
            log.debug("'{}' stderr:\n{}".format(" ".join(cmd), error))
        return output

    @staticmethod
    def parse_chain(dump, target):
        target_compiled_re = re.compile((
        r'(\s+)?(?P<pkts>\d+)\s+'
        r'(?P<bytes>\d+)\s+'
        r'(?P<target>{target})\s+'
        r'(?P<prot>\w+)\s+'
        r'(?P<opt>\w+)?\s+'
        r'(?P<in>[\w\*\-\_\@]+)\s+'
        r'(?P<out>[\w\*\-\_\@]+)\s+'
        r'(?P<source>[\w:/]+)\s+'
        r'(?P<destination>[\w:/]+)'
        r'(?P<comment>.*)'
        ).format(target=target),
        re.MULTILINE,
    )

        return list(target_compiled_re.finditer(dump))

    def parse_dump(self, dump):
        """Parse textual representation of rules."""
        log.debug("Parsing dump.")
        line = ''

        def parse_error():
            msg = "Parse error in line {}".format(n)
            log.error(msg + ":\n" + enumerate_lines(dump))
            raise IPTablesError(msg + "\n" + repr(line))

        current_table = None
        protected_line = False
        protected_table = False
        skip_block = False
        for n, line in enumerate(dump.splitlines(), 1):
            line = line.rstrip()

            # if dump is a hbf server output, skip blocks from alien AF
            if line in [
                "#BEGIN IP6TABLES",
                "#BEGIN IPTABLES",
            ]:
                skip_block = (self.ip_version == "v6") != ("6" in line)
            if skip_block:
                continue

            # extract counters from lines
            counters = None
            if line.startswith("["):
                found = line.find("]") + 1
                if found:
                    counters = line[:found]
                    line = line[found:].strip()

            if len(line) == 0:
                continue
            elif line.startswith("#"):
                m = self.header_re.match(line)
                if m:
                    flag = m.group("value").lower()
                    if flag == "protected":
                        protected_line = True
            elif line == "COMMIT":
                current_table = None
                if protected_line:
                    log.warning(
                        "Line with 'COMMIT' cannot be marked as protected"
                    )
                    protected_line = False
                protected_table = False
            elif line.startswith("*"):
                tokens = line[1:].split()
                if current_table is not None or len(tokens) < 1:
                    parse_error()
                current_table = tokens[0]
                self.add_table(current_table)
                if protected_line:
                    protected_table = True
                    protected_line = False
            elif line.startswith(":"):
                tokens = line[1:].split()
                if current_table is None or len(tokens) < 2:
                    parse_error()
                chain, policy = tokens[0:2]
                if protected_line or protected_table:
                    try:
                        self.add_protected_chain(current_table, chain)
                    except IPTablesError as e:
                        log.warning("Skipping line {}: {}".format(n, e))
                        self.dump(dump)
                    protected_line = False
                else:
                    self.add_chain(current_table, chain)
            elif line.startswith(self.rule_class.first_options):
                if current_table is None:
                    parse_error()
                r = self.rule_class(line, counters=counters)
                if (r.chain not in self[current_table] and
                        not builtin_chain(current_table, r.chain)):
                    log.warning("Chain '{}' not defined.".format(r.chain))
                self.append_rule(current_table, r.chain, r)
                if protected_line:
                    log.warning("Rule cannot be marked as protected.")
                    protected_line = False
            else:
                parse_error()

    def _remove_builtin(self):
        """Remove all empty built-in chains, then remove all empty tables.
        FIXME: Chain default policy is not considered.
        """
        for table in self.keys():
            for chain in self[table].keys():
                if (chain in BUILTIN_CHAINS[table] and
                        not self[table][chain]):
                    del self[table][chain]
            if not self[table]:
                del self[table]

    def _check_protected_chains(self):
        for table in self.protected_chains:
            for chain in self.protected_chains[table]:
                if table in self.tables and chain in self.tables[table]:
                    msg = ("Chain '{}' in table '{}' defined normally"
                           " as well as protected (won't be protected).")
                    msg = msg.format(chain, table)
                    log.warning(msg)

    def add_table(self, table):
        if table not in self.tables:
            self.tables[table] = {}

    def add_chain(self, table, chain):
        self.add_table(table)
        if chain not in self.tables[table]:
            self.tables[table][chain] = []

    def append_rule(self, table, chain, rule):
        if isinstance(rule, basestring):
            rule = self.rule_class(rule)
        self.add_chain(table, chain)
        self.tables[table][chain].append(rule)

    def add_protected_chain(self, table, chain):
        if table not in self.protected_chains:
            self.protected_chains[table] = set()
        if chain not in self.protected_chains[table]:
            self.protected_chains[table].add(chain)

    @staticmethod
    def _popen(cmd, input=None):
        log.debug("Executing: '{}'.".format(" ".join(cmd)))
        if input is None:
            p = Popen(cmd, stdout=PIPE, stderr=PIPE)
            output, error = p.communicate()
        else:
            p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
            output, error = p.communicate(input=input)
        return p.returncode, output, error

    def _iptables_save(self, counters=False):
        cmd = [self._iptables_save_exe]
        if self.need_sudo:
            cmd.insert(0, 'sudo')
        if counters:
            cmd += ["--counters"]
        ec, output, error = self._popen(cmd)
        if ec != 0:
            log.error("'{}' exit code: {}".format(" ".join(cmd), ec))
        if error:
            log.debug("'{}' stderr:\n{}".format(" ".join(cmd), error))
        return ec, output

    def _iptables_restore(self, dump, test=False):
        opts = ["--noflush"]
        if test:
            opts += ["--test"]
        if self.use_yandex_iptables:
            opts += ["-w"]
        cmd = [self._iptables_restore_exe] + opts
        if self.need_sudo:
            cmd.insert(0, 'sudo')
        ec, output, error = self._popen(cmd, input=dump)
        if ec != 0:
            log.error("'{}' exit code: {}".format(" ".join(cmd), ec))
        if error:
            log.debug("'{}' stderr:\n{}".format(" ".join(cmd), error))
            self.dump(dump)
        return ec

    def count(self):
        tables, chains, rules = 0, 0, 0
        tables += len(self)
        for t in self:
            chains += len(self[t])
            for c in self[t]:
                rules += len(self[t][c])
        return tables, chains, rules

    def dump(self, rules=None):
        if rules is None:
            rules = str(self)
        log.debug("Dumping rules:\n" + enumerate_lines(rules))

    def apply(self, force=False):
        current = IPTables(self.ip_version, use_yandex_iptables=self.use_yandex_iptables)
        current.load_current()
        _, _, current_nrules = current.count()
        log.debug('Number of current rules: {}'.format(current_nrules))

        # Create protected chains.
        rules = ""
        msg_prefix = "Creating protected IP{} chain".format(self.ip_version)
        for table in self.protected_chains:
            missing_chains = []
            for chain in self.protected_chains[table]:
                if ((table not in current or chain not in current[table]) and
                        not builtin_chain(table, chain)):
                    log.info(msg_prefix +
                             ": table '{}', chain '{}'.".format(table, chain))
                    missing_chains.append(chain)
            if missing_chains:
                rules += "*{}\n".format(table)
                rules += "\n".join(":{} -".format(chain)
                                   for chain in missing_chains) + "\n"
                rules += "COMMIT\n"
        if rules:
            t = Timer()
            ec = self._iptables_restore(rules)

            self.apply_time_protected = t.interval

            if ec != 0:
                log.error(msg_prefix + "s: FAIL")
            else:
                log.info(msg_prefix + "s: OK")

        # Apply rules.
        log.info("Applying IP{} rules.".format(self.ip_version))
        rules = str(self)

        ntables, nchains, nrules = self.count()
        log.debug('Number of new rules: {}'.format(nrules))
        t = Timer()
        ec = self._iptables_restore(rules, test=True)

        self.test_time = t.interval
        self.rule_count = nrules
        self.chain_count = nchains

        msg = "Tested {} tables, {} chains, {} rules in {:.3f} seconds."
        msg = msg.format(ntables, nchains, nrules, self.test_time)
        log.info(msg)
        if ec != 0:
            msg = "Testing IP{} rules: FAIL".format(self.ip_version)
            log.error(msg)
            raise IPTablesError(msg)

        # RTCNETWORK-157
        if self.rule_count <= (current_nrules / 2):
            log.warning(
                'New ruleset is much smaller than previous. Cur: {} Prev: {}'.format(
                    self.rule_count, current_nrules
                )
            )
            jugglerutil.push_ruleset_crit()

            if force:
                log.info("Force flag set, ignore ORLY")
            else:
                if self.config['orly']['orly_enabled']:
                    log.info('ORLY Enabled')
                    ok, err = orlyutil.start_operation(ORLY_RULE, socket.gethostname())
                    if not ok:
                        msg = 'Rules apply is not allowed by ORLY: {}'.format(err)
                        log.error(msg)
                        raise IPTablesError(msg)
                else:
                    log.info('ORLY Disabled, apply rules')
        else:
            jugglerutil.push_ruleset_ok()

        t = Timer()
        ec = self._iptables_restore(rules)

        self.apply_time = t.interval

        msg = "Applied {} tables, {} chains, {} rules in {:.3f} seconds."
        msg = msg.format(ntables, nchains, nrules, self.apply_time)
        log.info(msg)
        if ec != 0:
            msg = "Applying IP{} rules: FAIL".format(self.ip_version)
            log.error(msg)
            raise IPTablesError(msg)
        log.info("Applying IP{} rules: OK".format(self.ip_version))

        # Collect garbage.
        log.info("Collecting IP{} garbage.".format(self.ip_version))
        rules = ""
        for table in current:
            garbage = []
            for chain in current[table]:
                our = (table in self.tables and
                       chain in self.tables[table])
                protected = (table in self.protected_chains and
                             chain in self.protected_chains[table])
                if (not our and not protected):
                    log.debug("Table '{}', chain '{}' will be"
                              " flushed/deleted.".format(table, chain))
                    garbage.append(chain)
            if garbage:
                rules += "*{}\n".format(table)
                rules += "\n".join("-F " + str(chain)
                                   for chain in garbage) + "\n"
                rules += "\n".join("-X " + str(chain)
                                   for chain in garbage
                                   if not builtin_chain(table, chain)) + "\n"
                rules += "COMMIT\n"

        t = Timer()
        if rules:
            ec = self._iptables_restore(rules)
            if ec != 0:
                msg = "Collecting IP{} garbage: FAIL".format(self.ip_version)
                log.error(msg)
                raise IPTablesError(msg)

        self.gc_time = t.interval
        log.info("Collecting IP{} garbage: OK".format(self.ip_version))

    def apply_delete(self):
        current = IPTables(self.ip_version, use_yandex_iptables=self.use_yandex_iptables)
        current.load_current()
        to_delete = self & current
        rules = to_delete.str_delete()
        if rules:
            ec = self._iptables_restore(rules)
            if ec != 0:
                msg = "Failed to delete IP{} rules.".format(self.ip_version)
                raise IPTablesError(msg)

    def __str__(self):
        result = ""
        for table in self.tables:
            result += "*{}\n".format(table)
            table_rules = ""
            for chain in self.tables[table]:
                result += ":{} -\n".format(chain)
                if builtin_chain(table, chain):
                    table_rules += "-F {}\n".format(chain)
                table_rules += "\n".join(map(str, self.tables[table][chain]))
                table_rules += "\n"
            result += table_rules
            result += "COMMIT\n"
        return result

    def str_delete(self):
        result = ""
        for table in self.tables:
            result += "*{}\n".format(table)
            table_rules = ""
            for chain in self.tables[table]:
                table_rules += "\n".join(r.str_delete() for r in
                                         self.tables[table][chain]) + "\n"
            result += table_rules
            result += "COMMIT\n"
        return result

    def _can_op(self, other):
        if not isinstance(other, self.__class__):
            return False
        if self.ip_version != other.ip_version:
            raise ValueError(
                "Cannot perform operation on objects with different"
                " IP versions"
            )
        return True

    def __eq__(self, other):
        if not self._can_op(other):
            return NotImplemented
        return self.tables == other.tables

    def __ne__(self, other):
        if not self._can_op(other):
            return NotImplemented
        return self.tables != other.tables

    def __add__(self, other):
        if not self._can_op(other):
            return NotImplemented
        result = deepcopy(self)
        result += other
        return result

    def __iadd__(self, other):
        if not self._can_op(other):
            return NotImplemented
        for table in other:
            for chain in other[table]:
                chain_copy = deepcopy(other[table][chain])
                self.add_chain(table, chain)
                self[table][chain] += chain_copy
        for table in other.protected_chains:
            if table not in self.protected_chains:
                self.protected_chains[table] = set()
            self.protected_chains[table] |= other.protected_chains[table]
        self._check_protected_chains()
        return self

    def __and__(self, other):
        if not self._can_op(other):
            return NotImplemented
        and_tables = IPTables(self.ip_version)
        for table in self:
            for chain in self[table]:
                for rule in self[table][chain]:
                    if (table in other and chain in other[table] and
                            rule in other[table][chain]):
                        and_tables.append_rule(table, chain, deepcopy(rule))
        for table in self.protected_chains:
            if table in other.protected_chains:
                and_tables.protected_chains[table] = \
                    (self.protected_chains[table] &
                     other.protected_chains[table])
        return and_tables

    def __getitem__(self, name):
        return self.tables[name]

    def __setitem__(self, key, value):
        self.tables[key] = value

    def __delitem__(self, key):
        del self.tables[key]

    def __iter__(self):
        return iter(self.tables)

    def __len__(self):
        return len(self.tables)
