package burp;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.google.gson.Gson;
import com.google.gson.JsonParseException;
import com.jayway.awaitility.Awaitility;
import com.jayway.awaitility.core.ConditionTimeoutException;
import com.yandex.burp.extensions.EntryPointDeduplicator;
import com.yandex.burp.extensions.MollyProxyListener;
import com.yandex.burp.extensions.MollyRequestResponseHandler;
import com.yandex.burp.extensions.MollyStatsListener;
import com.yandex.burp.extensions.ScanLogger;
import com.yandex.burp.extensions.StatsCounter;
import com.yandex.burp.extensions.auth.IMollyAuthAdapter;
import com.yandex.burp.extensions.auth.YandexAuthAdapter;
import com.yandex.burp.extensions.config.BurpMollyScannerConfig;
import com.yandex.burp.extensions.config.MollyAuthConfig;
import com.yandex.burp.extensions.config.MollyConfig;


public class BurpExtender implements IBurpExtender,
        IScannerListener,
        IExtensionStateListener {
    private IBurpExtenderCallbacks callbacks;
    private IExtensionHelpers helpers;
    private ArrayList<IScanIssue> issues;
    private ConcurrentHashMap<String, Integer> issueStat;
    public List<IScanQueueItem> scanners;
    public EntryPointDeduplicator deduper;
    private List<IHttpRequestResponse> postponedEntryPoints;
    private int scanTime;
    private int totalScanners;
    private boolean timeout;
    private PrintWriter stdout;
    private BurpMollyScannerConfig extConfig;
    private static final int timeStep = 15;
    private IMollyAuthAdapter authenticator;
    private ScanLogger scanLogger;
    private StatsCounter statsCounter;

    public BurpMollyScannerConfig getExtConfig() {
        return extConfig;
    }

    public IMollyAuthAdapter getAuthenticator() {
        return authenticator;
    }

    public void postponeEntryPoint(IHttpRequestResponse messageInfo) {
        synchronized (postponedEntryPoints) {
            postponedEntryPoints.add(messageInfo);
        }
    }

    public void trackScanner(IScanQueueItem scan) {
        synchronized (scanners) {
            scanners.add(scan);
        }
    }

    //
    // implement IBurpExtender
    //
    @Override
    public void registerExtenderCallbacks(IBurpExtenderCallbacks callbacks) {
        /* keep a reference to our callbacks object */
        this.callbacks = callbacks;
        this.helpers = callbacks.getHelpers();
        this.timeout = false;
        this.totalScanners = 0;
        this.scanTime = 0;

        /* obtain our output stream */
        this.stdout = new PrintWriter(callbacks.getStdout(), true);

        /* set our extension name */
        callbacks.setExtensionName("MollyBurp");

        stdout.println("Extension was loaded");

        Map<String, String> env = System.getenv();
        String configPath = env.get("MOLLY_CONFIG");

        for (String arg : callbacks.getCommandLineArguments()) {
            String[] kv = arg.split("=", 2);
            if (kv.length != 2) {
                continue;
            }
            if (kv[0].equals("--molly-config")) {
                configPath = kv[1];
                break;
            }
        }

        if (configPath == null || configPath.length() == 0) {
            /* fallback to default config */
            configPath = "burp_molly_config.json";
        }

        MollyConfig mollyConfig;
        try {
            String configJSON = new String(Files.readAllBytes(Paths.get(configPath)), StandardCharsets.UTF_8);
            mollyConfig = new Gson().fromJson(configJSON, MollyConfig.class);
        } catch (IOException | JsonParseException e) {
            stdout.println("Error loading extension config");
            callbacks.exitSuite(false);
            return;
        }

        this.extConfig = mollyConfig.getBurpActiveScanner();
        if (extConfig == null) {
            stdout.println("Error loading extension config");
            callbacks.exitSuite(false);
            return;
        }

        this.issues = new ArrayList<>();
        this.issueStat = new ConcurrentHashMap<>();
        this.scanners = Collections.synchronizedList(new ArrayList<IScanQueueItem>());
        this.postponedEntryPoints = new ArrayList<>();
        this.deduper = new EntryPointDeduplicator(callbacks);

        if (extConfig.getReportPath() == null) {
            stdout.println("No report path configured");
            callbacks.exitSuite(false);
            return;
        }

        if (extConfig.getEntryPoint() == null) {
            stdout.println("No initial_url configured");
            callbacks.exitSuite(false);
            return;
        }

        MollyAuthConfig authConfig = extConfig.getAuthConfig();
        /* TODO: use reflections? */
        switch (authConfig.getAuthProvider().toUpperCase()) {
            case "YANDEX":
                authenticator = new YandexAuthAdapter(callbacks, authConfig);
                if (authenticator.isAuthExpected()) {
                    if (!authenticator.doAuth(null)) {
                        stdout.println("Auth config error. Invalid username/password?");
                        callbacks.exitSuite(false);
                        return;
                    }
                }
                break;
            default:
                break;
        }

        try {
            extConfig.setInitialURL(new URL(extConfig.getEntryPoint()));
        } catch (MalformedURLException e) {
            stdout.println("Invalid initial URL " + extConfig.getEntryPoint());
            callbacks.exitSuite(false);
            return;
        }

        /* TODO: refactor and move this to separate logic */
        this.scanLogger = new ScanLogger(callbacks, extConfig);

        /* register custom class as an HTTP listener */
        this.callbacks.registerHttpListener(new MollyRequestResponseHandler(callbacks,
                extConfig, authenticator, scanners, deduper, postponedEntryPoints, scanLogger));
        /* register custom class as an Proxy listener */
        this.callbacks.registerProxyListener(new MollyProxyListener(callbacks, extConfig, authenticator, deduper));

        if (extConfig.getEnableStats()) {
            this.statsCounter = new StatsCounter();
            this.callbacks.registerHttpListener(new MollyStatsListener(this.callbacks, this.statsCounter));
        }

        /* register ourselves as a Scanner listener */
        callbacks.registerScannerListener(this);

        /* register ourselves as an extension state listener */
        callbacks.registerExtensionStateListener(this);

        if (extConfig.getComplexDomainFlag()) {
            /* make sure if initial url is included in scope */
            if (!callbacks.isInScope(extConfig.getInitialURL())) {
                callbacks.includeInScope(extConfig.getInitialURL());
            }
        } else {
            /* make sure root url is included in scope b/c we run a lot of checks there */
            try {
                URL rootUrl = new URL(extConfig.getInitialURL().getProtocol(),
                        extConfig.getInitialURL().getHost(),
                        extConfig.getInitialURL().getPort() == -1 ?
                                extConfig.getInitialURL().getDefaultPort() :
                                extConfig.getInitialURL().getPort(),
                        "/");
                if (!callbacks.isInScope(rootUrl)) {
                    callbacks.includeInScope(rootUrl);
                }
            } catch (MalformedURLException e) {
                stdout.println("Invalid initial URL " + extConfig.getEntryPoint());
                callbacks.exitSuite(false);
                return;
            }
        }

        callbacks.sendToSpider(extConfig.getInitialURL());

        int maxTime = extConfig.getScanTimeout() * 60;

        int maxProxyWaitTime = extConfig.getProxyTimeout() * 60;

        /* we wait for external requests via Proxy in case if no entry point found */
        waitForProxyStart(maxProxyWaitTime);

        /* main work happens meanwhile */
        waitForScanners(maxTime - maxProxyWaitTime);

        /* TODO: move scanners count and max scanners count to config */
        int maxPendingScanners = 10;
        int maxActiveScanners = 20;

        /* run scan for postponed entry points */
        if (!timeout && ((maxTime-scanTime) > timeStep*2) && totalScanners < maxPendingScanners) {
            for (IHttpRequestResponse messageInfo : postponedEntryPoints) {
                if (scanners.size() > maxActiveScanners) {
                    break;
                }
                /* Full-dub detection */
                if (deduper.isFullDuplicate(messageInfo)) continue;

                /* Half-dub detection */
                if (deduper.isHalfDuplicate(messageInfo)) continue;

                /* Do not scan URLs with same parameters twice (?) */
                if (deduper.isDuplicateURL(messageInfo)) continue;

                IScanQueueItem scan = callbacks.doActiveScan(
                        extConfig.getInitialURL().getHost(),
                        extConfig.getInitialURL().getPort() == -1 ?
                                extConfig.getInitialURL().getDefaultPort() :
                                extConfig.getInitialURL().getPort(),
                        extConfig.getInitialURL().getProtocol().equals("https"),
                        messageInfo.getRequest());

                synchronized (scanners) {
                    scanners.add(scan);
                }
            }
        }

        /* wait for postponed entry points to be scanned */
        waitForScanners(maxTime - maxProxyWaitTime - scanTime);

        if (authenticator != null) {
            authenticator.doLogout(null);
        }

        if (extConfig.getEnableStats()) {
            issues.add(this.statsCounter.generateIssue());
        }

        if (issues != null) {
            callbacks.generateScanReport("XML", issues.toArray(new IScanIssue[issues.size()]),
                    new File(extConfig.getReportPath()));
        }

        scanLogger.finish();
        callbacks.exitSuite(false);
    }

    //
    // implement IScannerListener
    //
    @Override
    public void newScanIssue(IScanIssue issue) {
        if (issue == null) return;

        List<Integer> ignoreIssueIds = extConfig.getIgnoredIssueIds();
        if (ignoreIssueIds != null && ignoreIssueIds.contains(issue.getIssueType())) return;
        IHttpService issueService = issue.getHttpService();

        /* XXX: test if it works */
        if (!issueService.getHost().contains(extConfig.getInitialURL().getHost())) {
            return;
        }

        /* Do not store more than X issues of one type - prevent huuuge reports */
        int existingIssues = issueStat.getOrDefault(issue.getIssueName(), 0);
        if (extConfig.getMaxIssuesByType() > 0 && existingIssues >= extConfig.getMaxIssuesByType()) {
            return;
        }

        switch (issue.getIssueType()) {
            // 5244160 = Cross Domain Script include, handle whitelisting here
            case 0x00500500:
                List<String> wl = extConfig.getCrossdomainJsWhitelist();
                if (wl == null) break;
                for (String d : wl) {
                    Pattern ValidPattern = Pattern.compile("https://" + d);
                    Matcher matcher = ValidPattern.matcher(issue.getIssueDetail());
                    if (matcher.find()) {
                        /* TODO: grep URIs then parse and match only domains */
                        return;
                    }
                }
                break;
            /* handle CORS whitelist here */
            case 0x00200600:
                boolean isInteresting = false;
                IHttpRequestResponse[] trans = issue.getHttpMessages();
                if (trans == null) return;
                for (IHttpRequestResponse t : trans) {
                    stdout.println(t.getHttpService().getHost());
                    stdout.println(issue.getHttpService().getHost());
                    if (!extConfig.getPublicCorsWhitelist().contains(t.getHttpService().getHost())) {
                        isInteresting = true;
                        break;
                    }
                }
                if (!isInteresting) return;
                break;
            // 2098176 = crossdomain.xml, handle whitelisting here
            case 0x200400:
                if (extConfig.getCrossdomainXmlWhitelist().contains(issue.getHttpService().getHost())) {
                    return;
                }
                break;
            case 0x00500300:
                /* get cookie name and whitelist */
                break;
            default:
                break;
        }

        issueStat.put(issue.getIssueName(), existingIssues+1);
        issues.add(issue);
    }

    //
    // implement IExtensionStateListener
    //
    @Override
    public void extensionUnloaded() {
        stdout.println("Extension was unloaded");
    }

    private void waitForProxyStart(int maxTime) {
        if (extConfig.getLastProxyRequestTimestamp() == 0) {
            /* if no requests were sent we wait maxTime starting from current timestamp */
            extConfig.setLastProxyRequestTimestamp(java.time.Instant.now().getEpochSecond());
        }

        while (java.time.Instant.now().getEpochSecond() - extConfig.getLastProxyRequestTimestamp() < maxTime) {
            try {
                /* check every 30 seconds */
                Thread.sleep(1000 * 30);
            } catch(InterruptedException e){
                return;
            }
        }
    }

    private void waitForScanners(int maxTime) {
        /* set max scan time to 12 hours */
        if (maxTime == 0) {
            maxTime = 3600*12;
        }
        if (maxTime <= timeStep) {
            maxTime = timeStep * 2;
        }
        try {
            Awaitility.with().timeout(maxTime, TimeUnit.SECONDS)
                    .and().with().pollDelay(timeStep, TimeUnit.SECONDS)
                    .and().with().pollInterval(timeStep, TimeUnit.SECONDS)
                    .await()
                    .until(new Callable<Boolean>() {
                        @Override
                        public Boolean call() throws Exception {
                            scanTime += timeStep;
                            synchronized (scanners) {
                                Iterator<IScanQueueItem> i = scanners.iterator();
                                while (i.hasNext()) {
                                    IScanQueueItem scan = i.next();
                                    if (scan.getStatus().equals("finished")) {
                                        i.remove();
                                        totalScanners += 1;
                                    } else {
                                        stdout.println("Scanners: " + scanners.size());

                                        if (issues != null) {
                                            callbacks.generateScanReport("XML", issues.toArray(new IScanIssue[issues.size()]),
                                                    new File(extConfig.getReportPath()));
                                        }

                                        return false;
                                    }
                                }
                            }
                            return true;
                        }
                    });
        } catch (ConditionTimeoutException e) {
            /* exiting anyway */
            stdout.println("timeout!");
            timeout = true;
        }
    }
}
