

from collections import defaultdict
from future.utils import iteritems
from functools import reduce
import json

from app.db.db import new_session
from app.db.models import DebbyAgent, DebbyTag, RelationAgentTag, DebbyScanResults, DebbyScanResultsService
from app.db.models import DebbyScanResultsScripts


# ---------------------------
# --- Task Import results ---
# ---------------------------


# def task_import_results(results, task_uuid, project_name, scan_id):
#     fname = os.path.join('/tmp', task_uuid)
#     with open(fname, "w") as write_file:
#         write_file.write(results)
#         # for result in results:
#         #     write_file.write(json.dumps(result) + '\n')
#
#     # specify config file
#     custom_env = os.environ.copy()
#     custom_env["IVRE_CONF"] = os.path.join('config', 'host_ivre.conf')
#     # BECAUSE CATEGORY LENGTH IS 32
#     categories = task_uuid.replace('-', '') + ',' + project_name
#     args = ['ivre', 'scan2db', '-c', categories, '-s', str(scan_id), fname]
#     process = subprocess.Popen(args, env=custom_env)
#
#     # wait to ensure successful exporting
#     process.wait()

# ---------------------------
# --- Get Available agent ---
# ---------------------------
from app.utils import timestamp_utc_2_datetime_msk, datetime_msk_2_timestamp_utc


def get_available_agent_2(required_tags_list, session=None, any_job_count=False, get_all=False):
    """
    Calculate the workload of available agents
    And return the least loaded
    return None if there are no available
    Agent must have tags listed in required_tags_list
    """

    if session is None:
        s = new_session()
    else:
        s = session

    res = s.query(DebbyAgent, DebbyTag)\
        .filter(DebbyAgent.jobs is not None)\
        .outerjoin(
            RelationAgentTag, DebbyAgent.id == RelationAgentTag.agent_id
        )\
        .outerjoin(DebbyTag, DebbyTag.id == RelationAgentTag.tag_id)\
        .all()

    if session is None:
        s.close()

    # filter by max_jobs
    if not any_job_count:
        res = list([r for r in res if (r.DebbyAgent.max_jobs == 0 or r.DebbyAgent.jobs < r.DebbyAgent.max_jobs)])

    # key = agent instance. value = tag list.
    # except None values
    d = defaultdict(list)
    for r in res:
        value = None
        if r.DebbyTag:
            value = r.DebbyTag.value
        d[r.DebbyAgent].append(value)

    # filter by required tags
    # tagged_agents = [(agent, ['tag1', 'tag2', ...]), ...]
    tagged_agents = list([x for x in iteritems(d) if all([tag in x[1] for tag in required_tags_list])])

    if len(tagged_agents) == 0:
        return None
    # elif len(tagged_agents) == 1:
    #     return tagged_agents[0][0]

    # only agents
    agents = list([x[0] for x in tagged_agents])

    # if get all -> return list
    if get_all:
        return agents

    # else get least loaded agent
    else:
        return reduce(lambda x, y: (x if x.jobs <= y.jobs else y), agents)


# ----------------------------------------
# --- Save splunk_event object into db ---
# ----------------------------------------

def save_events_to_db(events):
    """
    :param events:
        [
          {
            "protocol": "ipv4",
            "event_type": "info",
            "tags": [
              "INTERNAL",
              "IPv4",
              "IPv6"
            ],
            "projectName": "test57",
            "logClosed": false,
            "taskId": 964,
            "scripts": {},
            "transport": "tcp",
            "dest_port": 22,
            "scanId": 411,
            "scanStartTime": 1541678338,
            "enabled": true,
            "portState": "open"
            "time": 1541678339,
            "dest_ip": "5.255.255.55",
            "service_name": ...,
            "service_product": ...,
            "service_version": ...
          },
          ...
        ]
    :return:
    """

    session = new_session()
    for event in events:

        dest_port_prep = event['dest_port'] if isinstance(event['dest_port'], int) else -1

        time = timestamp_utc_2_datetime_msk(event["time"]) if event["time"] else None
        dsr = DebbyScanResults(scan_id=event['scanId'], ip=event['dest_ip'], port=dest_port_prep, time=time,
                               transport=event['transport'], enabled=event['enabled'], state=event.get('portState'))
        session.add(dsr)
        session.commit()

        if event['service_name'] or event['service_product'] or event['service_version']:
            dsr_service = DebbyScanResultsService(scan_result_id=dsr.id, name=event['service_name'],
                                                  product=event['service_product'], version=event['service_version'])
            session.add(dsr_service)

        for (key, value) in iteritems(event['scripts']):
            value_ = json.dumps({"value": value})
            dsr_script = DebbyScanResultsScripts(scan_result_id=dsr.id, key=key, value=value_)
            session.add(dsr_script)

        session.commit()
    session.close()


def get_events_from_db(scan_id):

    s = new_session()

    dsr_list = s.query(DebbyScanResults, DebbyScanResultsService)\
                .outerjoin(DebbyScanResultsService, DebbyScanResults.id == DebbyScanResultsService.scan_result_id)\
                .filter(DebbyScanResults.scan_id == scan_id).all()

    dsr_id_list = list([dsr.DebbyScanResults.id for dsr in dsr_list])
    scripts = s.query(DebbyScanResultsScripts).filter(DebbyScanResultsScripts.scan_result_id.in_(dsr_id_list)).all()

    s.close()

    # print('[+] get_events_from_db. dsr_list: {}.'.format(dsr_list))
    # print('[+] get_events_from_db. dsr_id_list: {}.'.format(dsr_id_list))
    # print('[+] get_events_from_db. scripts: {}.'.format(scripts))

    events = list()

    for dsr in dsr_list:

        event = dict()

        # General
        time_ = datetime_msk_2_timestamp_utc(dsr.DebbyScanResults.time) if dsr.DebbyScanResults.time else None
        event['time'] = time_
        event['scanId'] = dsr.DebbyScanResults.scan_id
        event['dest_ip'] = dsr.DebbyScanResults.ip
        event['dest_port'] = dsr.DebbyScanResults.port if dsr.DebbyScanResults.port != -1 else 'general'
        event['transport'] = dsr.DebbyScanResults.transport
        event['enabled'] = dsr.DebbyScanResults.enabled
        event['portState'] = dsr.DebbyScanResults.state

        # Service info
        event['service_name'] = dsr.DebbyScanResultsService.name if dsr.DebbyScanResultsService else None
        event['service_product'] = dsr.DebbyScanResultsService.product if dsr.DebbyScanResultsService else None
        event['service_version'] = dsr.DebbyScanResultsService.version if dsr.DebbyScanResultsService else None

        # Scripts
        event_scripts = dict()
        dsr_scripts = list([s for s in scripts if s.scan_result_id == dsr.DebbyScanResults.id])
        for dsr_script in dsr_scripts:
            event_scripts[dsr_script.key] = json.loads(dsr_script.value)["value"]
        event['scripts'] = event_scripts

        events.append(event)

    return events


def delete_events_for_scan(scan_id):
    s = new_session()
    s.query(DebbyScanResults).filter(DebbyScanResults.scan_id == scan_id).delete()
    s.commit()
    s.close()
