# -*- coding: utf-8 -*-

from collections import defaultdict
import difflib
import logging
import re
import urllib2

from sandbox import sdk2
from sandbox.common import types
from sandbox.sandboxsdk import environments


class SshExceptionRules(sdk2.Resource):
    pass


class SshExceptionGenerator(sdk2.Task):

    class Requirements(sdk2.Task.Requirements):
        environments = (environments.PipEnvironment("ipaddr", use_wheel=True),)

    class Parameters(sdk2.Task.Parameters):

        networks_url = sdk2.parameters.Url(
            "URL to the list of networks", required=True,
            default="https://racktables.yandex.net/export/networklist.php?report=usernets"  # noqa: E501
        )

        skip_re = sdk2.parameters.String(
            "RE to skip some networks", required=True,
            default=(
                "PDAS|MobTest|HelpDesk|Промежуточная|Продавцы|Саппорты"
                "|Временная|ШАД|SIP|_YMPROXYSRV_|США|Разработчики|сейлы|гости"
            )
        )

        path_template = sdk2.parameters.String(
            "Output path template", required=True,
            default="10-ssh-exception.v{}"
        )

    def on_execute(self):
        logging.info("Fetching list of user networks.")
        all_nets = self.get_networks()
        assert sorted(all_nets.keys()) == [4, 6]
        logging.info("Generating rules.")
        all_rules = {v: self.generate_rules(net)
                     for v, net in all_nets.items()}

        logging.info("Looking for previous resource.")
        prev_resource = SshExceptionRules.find(
            state=types.resource.State.READY
        ).first()
        all_diffs = {}
        if prev_resource:
            logging.info("Found resource with ID {}.".format(prev_resource.id))
            data = sdk2.ResourceData(prev_resource)
            for v, rules in all_rules.items():
                file_name = self.Parameters.path_template.format(v)
                try:
                    prev_rules = data.path.joinpath(file_name).read_bytes()
                except IOError:
                    prev_rules = ""
                prev_rules = prev_rules.splitlines()
                if prev_rules != rules:
                    all_diffs[v] = self.ndiff_plus_minus(prev_rules, rules)

        if prev_resource and not all_diffs:
            logging.info("No updates in rules.")
            self.set_info("No updates in rules.")
            return

        logging.info("Creating new resource.")
        data = sdk2.ResourceData(
            SshExceptionRules(self, "HBF SSH exception rules",
                              "ssh-exception-rules")
        )
        data.path.mkdir(0o755)
        for v, rules in all_rules.items():
            file_name = self.Parameters.path_template.format(v)
            data.path.joinpath(file_name).write_bytes("\n".join(rules) + "\n")

        task_info = [
            "Created new resource with rules.",
            ""
        ]
        for v, diff in all_diffs.items():
            msg = "Diff for IPv{}:".format(v)
            logging.info(msg)
            task_info += [
                msg,
                ""
            ]
            for d in diff:
                logging.info(d)
                task_info.append(d)
            task_info.append("")
        self.set_info("\n".join(task_info) + "\n")
        self.email_notification(all_diffs)

    def get_networks(self):
        # Packages installed via PipEnvironment are only available at runtime.
        from ipaddr import IPNetwork

        r = urllib2.urlopen(self.Parameters.networks_url)
        data = r.read()
        skip_re = re.compile(self.Parameters.skip_re)
        nets = defaultdict(list)
        for line in data.splitlines():
            fields = line.split(None, 1)
            if len(fields) == 2:
                net, desc = fields
            else:
                net = fields[0]
                desc = ""
            if skip_re.search(desc):
                continue
            net = IPNetwork(net)
            nets[net.version].append(net)
        for net in nets.values():
            net.sort()
        return nets

    @staticmethod
    def generate_rules(networks):
        rules = [
            "*filter",
            ":SSH_EXCEPTION_INPUT -",
            ":SSH_EXCEPTION_OUTPUT -",
            "-A INPUT -p tcp --dport 22 -j SSH_EXCEPTION_INPUT",
            "-A OUTPUT -p tcp --sport 22 -j SSH_EXCEPTION_OUTPUT"
        ]
        for net in networks:
            rules.append("-A SSH_EXCEPTION_INPUT -s {} -j ACCEPT".format(net))
            rules.append("-A SSH_EXCEPTION_OUTPUT -d {} -j ACCEPT".format(net))
        return rules

    def email_notification(self, all_diffs):
        email_body = [
            "Привет!",
            "",
            "Собрался новый ресурс с HBF правилами для SSH.",
            ""
        ]
        for v, diff in all_diffs.items():
            email_body += [
                "Изменения в IPv{}:".format(v),
                ""
            ] + diff + [
                ""
            ]
        email_body += [
            "Нужно проверить новые правила и убедиться, что они верны",
            "и нормально применяются на хостах.",
            "",
            "https://sandbox.yandex-team.ru/task/{}/".format(self.id),
            "",
            "",
            "-- ",
            "SSH_EXCEPTION_GENERATOR"
        ]

        self.server.notification(
            subject="Собрался новый ресурс с HBF правилами для SSH",
            body="\n".join(email_body) + "\n",
            recipients=["max7255"],
            transport=types.notification.Transport.EMAIL
        )

    @staticmethod
    def ndiff_plus_minus(l, r):
        return [d for d in difflib.ndiff(l, r) if d.startswith(("+", "-"))]
