# import logging
# logging.basicConfig(filename='debug.log', level=logging.DEBUG)

import zlib
import random
import string
from collections import defaultdict
from lxml import etree
from gvm.connections import UnixSocketConnection, DebugConnection, TLSConnection
from gvm.protocols.gmpv8 import AliveTest
from gvm.protocols.latest import Gmp
from gvm.transforms import EtreeCheckCommandTransform
from gvm.protocols.gmpv7.types import InfoType


class OpenVASClientException(Exception):
    pass


def connection(func):
    def wrapper(*args):
        cli = args[0]
        try:
            cli._connect()
            cli._auth()
            res = func(*args)
            cli._disconnect()
            return res
        except Exception as e:
            cli._disconnect()
            raise e

        # return func(*args)
        # res = func(*args)
        # cli._disconnect()
        # return res
    return wrapper


class OpenVASParser(object):

    @staticmethod
    def _parse_nvt(nvt):
        nvt_info = dict()
        nvt_details = nvt.getchildren()
        # scan_nvt_version = list(filter(lambda c: c.tag == 'scan_nvt_version', r))[0].text
        nvt_info['oid'] = nvt.get('oid')
        nvt_info['name'] = list(filter(lambda c: c.tag == 'name', nvt_details))[0].text
        nvt_info['family'] = list(filter(lambda c: c.tag == 'family', nvt_details))[0].text
        values = list(filter(lambda c: c.tag == 'cvss_base', nvt_details))
        nvt_info['cvss'] = float(values[0].text) if values else None
        values = list(filter(lambda c: c.tag == 'xref', nvt_details))
        nvt_info['xref'] = values[0].text if values else None
        values = list(filter(lambda c: c.tag == 'tags', nvt_details))
        nvt_info['tags'] = values[0].text if values else None
        values = list(filter(lambda c: c.tag == 'cert', nvt_details))
        nvt_info['cert'] = values[0].text if values else None
        values = list(filter(lambda c: c.tag == 'solution', nvt_details))
        nvt_info['solution'] = values[0].text if values else None
        values = list(filter(lambda c: c.tag == 'cve', nvt_details))
        nvt_cve_ = values[0].text if values else None
        nvt_info['cve'] = nvt_cve_ if nvt_cve_ and nvt_cve_ != 'NOCVE' else None
        values = list(filter(lambda c: c.tag == 'bid', nvt_details))
        nvt_bid_ = values[0].text if values else None
        nvt_info['bid'] = nvt_bid_ if nvt_bid_ and nvt_bid_ != 'NOBID' else None

        return nvt_info

    @staticmethod
    def parse_result(result):
        result_dict = dict()
        r = result.getchildren()

        # qod
        qod_vuln_type_ = list(filter(lambda c: c.tag == 'qod', r))[0].getchildren()
        result_dict['qod'] = float(qod_vuln_type_[0].text)
        result_dict['qod_type'] = qod_vuln_type_[1].text
        # port
        port_ = list(filter(lambda c: c.tag == 'port', r))[0].text.split('/')
        result_dict['port_id'] = -1 if port_[0] == 'general' else int(port_[0])
        result_dict['transport'] = port_[1]
        # host and hostname and ip
        host_entry = list(filter(lambda c: c.tag == 'host', r))[0]
        result_dict['host'] = host_entry.text
        result_dict['hostname'] = None
        host_childs = host_entry.getchildren()
        if host_childs:
            hostname_entry_ = list(filter(lambda c: c.tag == 'hostname', host_childs))
            if hostname_entry_:
                hostname = hostname_entry_[0].text
                if hostname:
                    result_dict['hostname'] = hostname
        # other
        result_dict['description'] = list(filter(lambda c: c.tag == 'description', r))[0].text
        result_dict['severity'] = float(list(filter(lambda c: c.tag == 'severity', r))[0].text)
        result_dict['threat'] = list(filter(lambda c: c.tag == 'threat', r))[0].text
        result_dict['name'] = list(filter(lambda c: c.tag == 'name', r))[0].text
        # nvt
        nvt_ = list(filter(lambda c: c.tag == 'nvt', r))
        if nvt_:
            result_dict['nvt'] = OpenVASParser._parse_nvt(nvt_[0])
        else:
            result_dict['nvt'] = dict()

        return result_dict


class OpenVASClient(object):

    def __init__(self, debug=False):
        self._tls_connection = TLSConnection()
        if debug:
            self._connection = DebugConnection(self._tls_connection)
        else:
            self._connection = self._tls_connection

        self._transform = EtreeCheckCommandTransform()
        self._gmp = Gmp(connection=self._connection, transform=self._transform)

        self._username = 'admin'

    ###
    # Connection
    ###

    def _connect(self):
        if not self._gmp.is_connected():
            self._gmp.connect()

    def _disconnect(self):
        if self._gmp.is_connected():
            self._gmp.disconnect()

    ###
    # Authentication
    ###

    def _auth(self):
        # if not self._gmp.is_authenticated(): ### THIS ONE WORKS NOT PROPERLY !!!
        self._gmp.authenticate(self._username, self._password)

    def set_username(self, username):
        self._username = username

    def set_password(self, password):
        self._password = password

    ###
    # Targets
    ###

    @connection
    def create_target(self, hosts, port_list_id, name=None):
        """

        return:
            uuid OR None
        """

        if not isinstance(hosts, list):
            raise OpenVASClientException('list of targets required')

        if not name:
            name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))

        resp = self._gmp.create_target(name, hosts=hosts, port_list_id=port_list_id, alive_test=AliveTest.CONSIDER_ALIVE)
        status = resp.get('status')
        uuid = resp.get('id')

        if status == '201':
            return uuid
        else:
            return None

    @connection
    def delete_target(self, uuid):
        resp = self._gmp.delete_target(uuid)
        return resp.get('status') == '200'

    ###
    # Port list
    ###

    @connection
    def create_port_list(self, port_list, name=None):
        """
        port_list: 
            T:1-3,7,9,U:12,155

        return:
            uuid OR None
        """
        if not name:
            name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))

        resp = self._gmp.create_port_list(name, port_range=port_list)
        status = resp.get('status')
        uuid = resp.get('id')

        if status == '201':
            return uuid
        else:
            return None

    @connection
    def delete_port_list(self, uuid):
        resp = self._gmp.delete_port_list(uuid)
        return resp.get('status') == '200'

    ###
    # Scanner
    ###

    @connection
    def get_scanner_by_name(self, name):
        """
        
        return:
            uuid OR None
        """

        resp = self._gmp.get_scanners()
        if resp.get('status') != '200':
            return None

        scanners = list(filter(lambda c: c.tag == 'scanner', resp.getchildren()))

        for scanner in scanners:
            scanner_uuid = scanner.get('id')
            scanner_name = list(filter(lambda c: c.tag == 'name', scanner.getchildren()))[0].text

            if scanner_name == name:
                return scanner_uuid

        return None

    def get_openvas_default_scanner(self):
        return self.get_scanner_by_name('OpenVAS Default')

    def get_cve_scanner(self):
        return self.get_scanner_by_name('CVE')

    ###
    # Export/Import Config
    ###

    @connection
    def export_config(self, config_uuid):
        resp = self._gmp.get_config(config_uuid)
        if resp.get('status') != '200':
            return None
        else:
            try:
                config_data = etree.tostring(resp, encoding='utf8', method='xml')
                compressed_config = zlib.compress(config_data, level=9)
                return compressed_config
            except:
                raise OpenVASClientException('Unable to convert xml to string')

    @connection
    def import_config(self, compressed_config):
        try:
            config_data = zlib.decompress(compressed_config)
            resp = self._gmp.import_config(config_data)
            if resp.get('status') != '201':
                return None
            else:
                return resp.get('id')

        except:
            raise OpenVASClientException('Incorrect config')

    @connection
    def delete_config(self, config_uuid):
        """
        
        return:
            True/False
        """

        resp = self._gmp.delete_config(config_uuid)
        return resp.get('status') == '200'

    ###
    # Scan Config
    ###

    @connection
    def get_config_by_name(self, name):
        """

        return:
            uuid 
            OR 
            None
        """

        resp = self._gmp.get_configs()
        if resp.get('status') != '200':
            return None

        configs = list(filter(lambda c: c.tag == 'config', resp.getchildren()))

        for config in configs:
            config_uuid = config.get('id')
            config_name = list(filter(lambda c: c.tag == 'name', config.getchildren()))[0].text

            if config_name == name:
                return config_uuid

        return None

    def get_empty_config(self):
        return self.get_config_by_name('empty')

    @connection
    def fork_config(self, existing_config_uuid, new_config_name):
        resp = self._gmp.create_config(existing_config_uuid, new_config_name)
        if resp.get('status') != '200':
            return None
        else:
            return resp.get('id')

    ###
    # Task
    ###

    @connection
    def create_task(self, config_id, target_id, scanner_id, name=None):
        """
        
        return:
            task_uuid
        """
        if not name:
            name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))

        resp = self._gmp.create_task(name, config_id=config_id, target_id=target_id, scanner_id=scanner_id)

        if resp.get('status') != '201':
            return None

        return resp.get('id')

    @connection
    def delete_task(self, task_uuid):
        """
        
        return:
            True/False
        """

        resp = self._gmp.delete_task(task_uuid)
        return resp.get('status') == '200'

    @connection
    def run_task(self, task_uuid):
        """

        return:
            report_uuid
        """
        resp = self._gmp.start_task(task_uuid)

        if resp.get('status') != '202':
            return None

        report_uuid = resp.getchildren()[0].text
        return report_uuid

    @connection
    def stop_task(self, task_uuid):
        """

        return:
            report_uuid
        """
        resp = self._gmp.stop_task(task_uuid)
        return resp.get('status') == '202'

    @connection
    def get_report_info(self, report_uuid):
        """
        Statues:
            ...
            Done
            ...
        """
        resp = self._gmp.get_report(report_uuid, details=True, ignore_pagination=True)
        if resp.get('status') != '200':
            return None

        reports_ = list(filter(lambda c: c.tag == 'report', resp.getchildren()))
        if len(reports_) == 0:
            return None
        report_ = reports_[0]

        reports = list(filter(lambda c: c.tag == 'report', report_.getchildren()))
        if len(reports) == 0:
            return None
        report = reports[0]

        scan_run_status = list(filter(lambda c: c.tag == 'scan_run_status', report.getchildren()))[0].text

        results_ = list(filter(lambda c: c.tag == 'results', report.getchildren()))[0]
        results = results_.getchildren()

        # remove results with 0.0 severity
        # severe_results = list(filter(
        #     lambda result: 
        #         float(list(filter(
        #             lambda c: c.tag == 'severity', result.getchildren()
        #         ))[0].text) > 0.0, 
        #     results
        # ))
        severe_results = results

        processed_results = [OpenVASParser.parse_result(r) for r in severe_results]

        return {
            'status': scan_run_status, 
            'results': processed_results,
        }

    ###
    # Selecting NVTS
    ###

    def select_nvts_for_config(self, config_id, nvts):
        """
        Select specified nvt plugins for config

        :param config_id:   Config uuid. Config itself Should be empty.
        :param nvts:        List of dicts that describe nvts
                            [{"nvt_family": str, "nvt_oid": list(str)}, ...]
                            Example of one dict of that list:
                            [
                                {
                                    "nvt_family": "General",
                                    "nvt_oid": "1.3.6.1.4.1.25623.1.0.142673"
                                },
                                ...
                            ]

        Example:
        ```
        >>> import json
        >>> from app.openvas_client import OpenVASClient
        >>> f = open("/tmp/invts.json", 'r')
        >>> invts = json.load(f)
        >>> f.close()
        >>> config_id = 'bcbea983-c50f-40__-____-____________'
        >>> cli = OpenVASClient()
        >>> cli.set_password('_________________')
        >>> cli.select_nvts_for_config(config_id, invts)
        ```
        """

        family_nvts_dict = defaultdict(list)
        for nvt in nvts:
            family_nvts_dict[nvt['nvt_family']].append(nvt['nvt_oid'])

        self._connect()
        self._auth()

        for family, nvt_oids in family_nvts_dict.items():
            print('Selecting {} nvts from family {}'.format(len(nvt_oids), family))
            self._gmp.modify_config_set_nvt_selection(config_id, family, nvt_oids)

        self._disconnect()

    ###
    # EXPORT NVTS
    ###

    def get_nvts_from_xml_response(self, xml_get_info_response):

        xml_infos = list(filter(lambda x: x.tag == 'info' and 'id' in x.attrib, xml_get_info_response.getchildren()))

        nvts = list()

        for xml_info in xml_infos:
            nvt_oid = xml_info.attrib['id']
            nvt = list(filter(lambda x: x.tag == 'nvt', xml_info.getchildren()))[0]
            nvt_details = nvt.getchildren()
            qod = list(filter(lambda x: x.tag == 'qod', nvt_details))[0]
            qod_details = qod.getchildren()
            nvt_name = list(filter(lambda x: x.tag == 'name', nvt_details))[0].text
            nvt_creation_time = list(filter(lambda x: x.tag == 'creation_time', nvt_details))[0].text
            nvt_modification_time = list(filter(lambda x: x.tag == 'modification_time', nvt_details))[0].text
            nvt_category = list(filter(lambda x: x.tag == 'category', nvt_details))[0].text
            nvt_family = list(filter(lambda x: x.tag == 'family', nvt_details))[0].text
            nvt_cvss_score = list(filter(lambda x: x.tag == 'cvss_base', nvt_details))[0].text
            values = list(filter(lambda x: x.tag == 'cve_id', nvt_details))
            nvt_cve_id = values[0].text if values else None
            values = list(filter(lambda x: x.tag == 'xrefs', nvt_details))
            nvt_xrefs = values[0].text if values else None
            values = list(filter(lambda x: x.tag == 'solution', nvt_details))
            nvt_solution = values[0].text if values else None
            nvt_tags = list(filter(lambda x: x.tag == 'tags', nvt_details))[0].text
            nvt_qod_value = list(filter(lambda x: x.tag == 'value', qod_details))[0].text
            nvt_qod_type = list(filter(lambda x: x.tag == 'type', qod_details))[0].text

            new_nvt = {
                'nvt_oid': nvt_oid,
                'nvt_name': nvt_name,
                'nvt_creation_time': nvt_creation_time,
                'nvt_modification_time': nvt_modification_time,
                'nvt_category': nvt_category,
                'nvt_family': nvt_family,
                'nvt_cvss_score': nvt_cvss_score,
                'nvt_cve_id': nvt_cve_id,
                'nvt_xrefs': nvt_xrefs,
                'nvt_tags': nvt_tags,
                'nvt_solution': nvt_solution,
                'nvt_qod_value': nvt_qod_value,
                'nvt_qod_type': nvt_qod_type,
            }

            nvts.append(new_nvt)

        return nvts

    def export_nvts(self):
        """
        Export nvts in list of dicts format

        Example:
        ```
        >>> from app.openvas_client import OpenVASClient
        >>> cli = OpenVASClient()
        >>> cli.set_password('_________________')
        >>> nvts = cli.export_nvts()
        ```
        """

        nvts = list()

        first_idx = len(nvts) # from 0
        rows = 1000 # step by 1000
        print('[+] Expecting about 65k-75k nvts.')
        while True:
            print('[+] Fetching... first_idx={}'.format(first_idx))

            self._connect()
            self._auth()

            filter_str = "sort-reverse=created rows={} first={}".format(rows, first_idx)
            nvt_xml_response = self._gmp.get_info_list(info_type=InfoType.NVT, filter=filter_str)
            new_nvts = self.get_nvts_from_xml_response(nvt_xml_response)
            nvts += new_nvts

            self._disconnect()

            # Total nvts about 70k
            print('[+] Fetched {}. Total: {}'.format(len(new_nvts), len(nvts)))

            first_idx += len(new_nvts)

            if len(new_nvts) != rows:
                break

        return nvts


# cli = OpenVASClient(debug=True)
# cli.set_password('')
