package com.yandex.burp.extensions.auth;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import burp.IBurpExtenderCallbacks;
import burp.ICookie;
import burp.IExtensionHelpers;
import burp.IHttpRequestResponse;
import burp.IParameter;
import burp.IRequestInfo;
import burp.IResponseInfo;
import com.google.common.net.InternetDomainName;
import com.yandex.burp.extensions.config.MollyAuthConfig;

/* XXX: use loggin instead
 * import java.io.PrintWriter;
 */

/**
 * Created by ezaitov on 03.02.2017.
 */
public class YandexAuthAdapter implements IMollyAuthAdapter {
    private final int MAX_AUTH_TRIES = 2;
    private final MollyAuthConfig authConfig;
    private IBurpExtenderCallbacks callbacks;
    private IExtensionHelpers helpers;
    private List<ICookie> sessionCookies;
    private int authFailures;
    private static final Pattern CSRFPattern = Pattern.compile("<input type=\"hidden\" name=\"csrf_token\" value=\"(.*?)\"");

    public YandexAuthAdapter(IBurpExtenderCallbacks callbacks, MollyAuthConfig authConfig) {
        this.callbacks = callbacks;
        this.helpers = callbacks.getHelpers();
        this.authConfig = authConfig;
        this.authFailures = 0;
        this.sessionCookies = new ArrayList<>();
    }

    public boolean isAuthExpected() {
        switch (authConfig.getAuthSchema().toUpperCase()) {
            case "PASSPORT":
                if (authConfig.getAuthUsername() == null ||
                        authConfig.getAuthPassword() == null ||
                        authConfig.getAuthHost() == null ||
                        authConfig.getAuthHost().equals("") ||
                        authConfig.getAuthUsername().equals("") ||
                        authConfig.getAuthPassword().equals("")) {
                    return false;
                }
                return true;
            case "OAUTH":
                return true;
        }
        return false;
    }

    public boolean doAuth(IHttpRequestResponse messageInfo) {
/*
        PrintWriter stdout = new PrintWriter(callbacks.getStdout(), true);
 */
        switch (authConfig.getAuthSchema().toUpperCase()) {
            case "PASSPORT":
                if (messageInfo != null) {
                    /* Do we have cookies set? */
                    if (isAuthenticated(messageInfo)) return true;

                    /* Do we analyze Passport request? */
                    IRequestInfo requestInfo = helpers.analyzeRequest(messageInfo.getRequest());
                    if (messageInfo.getHttpService().getHost().equals(authConfig.getAuthHost())) {
                        return true;
                    }

                    if (messageInfo.getHttpService() == null) {
                        return false;
                    }
                    /* Do we scan domain which does not have Passport? Say it's ok even if it is not */
                    InternetDomainName passportDomain = InternetDomainName.from(getAuthURL().getHost()).topPrivateDomain();
                    if (!passportDomain.equals(InternetDomainName.from(messageInfo.getHttpService().getHost()).topPrivateDomain())){
                        return true;
                    }

                    if (sessionCookies != null) {
                        byte[] modifiedRequest = messageInfo.getRequest();
                        for (ICookie cookie : sessionCookies) {
                            modifiedRequest = helpers.addParameter(modifiedRequest,
                                    helpers.buildParameter(
                                            cookie.getName(),
                                            cookie.getValue(),
                                            IParameter.PARAM_COOKIE));
                        }
                        messageInfo.setRequest(modifiedRequest);
                        return true;
                    }
                }

                if (authFailures > MAX_AUTH_TRIES) {
                    return false;
                }

                byte[] authRequest = buildAuthRequest();
                byte[] authResp;
                try {
                    authResp = callbacks.makeHttpRequest(authConfig.getAuthHost(), 443, true, authRequest);
                } catch (RuntimeException e) {
                    return false;
                }

//                stdout.println(helpers.bytesToString(authResp));

                IResponseInfo passportResp = callbacks.getHelpers().analyzeResponse(authResp);
                if (passportResp.getCookies()
                        .stream()
                        .filter(cn -> cn.getName().equals("Session_id"))
                        .findFirst().isPresent() &&
                        passportResp.getCookies()
                                .stream()
                                .filter(cn -> cn.getName().equals("sessionid2"))
                                .findFirst().isPresent()) {

                    if (sessionCookies.size() > 0) {
                        sessionCookies.clear();
                    }
                    /* TODO: remove it as soon as passport will be ready */
                    sessionCookies.add(passportResp.getCookies().stream()
                            .filter(c -> c.getName().contains("sessionid2"))
                            .findFirst().get());
                    sessionCookies.add(passportResp.getCookies().stream()
                            .filter(c -> c.getName().contains("Session_id"))
                            .findFirst().get());

/*
                    if (passportResp.getStatusCode() == 302) {
                        List<String> headers = passportResp.getHeaders();
                        String locationHeader = Utils.getHeaderValue(headers, "Location");
                        if (locationHeader == null) {
                            authFailures += 1;
                            return false;
                        }
                        callbacks.makeHttpRequest(authConfig.getAuthHost(), 443, true,
                                helpers.buildHttpRequest(locationHeader));
                    }
*/
                    return true;
                }
                return false;
            case "OAUTH":
                /* TODO: check if OAuth token valid? */
                if (messageInfo == null) return true;
                IRequestInfo requestInfo = helpers.analyzeRequest(messageInfo.getRequest());
                List<String> reqHeaders = requestInfo.getHeaders();
                List<String> newHeaders = new ArrayList<String>();
                for (String h : reqHeaders) {
                    if (!h.toUpperCase().startsWith("AUTHORIZATION:"))
                        newHeaders.add(h);
                }
                newHeaders.add("Authorization: " + authConfig.getAuthPassword());

                byte[] body;
                byte[] modifiedReq;
                if (helpers.bytesToString(messageInfo.getRequest()).length() > requestInfo.getBodyOffset()) {
                    body = helpers.stringToBytes(helpers.bytesToString(messageInfo.getRequest()).substring(requestInfo.getBodyOffset()));
                    modifiedReq = helpers.buildHttpMessage(newHeaders, body);
                } else {
                    modifiedReq = helpers.buildHttpMessage(newHeaders, "".getBytes());
                }

//                stdout.println(helpers.bytesToString(modifiedReq));
                messageInfo.setRequest(modifiedReq);
                return true;
            default:
                return false;
        }
    }

    public void doLogout(IHttpRequestResponse messageInfo) {
        return;
    }

    private byte[] buildAuthRequest() {
        byte[] stepOneRequest = helpers.buildHttpRequest(getPassportURL());
        byte[] passpResp = callbacks.makeHttpRequest(authConfig.getAuthHost(), 443, true, stepOneRequest);
        Matcher matcher = CSRFPattern.matcher(helpers.bytesToString(passpResp));

        byte[] rawRequest = helpers.buildHttpRequest(getAuthURL());
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("mode", "auth", IParameter.PARAM_URL));

        rawRequest = helpers.toggleRequestMethod(rawRequest);
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("login",
                        helpers.urlEncode(authConfig.getAuthUsername()), IParameter.PARAM_BODY));
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("passwd",
                        helpers.urlEncode(authConfig.getAuthPassword()), IParameter.PARAM_BODY));
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("from", "passport", IParameter.PARAM_BODY));
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("display", "page", IParameter.PARAM_BODY));
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("retpath",
                        helpers.urlEncode(getProfileURL().toString()), IParameter.PARAM_BODY));
        /* TODO: set correct timestamp or remove it */
        rawRequest = helpers.addParameter(rawRequest,
                helpers.buildParameter("timestamp", "123", IParameter.PARAM_BODY));
        if (matcher.find()) {
            rawRequest = helpers.addParameter(rawRequest,
                    helpers.buildParameter("csrf_token",
                            helpers.urlEncode(matcher.group(1)), IParameter.PARAM_BODY));
            /* ok, we got CSRF token, need to preserve cookies as well */
            IResponseInfo stepOneResp = callbacks.getHelpers().analyzeResponse(passpResp);
            for (ICookie cookie : stepOneResp.getCookies()) {
                rawRequest = helpers.addParameter(rawRequest,
                        helpers.buildParameter(
                                cookie.getName(),
                                cookie.getValue(),
                                IParameter.PARAM_COOKIE));
            }
        }

        return rawRequest;
    }

    private URL getAuthURL() {
        try {
            return new URL("https", authConfig.getAuthHost(), 443, "/passport");
        } catch (MalformedURLException e) {
            return null;
        }
    }

    private URL getProfileURL() {
        try {
            return new URL("https", authConfig.getAuthHost(), 443, "/profile");
        } catch (MalformedURLException e) {
            return null;
        }
    }

    private URL getPassportURL() {
        try {
            return new URL("https", authConfig.getAuthHost(), 443, "/auth");
        } catch (MalformedURLException e) {
            return null;
        }
    }

    public boolean isLogoutRequest(IHttpRequestResponse messageInfo) {
        if (!helpers.analyzeRequest(messageInfo).getUrl().toString().toUpperCase().contains(authConfig.getAuthHost().toUpperCase())) {
            return false;
        }

        if (helpers.analyzeRequest(messageInfo).getParameters()
                .stream()
                .filter(p -> p.getValue().toUpperCase().contains("LOGOUT"))
                .findFirst().isPresent()) {
            return true;
        }
        return false;
    }

    public boolean isAuthenticated(IHttpRequestResponse messageInfo) {
        switch (authConfig.getAuthSchema().toUpperCase()) {
            case "PASSPORT":
                if (getAuthURL() == null) return true;
/*                PrintWriter stdout = new PrintWriter(callbacks.getStdout(), true);


                stdout.println("IsAuth:\n" + helpers.bytesToString(messageInfo.getRequest()));
                stdout.println("Failures:");
                stdout.println(authFailures);
                stdout.println("Params:");

                for (IParameter p: helpers.analyzeRequest(messageInfo).getParameters()) {
                    stdout.println(p.getType());
                    stdout.println(p.getName());
                }
*/

                return helpers.analyzeRequest(messageInfo).getParameters()
                    .stream()
                    .filter(p -> p.getType() == IParameter.PARAM_COOKIE)
                    .filter((cn -> (cn.getName().equals("sessionid2")) || cn.getName().equals("Session_id")))
                    .findFirst().isPresent();
            case "OAUTH":
                return true;
        }

        return false;
    }
}
