package com.yandex.burp.extensions.plugins.audit;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import burp.IBurpExtenderCallbacks;
import burp.IExtensionHelpers;
import burp.IHttpRequestResponse;
import burp.IRequestInfo;
import burp.IResponseInfo;
import burp.IScanIssue;
import burp.IScannerInsertionPoint;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import com.yandex.burp.extensions.plugins.CustomScanIssue;
import com.yandex.burp.extensions.plugins.Utils;
import com.yandex.burp.extensions.plugins.config.BurpMollyPackConfig;

/**
 * Created by ilyaon on 18.07.2019
 */
public class PrototypePollutionPlugin implements IAuditPlugin {

    private IBurpExtenderCallbacks callbacks;
    private IExtensionHelpers helpers;
    private Gson gson;

    private static final int ISSUE_TYPE = 0x00101080;
    private static final String ISSUE_NAME = "Prototype pollution attack";
    private static final String SEVERITY = "High";
    private static final String CONFIDENCE = "Firm";

    private static final String CONFIDENCE_LOW = "Certain";
    private static final String SEVERITY_LOW = "Information";

    private static final List<Byte> contentTypes = Arrays.asList(
            IRequestInfo.CONTENT_TYPE_URL_ENCODED,
            // IRequestInfo.CONTENT_TYPE_MULTIPART, TODO: enable, sometimes used for json
            IRequestInfo.CONTENT_TYPE_JSON
    );

    private static final String JSON_CONTENT_TYPE = "application/json; charset=utf-8";
    private static final String CONTENT_TYPE = "Content-Type";

    private static final List<Integer> errorCodes = Arrays.asList(
            500,
            502,
            503,
            504
    );

    private static final List<Byte> insertionPoints = Arrays.asList(
            IScannerInsertionPoint.INS_PARAM_NAME_BODY, // for urlencoded
            IScannerInsertionPoint.INS_ENTIRE_BODY // for json
    );

    public PrototypePollutionPlugin(IBurpExtenderCallbacks callbacks, BurpMollyPackConfig extConfig) {
        this.callbacks = callbacks;
        this.helpers = callbacks.getHelpers();
        this.gson = new Gson();
    }

    public List<IScanIssue> doScan(IHttpRequestResponse baseRequestResponse, IScannerInsertionPoint insertionPoint) {

        if (!insertionPoints.contains(insertionPoint.getInsertionPointType())) {
            return null;
        }

        IRequestInfo request = helpers.analyzeRequest(baseRequestResponse.getRequest());
        IResponseInfo response = helpers.analyzeResponse(baseRequestResponse.getResponse());

        if (request == null || response == null || response.getStatusCode() != 200) {
            return null;
        }

        byte contentType = request.getContentType();

        if (!contentTypes.contains(contentType)) {
            return null;
        }

        byte[] requestBody = Arrays.copyOfRange(
                baseRequestResponse.getRequest(),
                request.getBodyOffset(),
                baseRequestResponse.getRequest().length
        );

        JsonObject json;

        try {
            json = toJson(contentType, helpers.bytesToString(requestBody));
        } catch (Exception exception) {
            exception.printStackTrace();
            return null;
        }

        RandomPayload randomPayload = new RandomPayload();
        poisonJson(json, randomPayload, true);

        String payload = gson.toJson(json);

        List<IScanIssue> issues = new ArrayList<>();

        List<String> headers = request.getHeaders();
        Utils.setHeaderValue(headers, CONTENT_TYPE, JSON_CONTENT_TYPE);

        byte[] attackRequest = helpers.buildHttpMessage(headers, helpers.stringToBytes(payload));
        IHttpRequestResponse attack = this.callbacks.makeHttpRequest(baseRequestResponse.getHttpService(),
                attackRequest);

        if (attack.getResponse() == null) {
            return null;
        }

        String attackResponseText = helpers.bytesToString(attack.getResponse());
        Matcher matcher = randomPayload.getMarker().matcher(attackResponseText);

        // detect by marker
        if (matcher.find()) {
            List<int[]> requestMarkers = Collections.singletonList(
                    null
            );

            List<int[]> responseMarkers = Collections.singletonList(new int[]{
                    matcher.start(),
                    matcher.end()
            });

            String attackDetails = "Prototype pollution attack " +
                    "(https://github.com/HoLyVieR/prototype-pollution-nsec18/" +
                    "blob/master/paper/JavaScript_prototype_pollution_attack_in_NodeJS.pdf)";

            issues.add(new CustomScanIssue(attack.getHttpService(),
                    this.helpers.analyzeRequest(attack).getUrl(), new IHttpRequestResponse[] {
                    this.callbacks.applyMarkers(attack, requestMarkers, responseMarkers)
            },
                    attackDetails, ISSUE_TYPE, ISSUE_NAME, SEVERITY, CONFIDENCE,
                    "", "", ""));
        }

        Integer statusCode = (int) helpers.analyzeResponse(attack.getResponse()).getStatusCode();

        // detect by error code
        if (errorCodes.contains(statusCode)) {
            List<int[]> requestMarkers = Collections.singletonList(
                    null
            );

            List<int[]> responseMarkers = Collections.singletonList(
                    null
            );

            String attackDetails = "Prototype pollution attack " +
                    "(https://github.com/HoLyVieR/prototype-pollution-nsec18/" +
                    "blob/master/paper/JavaScript_prototype_pollution_attack_in_NodeJS.pdf)";

            issues.add(new CustomScanIssue(attack.getHttpService(),
                    this.helpers.analyzeRequest(attack).getUrl(), new IHttpRequestResponse[] {
                    this.callbacks.applyMarkers(attack, requestMarkers, responseMarkers)
            },
                    attackDetails, ISSUE_TYPE, ISSUE_NAME, SEVERITY_LOW, CONFIDENCE_LOW,
                    "", "", ""));
        }

        return issues;
    }

    private JsonObject toJson(byte contentType, String body) {
        if (contentType == IRequestInfo.CONTENT_TYPE_URL_ENCODED) {
            return convertUrlencodedToJson(body);
        } else if (contentType == IRequestInfo.CONTENT_TYPE_JSON) {
            return new JsonParser().parse(body).getAsJsonObject();
        } else {
            throw new RuntimeException("Unknown Content-Type");
        }
    }

    private JsonObject convertUrlencodedToJson(String body) {
        return splitQuery(body);
    }

    private static final String PROTO = "__proto__";

    // mutate object
    private void poisonJson(JsonObject object, RandomPayload payload, boolean skipTokens) {

        Set<String> keys = getKeys(object);

        for (String key : keys) {
            JsonElement element = object.get(key);

            if (element.isJsonObject()) {

                JsonObject nextObject = (JsonObject) element;
                poisonJson(nextObject, payload, skipTokens);

            } else if (element.isJsonArray()) {

                JsonArray nextArray = (JsonArray) element;
                poisonJson(nextArray, payload, skipTokens);

            } else if (element.isJsonPrimitive()) {

                if (skipTokens) {
                    JsonPrimitive primitive = (JsonPrimitive) element;
                    if (primitive.isString() && Utils.isYandexCSRF(primitive.getAsString()))
                        continue;
                }

                JsonObject objectProto = new JsonObject();
                JsonObject poison = new JsonObject();
                poison.addProperty(payload.name, payload.value);
                objectProto.add(PROTO, poison);
                object.add(key, objectProto);
            }

            JsonObject poison = new JsonObject();
            poison.addProperty(payload.name, payload.value);
            object.add(PROTO, poison);

        }
    }

    private Set<String> getKeys(JsonObject object) {
        Set<String> keys = new LinkedHashSet<>();
        for (Map.Entry<String, JsonElement> entry: object.entrySet()) {
            keys.add(entry.getKey());
        }
        return keys;
    }

    // mutate array
    private void poisonJson(JsonArray array, RandomPayload payload, boolean skipTokens) {

        for (int i = 0; i < array.size(); i++) {
            JsonElement element = array.get(i);

            if (element.isJsonArray()) {
                JsonArray nextArray = (JsonArray) element;
                poisonJson(nextArray, payload, skipTokens);
            } else if (element.isJsonObject()) {
                JsonObject nextObject = (JsonObject) element;
                poisonJson(nextObject, payload, skipTokens);
            }
        }

        JsonObject objectProto = new JsonObject();
        JsonObject poison = new JsonObject();
        poison.addProperty(payload.name, payload.value);
        objectProto.add(PROTO, poison);
        array.add(objectProto);
    }

    private JsonObject splitQuery(String body) {
        JsonObject root = new JsonObject();
        String[] pairs = body.split("&");

        for (String pair : pairs) {
            int idx = pair.indexOf("=");
            boolean isObject = false;

            String key;
            List<String> segments = Collections.emptyList();

            if (idx > 0) {
                key = pair.substring(0, idx);
                segments = getSegments(key);

                if (!segments.isEmpty())
                    isObject = true;
            }

            String value;

            if (idx > 0 && pair.length() > idx + 1) {
                value = helpers.urlDecode(pair.substring(idx + 1));
            } else {
                value = "";
            }

            if (isObject) {
                merge(segments, root, value);
            }

        }

        return root;
    }

    private static JsonObject merge(List<String> segments, JsonObject root, String value) {
        JsonElement current = root;
        List<JsonElement> history = new ArrayList<>();

        for (int i = 0; i < segments.size(); i++) {
            String segment = segments.get(i);

            if (current.isJsonObject()) {

                JsonObject object = (JsonObject) current;

                if (object.get(segment) != null) {
                    JsonElement next = object.get(segment);

                    if (next.isJsonPrimitive()) {
                        JsonArray array = new JsonArray();
                        object.remove(segment);
                        array.add(next);
                        object.add(segment, array);
                        current = array;
                    } else {
                        current = next;
                    }

                } else if (segment.equals("")) {
                    JsonElement lastElement =  history.get(history.size() - 1);
                    String prevSegment = segments.get(i - 1);

                    if (lastElement.isJsonObject()) {
                        JsonObject lastObject = (JsonObject) lastElement;
                        JsonElement element = lastObject.get(prevSegment);
                        lastObject.remove(prevSegment);
                        JsonArray array = new JsonArray();

                        if (!element.equals(new JsonObject()))
                            array.add(element);

                        lastObject.add(prevSegment, array);
                        current = array;

                    } else if (lastElement.isJsonArray()) {
                        JsonObject newObject = new JsonObject();
                        JsonArray lastArray = new JsonArray();

                        lastArray.add(newObject);
                        current = newObject;
                    }

                } else if (i + 1 < segments.size()) {
                    JsonObject newObject = new JsonObject();
                    object.add(segment, newObject);
                    history.add(current);
                    current = newObject;
                }

            } else if (current.isJsonArray()) {
                if (segment.equals("")) {
                    continue;
                }
                JsonArray array = (JsonArray) current;
                JsonObject newObject = new JsonObject();
                array.add(newObject);
                current = newObject;
            }
        }

        if (current.isJsonObject()) {
            JsonObject object = (JsonObject) current;
            object.addProperty(segments.get(segments.size() - 1), value);
        } else if (current.isJsonArray()) {
            JsonArray array = (JsonArray) current;
            array.add(new JsonPrimitive(value));
        }

        return root;
    }

    private static final char START_BRACKET = '[';
    private static final char END_BRACKET = ']';

    private List<String> getSegments(String input) {
        int startNested = input.indexOf(START_BRACKET);

        if (startNested == -1) {
            return Collections.singletonList(helpers.urlDecode(input));
        }

        List<String> segments = new ArrayList<>();

        segments.add(helpers.urlDecode(input.substring(0, startNested)));

        StringBuilder buffer = new StringBuilder();

        boolean inBrackets = false;

        for (int i = startNested; i < input.length(); i++) {
            if (input.charAt(i) == START_BRACKET) {
                // (a) current char is [, should not in brackets

                if (inBrackets)
                    throw new RuntimeException("SyntaxError");

                inBrackets = true;

            } else if (input.charAt(i) == END_BRACKET) {
                // (b) current char is ], should in brackets

                if (!inBrackets)
                    throw new RuntimeException(("SyntaxError"));

                inBrackets = false;

                if (buffer.length() > 0) {
                    // flush buffer on end of sequence
                    segments.add(helpers.urlDecode(buffer.toString()));
                    buffer.delete(0, buffer.length());
                }

            } else {
                // (c) current char is any, except [ and ], add to buffer
                buffer.append(input.charAt(i));
            }
        }

        return segments;
    }

    private class RandomPayload {

        public String name;
        public String value;

        public RandomPayload() {
            this.name = UUID.randomUUID().toString().replaceAll("-", "").substring(0, 10);
            this.value = UUID.randomUUID().toString().replaceAll("-", "").substring(0, 10);
        }

        public Pattern getMarker() {
            return Pattern.compile(String.format("%s=[\"']?%s[\"']?", name, value));
        }
    }
}
