

import json
import logging

from datetime import datetime
from sqlalchemy.dialects.postgresql import insert
from app.engines.base_engine import BaseEngine
from app.db.db import new_session
from app.db.models import DebbyProject, DebbyPolicy, DebbyPolicyScript, DebbyPolicyAdditionalOptions, PipelineScans
from app.db.models import DebbyScanResults
from app.db.models import DebbyTask, DebbyScan, DebbyTag, RelationProjectTag
from app.engines.utils import target_generator, list_of_ip_to_subnet
from app.settings import PROTO_IPV6, PROTO_IPV4, PROTO_UDP, PROTO_TCP
from app.utils import is_subnet, is_valid_ipv4_address, is_valid_ipv6_address, iterate_ports, uniq_ports
from app.utils import port_list_to_ports_str, addr_to_proto, datetime_msk_2_timestamp_utc
from app.validators import DebbyValidateException, check_port
# from app.puncher import CachedPuncherAPIClient
from app.db.models import PortCache, RelationAgentTag
from app import portcache


class NmapEngine(BaseEngine):

    @staticmethod
    def _prepare_profile(policy, targetlist):
        proto = None

        for target in targetlist:

            if is_subnet(target):
                target = target.split('/')[0]

            if is_valid_ipv4_address(target):
                if not proto or proto == PROTO_IPV4:
                    proto = PROTO_IPV4
                else:
                    raise DebbyValidateException('Target list with multiple protocols')

            elif is_valid_ipv6_address(target):
                if not proto or proto == PROTO_IPV6:
                    proto = PROTO_IPV6
                else:
                    raise DebbyValidateException('Target list with multiple protocols')

            else:
                raise DebbyValidateException('Unknown address type/protocol')

        # scan protocol
        if policy.scan_type == PROTO_UDP:
            args = '-sU'
        elif policy.scan_type == PROTO_TCP:
            args = '-sS'
        else:
            raise DebbyValidateException('Unknown scan_type')

        # get scripts and additional options
        session = new_session()
        scripts_objs = session.query(DebbyPolicyScript).filter(DebbyPolicyScript.policy_id == policy.id).all()
        additional_options = session.query(DebbyPolicyAdditionalOptions) \
                                    .filter(DebbyPolicyAdditionalOptions.policy_id == policy.id).first()
        session.close()
        scripts_str = ",".join(script_obj.name for script_obj in scripts_objs)
        additional_options_value = additional_options.value if additional_options else None

        # ports
        ports = policy.ports
        if ports:
            if not all([check_port(p.strip()) for p in policy.ports.split(',')]):
                raise DebbyValidateException('Incorrect ports')

            unique_ports = port_list_to_ports_str(uniq_ports(iterate_ports(ports.replace(' ', ''))))
            args += ' -p {}'.format(unique_ports)

        # network protocol
        if proto == PROTO_IPV6:
            args += ' -6'

        # magic args to increase performace
        args += ' --max-rtt-timeout 300ms'
        args += ' --min-rtt-timeout 50ms'
        args += ' --initial-rtt-timeout 250ms'
        args += ' --max-retries 2'
        args += ' --max-scan-delay 5ms'
        args += ' --min-hostgroup 64'
        args += ' --max-hostgroup 64'
        args += ' --min-parallelism 100'

        # skip ping checking
        args += ' -Pn'

        # force dont skip many filtered/closed ports
        args += ' -v'

        # additional options
        if additional_options_value:
            args += ' ' + additional_options_value

        if scripts_str:
            args += ' --script={}'.format(scripts_str)

        # targets
        args += ' '
        args += ' '.join(targetlist)

        profile = {"args": args}

        return profile

    @staticmethod
    def _get_targets_from_db(prev_scan_id):
        session = new_session()
        targets = session.query(DebbyScanResults.ip).filter(DebbyScanResults.scan_id == prev_scan_id)\
                                                    .filter(DebbyScanResults.enabled == True).distinct()
        target_list = [ip for ip, in targets]
        session.close()
        return target_list

    @staticmethod
    def new_tasks_payloads_generator(project_id, scan_id=None):
        s = new_session()
        project = s.query(DebbyProject).filter(DebbyProject.id == project_id).first()
        policy = s.query(DebbyPolicy).filter(DebbyPolicy.id == project.policy_id).first()
        s.close()

        targets = None
        if scan_id:
            session = new_session()
            scan_pipeline = session.query(PipelineScans).filter(PipelineScans.next_scan_id == scan_id).first()
            session.close()

            if scan_pipeline:
                targets = NmapEngine._get_targets_from_db(scan_pipeline.prev_scan_id)

        # prepare exlude list
        exclude_list = ['_OPENVPNLOOPBACKS_']
        # exclude_list += ['_OEBS_DB_SRV_', '_OEBS_FRONT_SRV_', '_OEBS_STORAGE_SRV_', '_CRMDBSRV_', '_CRMDBTARGETSRV_']
        # exclude_list += ['_BI_PROD_DB_SRV_', '_BI_DEV_DB_SRV_']
        # exclude_list += ['_BALANCEDBSRV_', '_BALANCETARGETSRV_', '_METADBSRV_', '_BALANCENONPRODDBSRV_']

        if not targets:
            targetlist = project.targets
            parts = targetlist.split(' exclude ')
            if len(parts) == 2:
                targets = parts[0].split(', ')
                exclude_list += parts[1].split(', ')
            else:
                targets = targetlist.split(', ')

        # scan.scanned_targets = 0
        # scan.tatal_targets = len(targets)
        # for target in targets:
        #   do same.
        #   scan.scanned_targets += 1

        # print("[+] new_tasks_payloads_generator. target_generator({}, {})".format(targets, exclude_list))
        # logging.debug("[+] new_tasks_payloads_generator. target_generator({}, {})".format(targets, exclude_list))
        logging.warning("[+] new_tasks_payloads_generator. target_generator({}, {})".format(targets, exclude_list))
        for some_targets in target_generator(targets, exclude_list):
            some_targets = list_of_ip_to_subnet(some_targets)

            payload = json.dumps({
                "engine": project.engine,
                "profile": NmapEngine._prepare_profile(policy, some_targets),
                "save_to_db": project.save_to_db
            })

            yield (some_targets, payload)

    @staticmethod
    def _overwrite_port_scripts(scripts_dict):
        if not scripts_dict:
            return scripts_dict

        ssh_auth_methods_script_output = scripts_dict.get("ssh-auth-methods")
        if ssh_auth_methods_script_output:
            lines = ssh_auth_methods_script_output.split("\n")
            lines = list([x for x in lines if "Supported authentication methods:" not in x])
            lines = list([x.strip() for x in lines])
            lines = list([x for x in lines if x])
            scripts_dict["ssh-auth-methods"] = lines

        return scripts_dict

    @staticmethod
    def scan_results_to_splunk_events(scan_results, task_id, only_enabled=True):
        """

        :param results:
        [
          {
            "addr": "5.255.255.55",
            "starttime": 1541678339,
            "ports": [
              {
                "state": "open",
                "time": 1541678339,
                "protocol": "tcp",
                "port": 22,
                "scripts": [],
                "service_name": ...,
                "service_product": ...,
                "service_version": ...
              },
              {
                "state": "open",
                "time": 1541678339,
                "protocol": "tcp",
                "port": 80,
                "scripts": [],
                "service_name": ...,
                "service_product": ...,
                "service_version": ...
              },
            ]
          }
        ]
        :param task_id:
        :param scan:
        :param project:
        :param policy:
        :return:
        [
          {
            "protocol": "ipv4",
            "event_type": "info",
            "tags": [
              "INTERNAL",
              "IPv4",
              "IPv6"
            ],
            "projectName": "test57",
            "logClosed": false,
            "taskId": 964,
            "scripts": [],
            "transport": "tcp",
            "dest_port": 22,
            "scanId": 411,
            "scanStartTime": 1541678338,
            "enabled": true,
            "time": 1541678339,
            "dest_ip": "5.255.255.55",
            "service_name": ...,
            "service_product": ...,
            "service_version": ...
          }
        ]
        """

        events = list()

        session = new_session()
        task = session.query(DebbyTask).filter(DebbyTask.id == task_id).first()
        scan = session.query(DebbyScan).filter(DebbyScan.id == task.debbyscan_id).first()
        project = session.query(DebbyProject).filter(DebbyProject.id == scan.project_id).first()
        policy = session.query(DebbyPolicy).filter(DebbyPolicy.id == project.policy_id).first()
        tags = session.query(DebbyTag).filter(RelationProjectTag.project_id == project.id)\
                                      .filter(RelationProjectTag.tag_id == DebbyTag.id).all()
        session.close()

        tag_list = list([tag.value for tag in tags])

        # puncher = CachedPuncherAPIClient()
        
        for result in scan_results:
            addr = result.get('addr')

            ports = result.get('ports') or list()
            for port in ports:

                port_scripts = port.get('scripts')
                port_scripts = NmapEngine._overwrite_port_scripts(port_scripts)
                # port_scripts.update({
                #     "puncher_allowed_from_inet": puncher.is_allowed_from_inet(
                #         addr, port.get("port"), port.get("protocol")
                #     )
                # })

                events.append({
                    'event_type': 'info',

                    'projectId': project.id,
                    'projectName': project.name,
                    'engine': project.engine,
                    'logClosed': project.log_closed,
                    'tags': tag_list,

                    'policyId': policy.id,

                    'dest_ip': addr,
                    'resp': None,
                    'protocol': addr_to_proto(addr),

                    'time': port.get('time'),
                    'dest_port': port.get('port'),
                    'transport': port.get('protocol'),
                    'portState': port.get('state'),
                    'enabled': port.get('state') == 'open',
                    'scripts': port_scripts,

                    'service_name': port.get('service_name'),
                    'service_product': port.get('service_product'),
                    'service_version': port.get('service_version'),

                    'scanId': scan.id,
                    'scanStartTime': datetime_msk_2_timestamp_utc(scan.create_time),

                    'taskId': task.id,
                })

            osmatches = result.get('os_matches')
            if osmatches:
                osscripts = dict()

                # for osmatch in osmatches:
                #     accuracy = osmatch.get('accuracy')
                #     name = osmatch.get('name')
                #     osscripts['nmap_os_detect'] = '{}:{}'.format(accuracy, name)

                osscripts['nmap_os_detect'] = '\n'.join(['{}:{}'.format(osmatch.get('accuracy'), osmatch.get('name'))
                                                         for osmatch in osmatches])

                events.append({
                    'event_type': 'info',

                    'projectId': project.id,
                    'projectName': project.name,
                    'engine': project.engine,
                    'logClosed': project.log_closed,
                    'tags': tag_list,

                    'policyId': policy.id,

                    'dest_ip': addr,
                    'protocol': addr_to_proto(addr),

                    'time': None,
                    'dest_port': None,
                    'transport': None,
                    'portState': None,
                    'enabled': True,
                    'scripts': osscripts,

                    'service_name': None,
                    'service_product': None,
                    'service_version': None,

                    'scanId': scan.id,
                    'scanStartTime': datetime_msk_2_timestamp_utc(scan.create_time),

                    'taskId': task.id,
                })

        return events

    @staticmethod
    def cache_ports(task_results, project_name, scan_id, task_id, agent_id):

        # get agent to determine it's location by tags
        s = new_session()
        r = s.query(DebbyTag).outerjoin(RelationAgentTag, DebbyTag.id == RelationAgentTag.tag_id).filter(RelationAgentTag.agent_id == agent_id).all()
        s.close()
        tags = list([x.value for x in r])

        location = "UNKNOWN"
        if "AZURE" in tags or "EXTERNAL" in tags:
            location = "EXTERNAL"
        if "GUEST" in tags:
            location = "GUEST"
        if "NESSUS" in tags or "INTERNAL" in tags:
            location = "INTERNAL"
        if "YACLOUD" in tags:
            location = "YACLOUD"
        location = portcache.location2code(location)
        
        # prepare results for caching 
        caching_results = list()
        for result in task_results:
            try:
                addr = result.get('addr')
                ports = result.get('ports') or list()
                for port in ports:
                    if port.get('state') != 'open':
                        continue
                    dest_port = port.get('port')
                    # transport = port.get('protocol')
                    transport = portcache.transport2code(port.get('protocol'))
                    last_seen = datetime.fromtimestamp(port.get('time'))
                    caching_results.append((addr, dest_port, transport, last_seen))
            except Exception as e:
                print("problem parsing cache_ports. project_name={}, scan_id={}, task_id={}, e={}".format(project_name, scan_id, task_id, e))

        # insert or update last seen time
        s = new_session()
        conn = s.connection()
        for r in caching_results:
            stmt = insert(PortCache)
            stmt = stmt.values(
                target=r[0], port=r[1], transport=r[2], location=location,
                last_seen=r[3], info_source="nmap:{}:{}:{}".format(project_name, scan_id, task_id)
                )
            stmt = stmt.on_conflict_do_update(
                index_elements=[PortCache.target, PortCache.port, PortCache.transport, PortCache.location],
                set_=dict(last_seen=r[3], info_source="nmap:{}:{}:{}".format(project_name, scan_id, task_id))
                )
            conn.execute(stmt)
        s.commit()
        s.close()

        return len(caching_results)
