import requests
from app.settings import PUNCHER_OAUTH_TOKEN
from collections import namedtuple
import traceback


PuncherPortRange = namedtuple("PuncherPortRange", ["proto", "port_start", "port_end"])


class PuncherAPIClient(object):
    """
    https://wiki.yandex-team.ru/noc/nocdev/puncher/api/
    """

    @staticmethod
    def parse_rule(rule):

        puncher_port_ranges = list()

        if rule.get("status") != "active":
            return list()

        ports = rule.get("ports", list())

        for port in ports:
            if port.isdigit():
                puncher_port_ranges.append(PuncherPortRange(
                    proto=rule["protocol"],
                    port_start=int(port),
                    port_end=int(port),
                ))
            else:
                port_parts = port.split("-")
                if (
                    len(port_parts) == 2 and 
                    port_parts[0].isdigit() and 
                    port_parts[1].isdigit()
                ):
                    puncher_port_ranges.append(PuncherPortRange(
                        proto=rule["protocol"],
                        port_start=int(port_parts[0]),
                        port_end=int(port_parts[1]),
                    ))
                else:
                    print(("[!] PuncherAPIClient:parse_rule. Unknown port type: {}".format(port)))

        return puncher_port_ranges

    @classmethod
    def parse_rules(cls, rules):
        all_puncher_port_ranges = list()

        for rule in rules:
            puncher_port_ranges = cls.parse_rule(rule)
            if puncher_port_ranges:
                all_puncher_port_ranges += puncher_port_ranges

        return all_puncher_port_ranges

    @classmethod
    def get_puncherportranges_from_inet(cls, dest="_EXTSVN_"):
        url = "https://puncher.yandex-team.ru/api/dynfw/rules"
        params = {
            "rules": "exclude_rejected",
            "sort": "destination",
            "values": "all",
            "system": "",
            "destination": dest,
            "source": "inet"
        }

        headers = {
            "Authorization": "OAuth {}".format(PUNCHER_OAUTH_TOKEN)
        }

        resp = requests.get(url, params=params, headers=headers)
        if resp.status_code != 200:
            return list()
            
        resp_json = resp.json()
        rules = resp_json.get("rules", list())

        puncher_port_ranges = cls.parse_rules(rules)
        return puncher_port_ranges


class CachedPuncherAPIClient(object):

    def __init__(self):
        self.cached_dest_dict = dict()

    def _get_puncherportranges_from_inet_impl(self, dest="_EXTSVN_"):
        puncherportranges = list()
        
        try:
            puncherportranges = PuncherAPIClient.get_puncherportranges_from_inet(dest=dest)
        except KeyboardInterrupt as e:
            raise e
        except:
            print("[!] CachedPuncherAPIClient:_get_puncherportranges_from_inet_impl. Exception.")
            traceback.print_exc()
        
        return puncherportranges

    def get_puncherportranges_from_inet(self, dest="_EXTSVN_"):
        if dest not in self.cached_dest_dict:
            self.cached_dest_dict[dest] = self._get_puncherportranges_from_inet_impl(dest=dest)
       
        return self.cached_dest_dict[dest]

    def is_allowed_from_inet(self, dest, port, protocol):
        puncherportranges = self.get_puncherportranges_from_inet(dest)
        return PuncherViolationChecker.is_allowed(port, protocol, puncherportranges)


class PuncherViolationChecker(object):

    @staticmethod
    def is_allowed(port, protocol, puncherportranges):
        for puncherportrange in puncherportranges:
            if (
                protocol == puncherportrange.proto and
                port >= puncherportrange.port_start and 
                port <= puncherportrange.port_end
            ):
                return True

        return False


