# -*- coding: utf-8 -*-


import traceback
import os
import socket
import json
import zlib
import time
import subprocess
import random
import base64
import stat
import hashlib
import tempfile
import urllib.parse
import datetime

from logging import getLogger

import requests
from lxml import etree

from django.conf import settings
from django.template.loader import render_to_string
from django.urls import reverse
from django.db import transaction

from celery.worker.control import Panel
from celery.exceptions import SoftTimeLimitExceeded
from celery.signals import task_failure

from startrek_client import Startrek

from .celery_app import app

from .models import Vulnerability, VulnTicket, VulnerabilityType, HTTPTransaction, HTTPHeader, \
    RequestSamples, Scan

from .notification.send_mail import send_message
from .notification.solomon import Solomon
from .notification.metrics import Metrics
from .notification.errors import catch_error, ErrorTracer
from .task_start_burp import RemoteBurpTask
from .utils import upload_to_elliptics, killpstree, resolve_sandbox_resource, parse_balancer_log, parse_molly_json, \
    parse_burp_xml_log, parse_balancer_log_tmp, pg_wrap, is_l7, is_resolvable, log_vuln_to_splunk, \
    prepare_skynet_resource, read_from_mds, has_ipv6, is_s3, parse_datetime_from_startrack


logger = getLogger(__name__)


@app.task(bind=True, ignore_result=True, soft_time_limit=180)
def send_scan_result_notification(self, scan):
    # do not send emails more frequent than once per hour per target for production environments
    if scan.is_prod and scan.target.last_email and \
        datetime.datetime.now() - datetime.timedelta(hours=settings.EMAIL_PER_TARGET_DELAY) < scan.target.last_email:
        return False

    emails = []
    if scan.users.count():
        for user in scan.users.all():
            emails.append(user.email)
    else:
        for user in scan.target.users.all():
            emails.append(user.email)

    if scan.is_prod:
        emails.append('security-alerts@yandex-team.ru')

    if not emails:
        return False

    if not scan.is_vulnerable:
        return False

    if not scan.has_sensitive_vulns:
        return False

    ctx = dict()
    ctx['vulnerabilities'] = Vulnerability.objects.filter(scan=scan,
                                                          vuln_type__severity__gt=VulnerabilityType.SEVERITY_INFORMATION)\
        .order_by('-vuln_type__severity')
    ctx['scan'] = scan
    send_message(to_email=emails,
                 template='notifications/new_vuln_notification.html',
                 subject='В сервисе "%s" обнаружены уязвимости' % scan.target.name,
                 data=ctx,
                 sender_email=settings.EMAIL_HOST_USER)

    scan.target.last_email = datetime.datetime.now()
    scan.target.save()
    return True


@app.task(bind=True, ignore_result=True, soft_time_limit=180)
def send_scan_start_notification(self, scan):
    emails = []
    for user in scan.target.users.all():
        emails.append(user.email)

    if not emails:
        return False

    send_message(to_email=emails,
                 template='notifications/new_scan_notification.html',
                 subject='[Molly] Запущено сканирование сервиса "{}"'.format(scan.target.name),
                 data={'scan': scan},
                 sender_email=settings.EMAIL_HOST_USER)
    return True


@app.task(bind=True, ignore_result=True, soft_time_limit=90)
def send_scan_stats_to_solomon(self, scan):
    stats = scan.get_scan_response_stats()
    if not stats:
        return
    client = Solomon(project=settings.SOLOMON_PROJECT_ID,
                     cluster=settings.SOLOMON_CLUSTED_ID,
                     service=settings.SOLOMON_METRICS_SERVICE,
                     token=settings.SOLOMON_OAUTH_TOKEN)
    for item in stats:
        label = item.get('method')
        for resp in item.get("responses"):
            value = resp.get("count", 0)
            status_code = resp.get("status_code")
            key = '{}_{}'.format(label, status_code)
            client.push_sensor(key, value, str(scan.target.id))
    client.flush()


def save_burp_vulnerabilities(scan, xml_report):
    MAX_TRANS_SAMPLES = 10
    MAX_SIMILAR_VULNS = 10

    stats = dict()

    notified = dict()
    report = dict()
    report['issues'] = dict()
    try:
        parser = etree.XMLParser(encoding="utf-8", recover=True, dtd_validation=False,
                                 resolve_entities=False, no_network=True)
        tree = etree.fromstring(xml_report, parser=parser)

        for issue in tree.getiterator(tag='issue'):
            issue_info = dict()
            issue_name = list(issue.getiterator(tag='name'))[0].text[:80]
            issue_severity = 'unknown'
            if list(issue.getiterator(tag='severity')):
                issue_severity = list(issue.getiterator(tag='severity'))[0].text.lower()

            stat_name = "vuln_" + issue_severity.lower()
            try:
                stats[stat_name] += 1
            except KeyError:
                stats[stat_name] = 1

            issue_key = '{}:{}'.format(issue_name, issue_severity)
            vuln_type, created = VulnerabilityType.objects.get_or_create(name=issue_name,
                                                                         scanner_severity=issue_severity)
            if vuln_type.severity == VulnerabilityType.SEVERITY_IGNORE:
                continue
            if vuln_type.archived:
                continue

            if not report['issues'].get(issue_key):
                report['issues'][issue_key] = []
            issue_info['name'] = issue_name
            issue_info['severity'] = issue_severity
            issue_info['vulntype'] = vuln_type

            if list(issue.getiterator(tag='host')):
                issue_info['host'] = list(issue.getiterator(tag='host'))[0].text

            if list(issue.getiterator(tag='issueDetail')):
                issue_details = list(issue.getiterator(tag='issueDetail'))[0].text
                issue_info['details'] = issue_details

            if not issue_info.get('request_info'):
                issue_info['request_info'] = []

            for z in issue.findall('./requestresponse'):
                rr = dict()
                for x in z.findall('./request'):
                    if not x.text:
                        continue
                    rr['request'] = base64.b64decode(x.text).decode('latin-1')

                for x in z.findall('./response'):
                    if not x.text:
                        continue
                    rr['response'] = base64.b64decode(x.text).decode('latin-1')

                if len(issue_info['request_info']) > MAX_TRANS_SAMPLES:
                    break

                issue_info['request_info'].append(rr)
            if len(report['issues'].get(issue_key)) >= MAX_SIMILAR_VULNS:
                continue

            report['issues'][issue_key].append(issue_info)

        final_report = dict()
        final_report['issues'] = dict()
        for issue_key, issues in report['issues'].items():
            issue_info = dict()
            final_report['issues'][issue_key] = []
            for issue in issues:
                vuln_type = issue.get('vulntype')
                if not vuln_type:
                    continue

                if vuln_type.combine:
                    details = issue.get('details', '')

                    if not issue_info.get('details'):
                        issue_info["details"] = details
                        issue_info["details"] += "<br>\n<b>URLs having same problem:</b> \n"

                    for rr in issue.get('request_info'):
                        req_info = rr.get('request', '').split('\r\n\r\n', 1)
                        request_headers = req_info[0].split('\r\n')
                        if not request_headers or len(request_headers[0].split(' ')) < 3:
                            continue

                        uri = ' '.join(request_headers[0].split(' ')[1:-1])
                        issue_info["details"] += "\n"
                        issue_info["details"] += "* " + issue.get('host', '') + uri
                    issue_info['request_info'] = issue.get('request_info', [])

                    issue_info['severity'] = issue.get('severity')
                    issue_info['host'] = issue.get('host', b'')
                    issue_info['vulntype'] = vuln_type
                    final_report['issues'][issue_key].append(issue_info)
                else:
                    final_report['issues'][issue_key] = issues

        combined = []
        for issue_key, issues in final_report['issues'].items():
            for issue in issues:
                vuln_type = issue.get('vulntype')
                if not vuln_type:
                    continue

                if vuln_type.combine and issue_key in combined:
                    continue

                vuln = Vulnerability(scan=scan, vuln_type=issue.get('vulntype'),
                                     description=issue.get('details', ''))
                vuln.save()
                if vuln_type.combine:
                    combined.append(issue_key)

                for rr in issue.get('request_info'):
                    request_body = ''
                    req_info = rr.get('request', '').split('\r\n\r\n', 1)
                    request_line = pg_wrap(req_info[0].split('\r\n')[0], 2048)
                    request_headers = list(map(pg_wrap, req_info[0].split('\r\n')[1:]))
                    if len(req_info) > 1:
                        request_body = req_info[1]

                    response_body = ''
                    resp_info = rr.get('response', '').split('\r\n\r\n', 1)
                    response_line = pg_wrap(resp_info[0].split('\r\n')[0], 2048)
                    response_headers = list(map(pg_wrap, resp_info[0].split('\r\n')[1:]))
                    if len(resp_info) > 1:
                        response_body = resp_info[1]

                    trans = HTTPTransaction(request_line=request_line,
                                            request_body=request_body.encode("latin-1"),
                                            status_line=response_line,
                                            response_body=response_body.encode("latin-1"))
                    trans.save()

                    for hline in request_headers[1:]:
                        hvalue = ''
                        header = hline.split(':', 1)
                        if len(header) > 1:
                            hvalue = header[1]
                        h = HTTPHeader(name=header[0].strip()[:1024], value=hvalue.strip()[:1024])
                        h.save()
                        trans.request_headers.add(h)

                    for hline in response_headers[1:]:
                        hvalue = ''
                        header = hline.split(':', 1)
                        if len(header) > 1:
                            hvalue = header[1]
                        h = HTTPHeader(name=header[0].strip()[:1024], value=hvalue.strip()[:1024])
                        h.save()
                        trans.response_headers.add(h)

                    if trans:
                        vuln.http_details.add(trans)

                log_vuln_to_splunk(vuln)

                if vuln.was_marked_as_fp_in_past():
                    vuln.mark_as_fp()
                    continue

                if vuln.is_false_positive or vuln.was_triaged_in_past():
                    continue

                if not scan.report_ticket and not notified.get(issue_name) \
                        and vuln.vuln_type.severity >= scan.min_ticket_severity \
                        and vuln.vuln_type.tracker_severity:
                    create_vuln_ticket.apply_async(args=(vuln,), queue='scan_result_parser')
                    notified[issue_name] = True

        for name, count in stats.items():
            metrics_client.try_send_metric(name, count=count)


    except Exception as e:
        catch_error()
        formatted_lines = traceback.format_exc()
        scan.result_message = ('Internal Molly error while saving '
                               'vulnerabilies: %s\nDebug:\n%s' % str(e), formatted_lines)
        scan.save()
        logger.error('Error in file scanner_run:save_vulnerabilities: '
                     'Cannot parse xml report for scan %s: %s' % (scan.id, str(e)))

        return False

    return True


def fail_scan(scan, message, agent_host=''):
    """
    Set scan status to fail and set message to result_message
    """
    previous_message = scan.result_message
    scan.result_message = '\n'.join([previous_message, message])
    scan.status = Scan.ST_FAIL
    scan.finish = datetime.datetime.now()
    if agent_host:
        scan.agent_host = agent_host
    scan.save()


@app.task(bind=True, ignore_result=True, soft_time_limit=60)
def create_vuln_ticket(self, vuln):
    st_queue = settings.ST_DEFAULT_QUEUE
    if vuln.scan.target.st_queue:
        st_queue = vuln.scan.target.st_queue

    tags = []
    if vuln.scan.is_prod:
        tags.append("crasher")
    else:
        tags.append("molly")

    if vuln.vuln_type.st_tags:
        for vt_tag in vuln.vuln_type.st_tags.split(','):
            tag = vt_tag.strip()
            if not tag:
                continue
            tags.append(tag)

    template = 'notifications/vuln_ticket.txt'
    report = dict()
    report['issue'] = vuln
    report['scan'] = vuln.scan
    report['report_url'] = settings.APP_URL + reverse('show_report', args=[vuln.scan.uid])
    comment = render_to_string(template, {'report': report})

    followers = []
    if vuln.vuln_type.severity >= VulnerabilityType.SEVERITY_HIGH:
        followers = settings.ST_OAUTH_USERNAMES
    severity = vuln.vuln_type.tracker_severity

    issue = None
    st = Startrek(useragent=settings.ST_USER_AGENT, base_url=settings.ST_URL, token=settings.ST_OAUTH_TOKEN)

    if vuln.vuln_type.combine:
        if process_combine_vuln(tracker=st, vuln=vuln, comment=comment):
            return


    # No triaged issue was found
    if not issue:
        # try to create in service queue first then in SECALERTS
        for attempt in range(0, 2):
            if attempt == 1:
                if st_queue == settings.ST_DEFAULT_QUEUE:
                    break
                else:
                    st_queue = settings.ST_DEFAULT_QUEUE
            try:
                issue = st.issues.create(
                    queue=st_queue,
                    type={'name': 'Bug'},
                    summary=vuln.vuln_type.summary,
                    description=comment,
                    followers=followers,
                    tags=tags,
                )
            except Exception as e:
                catch_error({"tags": tags, "queue": st_queue, "followers": followers})
            else:
                break

    if not issue:
        logger.error('Ticket creation error (%s), all attempts failed.' % str(e))
        return
    try:
        issue = st.issues[issue.key]
    except Exception:
        catch_error()
        pass
    else:
        issue.update(security='Yes', securitySeverity=str(severity), ignore_version_change=True)
        vt, created = VulnTicket.objects.get_or_create(tracker_type=VulnTicket.TT_ST, ticket_id=issue.key)
        vuln.tracker_tickets.add(vt)
        if not created:
            return
        try:
            hook = issue.webhooks.create(endpoint=settings.APP_URL + reverse('st_webhook_handler'),
                                         filter={"key": issue.key})
        except Exception as e:
            catch_error({"ticket": issue.key})
            return
        else:
            vt.webhook_url = hook.self
            vt.save()


def process_combine_vuln(tracker=None, vuln=None, comment=None):
    """
    Try to find a ticket that has already been created and leave a comment there.

    Return True if success comment created or not need comment now or a tracker error has occurred.
    Return False if ticket not found.
    """
    existing_tickets = VulnTicket.objects.filter(
        vulnerabilities__scan__target=vuln.scan.target,
        vulnerabilities__vuln_type=vuln.vuln_type
    )

    for ticket in existing_tickets:
        if ticket.is_triaging:
            if ticket.ticket_id:

                # try to resolve ticket
                try:
                    issue = tracker.issues[ticket.ticket_id]
                except:
                    catch_error({"ticket": ticket.ticket_id})
                    return True

                if issue is None:
                    continue

                status = None
                if issue.status is not None:
                    status = ticket.get_status(issue.status.key.lower())

                resolution = None
                if issue.resolution is not None:
                    resolution = issue.resolution.key.lower()

                ticket.update_status(
                    status=status,
                    resolution=resolution
                )

                updateAt = parse_datetime_from_startrack(issue.updatedAt)
                margin = datetime.datetime.now() - updateAt

                if margin < datetime.timedelta(days=7):
                    return True

                # after update triaging status maybe change
                if ticket.is_triaging:
                    try:
                        issue.comments.create(text=comment)
                        return True
                    except:
                        catch_error({"ticket": ticket.ticket_id})
                        return True

    return False



@app.task(bind=True, ignore_result=True, soft_time_limit=60)
def send_report_to_ticket(self, scan):
    template = 'notifications/report.txt'
    resp = dict()
    resp['status'] = 'done'
    resp['vulnerabilities'] = scan.vulnerabilities_list
    resp['report_url'] = settings.APP_URL + reverse('show_report', args=[scan.uid])
    resp['is_vulnerable'] = scan.is_vulnerable
    comment = render_to_string(template, {'report': resp})

    try:
        st = Startrek(useragent=settings.ST_USER_AGENT, base_url=settings.ST_URL, token=settings.ST_OAUTH_TOKEN)
        issue = st.issues[scan.report_ticket]
        issue.comments.create(text=comment)
    except Exception as e:
        catch_error()
        logger.error('Reporting to ticket error (%s)' % str(e))
        pass


@app.task(bind=True, ignore_result=True)
def parse_scan_report(self, scan_result, scan_id):
    scan = Scan.objects.get(pk=scan_id)
    """
    Change scan status to done if returncode and scan status is ok
    """
    scan_report_meta, scan_log, was_aborted, agent_host = scan_result
    if not scan_report_meta:
        fail_scan(scan, 'Scanner returned err #1', agent_host)
        return

    if scan_report_meta == '<max_size_exceeded>':
        fail_scan(scan, 'Scanner returned err #3 (max size)', agent_host)
        return

    if scan_report_meta == '<report read error>':
        fail_scan(scan, 'Error reading report - scan was not started?', agent_host)
        return

    if scan_report_meta == '<no such host>':
        fail_scan(scan, 'Target DNS resolve failed. Check if target domain exists.', agent_host)
        return

    if scan_report_meta == '<url format error>':
        fail_scan(scan, 'Target URL format error.', agent_host)
        return

    if scan_report_meta == '<target unavailable>':
        fail_scan(scan, 'Connection to target failed. Check if target is accessible from _SECURITY_MTN_NETS_')
        return

    if scan_report_meta in ['<json config error>', '<config error>']:
        fail_scan(scan, 'Scan config generation error, please contact product-security@.')
        return

    # process was stopped by user
    if scan.status == Scan.ST_FAIL:
        return

    scan.scanner_report_url = scan_report_meta
    if not scan_report_meta.startswith(settings.ELLIPTICS_PROXY.get('read')):
        fail_scan(scan, 'Incorrect scan report url')
        return

    compressed_scan_report = read_from_mds(scan_report_meta)
    if not compressed_scan_report:
        fail_scan(scan, 'Scanner report upload to mds error', agent_host)
        return

    # process terminated successfully
    # are we really should keep redis memory or just filter our all the shit during report generation?
    try:
        scan_report = zlib.decompress(compressed_scan_report)
    except Exception:
        catch_error()
        fail_scan(scan, '\nerror while decompressing Burp scan report', agent_host)
        return

    if not scan_report:
        fail_scan(scan, '\nBurp returned empty result', agent_host)
        return

    if not save_burp_vulnerabilities(scan, scan_report):
        fail_scan(scan, '\nBurp returned err #2', agent_host)
        return

    finish_time = datetime.datetime.now()
    scan.finish = finish_time
    scan.scan_log = scan_log

    if agent_host:
        scan.agent_host = agent_host

    if was_aborted:
        scan.status = Scan.ST_ABORTED
    elif 'Error' in scan.result_message:
        scan.status = Scan.ST_FAIL
    else:
        scan.status = Scan.ST_DONE

    target = scan.target
    if target:
        target.last_scan = finish_time
        target.save()

    scan.save()


    logger.info('Trying to send scan {} stats to Solomon'.format(scan.uid))
    send_scan_stats_to_solomon.apply_async(args=(scan,), queue='scan_result_parser')

    if scan.report_ticket:
        logger.info('Sending scan {} report to ticket'.format(scan.uid))
        send_report_to_ticket.apply_async(args=(scan,), queue='scan_result_parser')

    if not scan.is_prod:
        if not scan.send_mail:
            return

        # TODO: move this timeout to settings
        # Aqua stops checking after 30min
        if (scan.finish - scan.start).total_seconds() <= 30*60 and not scan.send_mail:
            return

    send_scan_result_notification.apply_async(args=(scan,), queue='scan_result_parser')


def prepare_resource_as_filepath(resource_url, sample_format, report_path, vhost):
    prepared_samples_path = ''
    if resource_url.startswith('https://') or resource_url.startswith('http://'):
        try:
            r = requests.get(resource_url, verify=settings.CA_FILE)
            fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
            prepared_samples_path = fd.name
            for chunk in r.iter_content(chunk_size=128):
                fd.write(chunk)
            fd.close()
        except Exception as e:
            catch_error()
            logger.error('HTTP resource %s download error (%s)' % (resource_url, str(e)))
            return ''

    if resource_url.startswith('rbtorrent:'):
        prepared_samples_path = prepare_skynet_resource(resource_url, report_path)

    if not prepared_samples_path:
        logger.error('Sandbox resource skynet download error (%s)' % resource_url)
        return ''

    if sample_format == RequestSamples.FMT_JSON:
        out_fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
        try:
            for rec in parse_molly_json(prepared_samples_path):
                for k in list(rec.keys()):
                    new_headers = []
                    if k == 'headers':
                        for item in rec[k]:
                            if item.get('Name', '').lower() == 'host':
                                new_headers.append({'Name': 'Host', 'Value': vhost})
                                continue
                            new_headers.append(item)
                        rec.update({'headers': new_headers})
                        break
                out_fd.write(json.dumps(rec) + '\n')
            out_fd.close()
            os.unlink(prepared_samples_path)
            return out_fd.name
        except Exception as e:
            catch_error()
            logger.error('Request samples parsing error: %s' % str(e))
            return ''

    if sample_format == RequestSamples.FMT_JSON2:
        return prepared_samples_path

    if sample_format == RequestSamples.FMT_BURP:
        out_fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
        try:
            burp_samples = parse_burp_xml_log(prepared_samples_path)
            for rec in burp_samples:
                out_fd.write(json.dumps(rec) + '\n')
            out_fd.close()
            os.unlink(prepared_samples_path)
            return out_fd.name
        except Exception as e:
            catch_error()
            logger.error('Burst request samples parsing error: %s' % str(e))
            return ''

    if sample_format == RequestSamples.FMT_SERP:
        out_fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
        try:
            for rec in parse_balancer_log(prepared_samples_path):
                out_fd.write(json.dumps(rec) + '\n')
            out_fd.close()
            os.unlink(prepared_samples_path)
            return out_fd.name
        except Exception as e:
            catch_error()
            logger.error('SERP request samples parsing error: %s' % str(e))
            return ''

    if sample_format == RequestSamples.FMT_SERP2:
        out_fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
        try:
            for rec in parse_balancer_log_tmp(prepared_samples_path):
                out_fd.write(json.dumps(rec) + '\n')
            out_fd.close()
            os.unlink(prepared_samples_path)
            return out_fd.name
        except Exception as e:
            catch_error()
            logger.error('SERP2 Request samples parsing error: %s' % str(e))
            return ''

    return ''


#TODO: move to separate class
@app.task(bind=True, soft_time_limit=7200)
def remote_burp_call(self, txt_config, work_dir, profile_file, req_samples_url='',
                     sample_format=RequestSamples.FMT_JSON, ignore_time_limit=False):
    # XXX: ugly!
    was_aborted = False
    try:
        if not os.path.isdir(work_dir):
            os.makedirs(work_dir)
            os.chmod(work_dir, 0o777)
    except Exception:
        pass

    worker_hostname = socket.gethostname()
    report_tpl = os.path.splitext(profile_file)[0]
    try:
        config = json.loads(txt_config)
    except Exception:
        self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
        return '<json config error>', '', False, worker_hostname

    active_scanner_config = config.get("burp-active-scanner", {})

    self.update_state(state="RECEIVED", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})

    parsed_initial_url = urllib.parse.urlparse(active_scanner_config.get('initial_url', ''))
    if not parsed_initial_url.netloc:
        self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
        return '<url format error>', '', False, worker_hostname

    if not is_resolvable(active_scanner_config.get('initial_url')):
        self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
        return '<no such host>', '', False, worker_hostname

    try:
        requests.get(active_scanner_config.get('initial_url'),
                     timeout=5, allow_redirects=False, verify=False)
    except Exception:
        self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
        return '<target unavailable>', '', False, worker_hostname

    #XXX: temporary hack for SEARCHPRODINCIDENTS-3568
    if is_l7(active_scanner_config.get('initial_url', '')):
        active_scanner_config['complex_domain'] = True
        if active_scanner_config.get("qs_parameters"):
            active_scanner_config["qs_parameters"] += '&i-am-a-hacker=1'
        else:
            active_scanner_config["qs_parameters"] = 'i-am-a-hacker=1'

    active_scanner_config['report_path'] = report_tpl + '.xml'
    if req_samples_url:
        active_scanner_config['sample_requests'] = prepare_resource_as_filepath(req_samples_url, sample_format,
                                                                                work_dir,
                                                                                parsed_initial_url.netloc)

    active_scanner_config['scan_log_path'] = report_tpl + '.http'

    config['burp-active-scanner'] = active_scanner_config
    with open(profile_file, 'w') as configfile:
        configfile.write(json.dumps(config))
    # TODO: correct permissions
    os.chmod(profile_file, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
    env = os.environ.copy()
    env['MOLLY_CONFIG'] = profile_file
    # phantomjs requires it
    env['QT_QPA_PLATFORM'] = 'offscreen'

    ratelimit = 0
    # Random Proxy Port for multiprocess Burp
    proxy_port = random.randrange(1025, 65535)
    #XXX: stderr and stdout here for debug only
    with open('/tmp/stderr.{}'.format(proxy_port), 'wb+') as stderr,\
            open('/tmp/stdout.{}'.format(proxy_port), 'wb+') as stdoutlog:
        stdout = subprocess.PIPE
        try:
            config_fd = tempfile.NamedTemporaryFile(mode="w+", dir=work_dir, delete=False)
            with open('/usr/lib/yandex/burp/burp_project_config.json', 'r') as fd:
                burp_config = json.load(fd)
                burp_config['proxy']['request_listeners'] = []
                burp_config['proxy']['request_listeners'].append({
                    "certificate_mode": "per_host",
                    "listen_mode": "loopback_only",
                    "listener_port": proxy_port,
                    "running": True,
                })
                if active_scanner_config.get('throttle'):
                    scanner_config = burp_config["scanner"].get("active_scanning_engine", {})
                    scanner_config["do_throttle"] = True
                    scanner_config["number_of_threads"] = 1
                    scanner_config["throttle_random"] = False
                    scanner_config["pause_before_retry_on_failure"] = active_scanner_config.get('throttle')
                    burp_config["scanner"]["active_scanning_engine"] = scanner_config
                    ratelimit = 60000 / scanner_config["pause_before_retry_on_failure"]
                if active_scanner_config.get('collaborator_server'):
                    misc_config = burp_config["project_options"].get("misc", {})
                    misc_config["collaborator_server"] = active_scanner_config.get('collaborator_server')
                    burp_config["project_options"]["misc"] = misc_config
                json.dump(burp_config, config_fd)
                config_fd.close()
        except Exception:
            self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
            return '<config error>', '', False, worker_hostname
        else:
            self.update_state(state="STARTED", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})

            #XXX: move to config!
            burp_process = subprocess.Popen(['/usr/bin/java', '-jar', '-Xmx2048m',
                                             '-Djava.net.preferIPv6Addresses=true', '-Djava.awt.headless=true',
                                             '/usr/lib/yandex/burp/burpsuite_pro.jar',
                                             '--user-config-file=/usr/lib/yandex/burp/burp_user_config.json',
                                             '--config-file={config_path}'.format(config_path=config_fd.name)],
                                            stdout=stdout, stderr=stderr, env=env)
            with open(report_tpl + '.pid', 'w+') as fd:
                fd.write(str(burp_process.pid))

            proxy_started = False
            browser_process = None
            browser_running = False
            dirbuster_process = None
            dirbuster_running = False
            while burp_process.returncode is None:
                # wait
                try:
                    time.sleep(settings.SCANNER_POLL_PERIOD)
                    # get process status
                    burp_process.poll()
                    # check burp stdout to catch proxy status
                    line = burp_process.stdout.readline()
                    while line:
                        if b'proxy service started' in line.lower():
                            proxy_started = True
                            break
                        stdoutlog.write(line)
                        stdoutlog.flush()
                        line = burp_process.stdout.readline()
                    if proxy_started and (not browser_running or not dirbuster_running):
                        logger.info('Burp proxy should be available now at port {}, starting external utilities'.format(proxy_port))
                        time.sleep(settings.SCANNER_POLL_PERIOD)
                        if not browser_running:
                            req_samples_path = '/tmp/nonexistent'
                            if active_scanner_config.get('sample_requests'):
                                req_samples_path = active_scanner_config.get('sample_requests')

                            with open(tempfile.mktemp(".txt", "ph_stderr_", '/tmp'), 'wb+') as br_stderr,\
                                    open(tempfile.mktemp(".txt", "ph_stdout_", '/tmp'), 'wb+') as br_stdout:
                                browser_process = subprocess.Popen(['/usr/lib/yandex/burp/molly_repeater',
                                                                    '-f', req_samples_path, '-u',
                                                                    '://'.join([parsed_initial_url.scheme,
                                                                               parsed_initial_url.netloc]),
                                                                    '-proxy',
                                                                    '127.0.0.1:{port:d}'.format(port=proxy_port),
                                                                    '-r', '/usr/lib/yandex/repeater/resource/render.js'
                                                                    ], env=env, stdout=br_stdout, stderr=br_stderr)
                            browser_running = True
                        if not dirbuster_running:
                            with open(tempfile.mktemp(".txt", "gb_stderr_", '/tmp'), 'wb+') as gb_stderr,\
                                    open(tempfile.mktemp(".txt", "gb_stdout_", '/tmp'), 'wb+') as gb_stdout:
                                dirbuster_process = subprocess.Popen(['/usr/lib/yandex/burp/gobuster', '-w',
                                                                      '/usr/lib/yandex/burp/fuzz.txt', '-k',
                                                                      '-H', '{"X-Dirbuster-Code":"200"}',
                                                                      '-u', '://'.join([parsed_initial_url.scheme,
                                                                                        parsed_initial_url.netloc]),
                                                                      '-t', '1',
                                                                      '-L', str(ratelimit),
                                                                      '-p',
                                                                      'http://127.0.0.1:{port:d}'.format(port=proxy_port)
                                                                      ], env=env, stdout=gb_stdout, stderr=gb_stderr)
                            dirbuster_running = True
                    else:
                        if browser_process:
                            try:
                                browser_process.poll()
                            except Exception:
                                pass
                        if dirbuster_process:
                            try:
                                dirbuster_process.poll()
                            except Exception:
                                pass
                except SoftTimeLimitExceeded:
                    if not ignore_time_limit:
                        was_aborted = True
                        break

        if browser_process and browser_process.returncode is None:
            try:
                killpstree(browser_process.pid)
                browser_process.poll()
            except Exception as e:
                logger.info('Browser process kill result: %s' % str(e))

        if dirbuster_process and dirbuster_process.returncode is None:
            try:
                killpstree(dirbuster_process.pid)
                dirbuster_process.poll()
            except Exception as e:
                logger.info('Dirbuster process kill result: %s' % str(e))

        if burp_process and burp_process.returncode is None:
            try:
                killpstree(burp_process.pid)
                burp_process.poll()
            except Exception as e:
                logger.info('Burp process kill result: %s' % str(e))

    try:
        with open(report_tpl + '.xml', 'rb') as fd:
            scan_report = fd.read()
    except (OSError, IOError):
        self.update_state(state="FAILURE", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})
        return '<report read error>', '', was_aborted, worker_hostname

    max_report = 150000000
    if len(scan_report) > max_report:
        scan_report = scan_report[:max_report]
        idx = scan_report.rfind('>')
        scan_report = scan_report[:idx+1]
        scan_report += '</issues>'

    scan_log_url = ''
    try:
        if os.path.isfile(active_scanner_config['scan_log_path']):
            scan_log_url = upload_to_elliptics(hashlib.sha256(config.get("scan_uid").encode()).hexdigest(),
                                               open(active_scanner_config['scan_log_path'], 'rb').read(),
                                               public=True,
                                               ttl=60)
    except Exception:
        # we remove logs just after upload b/c they're too big to store on agents
        os.unlink(active_scanner_config['scan_log_path'])
        pass

    scan_report_url = ''
    try:
        compressed_scan_report = zlib.compress(scan_report)
        scan_report_url = upload_to_elliptics(hashlib.sha256(compressed_scan_report).hexdigest(),
                                              compressed_scan_report,
                                              ttl=7)
    except Exception:
        pass

    self.update_state(state="SUCCESS", meta={'uid': config.get("scan_uid"), 'worker': worker_hostname})

    return scan_report_url, scan_log_url, was_aborted, worker_hostname


app.tasks.register(RemoteBurpTask())


metrics_client = Metrics(settings.METRICS_URL)
metrics_client.maybe_init_session()


@app.task(bind=True, rate_limit="16/m")
def run_scan(self, scan_id):
    def prepare_work_dir():
        start = datetime.datetime.now()
        work_dir = os.path.join(settings.MOLLY_SPOOL_PATH, str(start.year), str(start.month), str(start.day))
        try:
            if not os.path.isdir(work_dir):
                os.makedirs(work_dir)
                os.chmod(work_dir, 0o777)
        except Exception as e:
            logger.info('Report path already exist %s' % e)
        return work_dir

    worker_hostname = socket.gethostname()
    try:
        scan = Scan.objects.get(pk=scan_id)
    except Exception:
        logger.error('Scan with id {} not found'.format(scan_id))
        return False

    if scan.status in [Scan.ST_ABORTED, Scan.ST_FAIL]:
        logger.info('Scan status: {}, returning False'.format(scan.status))
        return False

    scan.status = Scan.ST_INPROGRESS
    scan.task_id = self.request.id
    self.update_state(state="INQUEUE", meta={'uid': self.request.id, 'worker': worker_hostname})
    scan.save()

    report_path = prepare_work_dir()
    profile_file = scan.gen_burp_scan_profile(report_path)
    if not profile_file:
        fail_scan(scan, "config error (profile file error)")
        return False

    scan.pidfile = os.path.splitext(profile_file)[0] + '.pid'
    scan.save()

    #  XXX: stop saving temporary files, with Burp it's usable for debug only
    with open(profile_file, 'r') as fd:
        profile = fd.read()

    if not profile:
        fail_scan(scan, "config error (check base json config)")
        return False

    profile = profile.replace('{scan_uid}', scan.uid)
    try:
        config = json.loads(profile)
        active_scanner_config = config.get("burp-active-scanner", {})
    except Exception:
        fail_scan(scan, "config error (profile load error)")
        return False

    task_queue = 'molly_agents'
    if scan.scan_task.is_prod:
        task_queue = 'crasher_agents'

    burp2_queue = 'burp2'
    if has_ipv6(active_scanner_config.get('initial_url', '')):
        burp2_queue = 'burp2chrome'

    # special condition for s3.yandex.net aliases (SECTASK-16183)
    # we scan only main domain, not aliases
    if is_s3(active_scanner_config.get('initial_url', '')) and scan.target.id != 9102:
        fail_scan(scan, "Skipping scan for S3 aliases")
        return False

    result_parser_task = parse_scan_report.subtask((scan_id,), queue='scan_result_parser')
    sample_requests_url = ''
    sample_requests_format = RequestSamples.FMT_JSON
    if scan.sample_requests:
        sample_requests_format = scan.sample_requests.format
        if scan.sample_requests.url.startswith('sandbox-resource:'):
            resource_info = resolve_sandbox_resource(scan.sample_requests.url.split(':')[-1])
            sample_requests_url = resource_info.get('skynet_id', '')
        else:
            sample_requests_url = scan.sample_requests.url

    if scan.target.notify_on_scan:
        logger.info('Trying to send notifications, scan {}'.format(scan.uid))
        send_scan_start_notification.apply_async((scan,))

    logger.info('Starting scanner, type: ' + str(scan.scanner_type))

    if scan.scanner_type == Scan.SCANNER_BURP2:
        metrics_client.try_send_metric("start_scan_burp2")
        agent_task = RemoteBurpTask().apply_async((profile, report_path, profile_file, sample_requests_url,
                                                   sample_requests_format, scan.ignore_time_limit),
                                                  queue=burp2_queue, link=result_parser_task)
    elif scan.scanner_type == Scan.SCANNER_BURP:
        metrics_client.try_send_metric("start_scan_burp")
        agent_task = remote_burp_call.apply_async((profile, report_path, profile_file, sample_requests_url,
                                                   sample_requests_format, scan.ignore_time_limit),
                                                  queue=task_queue, link=result_parser_task)

    logger.info('Scanner task id ' + agent_task.id)

    scan.task_id = str(agent_task.id)
    scan.save()
    return True


@app.task(bind=True, ignore_result=True)
def stop_remote_scan(self, scan_id):
    scan = Scan.objects.get(pk=scan_id)
    if not scan.task_id:
        return False
    app.AsyncResult(scan.task_id).revoke()
    app.control.broadcast('kill_scanner', arguments={'pidfile': scan.pidfile})
    return True


@Panel.register
def kill_scanner(state, **kwargs):
    try:
        with open(kwargs.get('pidfile')) as fd:
            pid = fd.read()
            killpstree(int(pid))
    except Exception:
        return False
    return True


@task_failure.connect
def task_failure_error(einfo=None, **kwargs):
    tracer = ErrorTracer(settings.ERRORTRACER_URL)
    tracer.init_session()
    tracer.try_send_trace(einfo.traceback or ' File "./tasks.py", line 1337, in empty_einfo.traceback')
