package ru.yandex.intranet.d.web.errors;

import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.web.WebProperties;
import org.springframework.boot.autoconfigure.web.reactive.error.AbstractErrorWebExceptionHandler;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.reactive.error.ErrorAttributes;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.lang.NonNull;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import ru.yandex.intranet.d.web.log.AccessLogAttributesProducer;

import static org.springframework.web.reactive.function.server.RequestPredicates.all;
import static org.springframework.web.reactive.function.server.RouterFunctions.route;

/**
 * Web exception handler.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class YaErrorWebExceptionHandler extends AbstractErrorWebExceptionHandler {

    private static final Logger LOG = LoggerFactory.getLogger(YaErrorWebExceptionHandler.class);

    private final AccessLogAttributesProducer accessLogAttributesProducer;

    public YaErrorWebExceptionHandler(ErrorAttributes errorAttributes, WebProperties.Resources resourceProperties,
                                      ApplicationContext applicationContext,
                                      AccessLogAttributesProducer accessLogAttributesProducer) {
        super(errorAttributes, resourceProperties, applicationContext);
        this.accessLogAttributesProducer = accessLogAttributesProducer;
    }

    @Override
    public Mono<Void> handle(ServerWebExchange exchange, Throwable throwable) {
        return super.handle(exchange, throwable).contextWrite(ctx ->
                ctx.put(AccessLogAttributesProducer.LOG_ID, accessLogAttributesProducer.getLogId(exchange)));
    }

    @Override
    protected RouterFunction<ServerResponse> getRoutingFunction(ErrorAttributes errorAttributes) {
        return route(all(), this::renderErrorResponse);
    }

    @Override
    protected void logError(ServerRequest request, ServerResponse response, Throwable throwable) {
        if (LOG.isDebugEnabled()) {
            LOG.debug(formatError(throwable, request));
        }
        if (HttpStatus.resolve(response.rawStatusCode()) != null
                && response.statusCode().equals(HttpStatus.INTERNAL_SERVER_ERROR)) {
            LOG.error("500 Server Error for " + formatRequest(request), throwable);
        }
    }

    @NonNull
    private Mono<ServerResponse> renderErrorResponse(ServerRequest request) {
        Map<String, Object> error = getErrorAttributes(request,
                ErrorAttributeOptions.of(ErrorAttributeOptions.Include.MESSAGE));
        return ServerResponse.status(getHttpStatus(error)).contentType(MediaType.APPLICATION_JSON)
                .body(BodyInserters.fromValue(error));
    }

    private int getHttpStatus(Map<String, Object> errorAttributes) {
        return (int) errorAttributes.get("status");
    }

    private String formatError(Throwable ex, ServerRequest request) {
        String reason = ex.getClass().getSimpleName() + ": " + ex.getMessage();
        return "Resolved [" + reason + "] for HTTP " + request.methodName() + " " + request.path();
    }

    private String formatRequest(ServerRequest request) {
        String rawQuery = request.uri().getRawQuery();
        String query = StringUtils.hasText(rawQuery) ? "?" + rawQuery : "";
        return "HTTP " + request.methodName() + " \"" + request.path() + query + "\"";
    }

}
