package ru.yandex.solomon.http.filters;

import java.util.EnumSet;

import com.google.common.base.Charsets;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;


/**
 * Based on <a href="https://github.com/spring-projects/spring-security/blob/master/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java">StrictHttpFirewall</a>
 *
 * @author Sergey Polovko
 */
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
public class HttpFirewallFilter implements WebFilter {

    private static final EnumSet<HttpMethod> ALLOWED_HTTP_METHODS = EnumSet.of(
            HttpMethod.GET,
            HttpMethod.POST,
            HttpMethod.PUT,
            HttpMethod.DELETE,
            HttpMethod.OPTIONS,
            HttpMethod.PATCH);

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();

        HttpMethod method = request.getMethod();
        if (method == null || !ALLOWED_HTTP_METHODS.contains(method)) {
            var content = String.format("The request was rejected because the HTTP method %s is not allowed", method);
            return sendReject(response, content);
        }

        if (!isNormalized(request.getPath().value())) {
            return sendReject(response, "The request was rejected because the URL was not normalized.");
        }

        return chain.filter(exchange);
    }

    private static Mono<Void> sendReject(ServerHttpResponse response, String content) {
        response.setStatusCode(HttpStatus.FORBIDDEN);
        response.getHeaders().setContentType(MediaType.TEXT_PLAIN);
        response.getHeaders().setContentLength(content.length());

        DataBuffer body = response.bufferFactory()
                .wrap(content.getBytes(Charsets.UTF_8));

        return response.writeWith(Mono.just(body))
                .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
    }

    /**
     * Checks whether a path is normalized (doesn't contain path traversal
     * sequences like "./", "/../" or "/.")
     *
     * @param path
     *            the path to test
     * @return true if the path doesn't contain any path-traversal character
     *         sequences.
     */
    private static boolean isNormalized(String path) {
        if (path == null) {
            return true;
        }

        if (path.indexOf("//") > 0) {
            return false;
        }

        for (int j = path.length(); j > 0;) {
            int i = path.lastIndexOf('/', j - 1);
            int gap = j - i;

            if (gap == 2 && path.charAt(i + 1) == '.') {
                // ".", "/./" or "/."
                return false;
            } else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') {
                return false;
            }

            j = i;
        }

        return true;
    }
}
