package ru.yandex.intranet.d.web.log;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.util.pattern.PathPattern;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import ru.yandex.intranet.d.metrics.HttpServerMetrics;
import ru.yandex.intranet.d.util.MdcTaskDecorator;
import ru.yandex.intranet.d.util.OneShotStopWatch;

/**
 * Access log web filter.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
public class AccessLogFilter implements WebFilter, Ordered {

    private static final Logger ACCESS_LOG = LoggerFactory.getLogger("ACCESS_LOG");

    private final AccessLogAttributesProducer accessLogAttributesProducer;
    private final HttpServerMetrics httpServerMetrics;

    public AccessLogFilter(AccessLogAttributesProducer accessLogAttributesProducer,
                           HttpServerMetrics httpServerMetrics) {
        this.accessLogAttributesProducer = accessLogAttributesProducer;
        this.httpServerMetrics = httpServerMetrics;
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE;
    }

    @Override
    @NonNull
    public Mono<Void> filter(@NonNull ServerWebExchange exchange, @NonNull WebFilterChain chain) {
        OneShotStopWatch stopwatch = new OneShotStopWatch();
        accessLogAttributesProducer.addLogId(exchange);
        return chain.filter(exchange)
                .doFinally(signal -> writeLog(exchange, stopwatch))
                .contextWrite(buildContextWithLogId(exchange));
    }

    private Context buildContextWithLogId(ServerWebExchange exchange) {
        return Context.of(AccessLogAttributesProducer.LOG_ID, accessLogAttributesProducer.getLogId(exchange));
    }

    private void writeLog(ServerWebExchange exchange, OneShotStopWatch stopwatch) {
        long elapsedMillis = stopwatch.elapsed(TimeUnit.MILLISECONDS);
        Optional<String> uid = accessLogAttributesProducer.getUid(exchange);
        Optional<String> tvmServiceId = accessLogAttributesProducer.getTvmServiceId(exchange)
                .map(Objects::toString);
        String query = exchange.getRequest().getURI().getQuery() != null
                ? "?" + exchange.getRequest().getURI().getQuery() : "";
        String path = exchange.getRequest().getURI().getPath() + query;
        String method = exchange.getRequest().getMethodValue();
        PathPattern pathPattern = exchange.getAttribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE);
        Optional<String> sourceIp = getSourceIp(exchange).map(this::prepareIp);
        Optional<String> statusCode = Optional.ofNullable(exchange.getResponse().getRawStatusCode())
                .map(Objects::toString);
        long maybeContentLength = exchange.getResponse().getHeaders().getContentLength();
        Optional<String> contentLength = maybeContentLength >= 0
                ? Optional.of(String.valueOf(maybeContentLength)) : Optional.empty();
        Optional<String> pattern = Optional.ofNullable(pathPattern).map(PathPattern::getPatternString);
        Optional<String> remoteIp = Optional.ofNullable(exchange.getRequest().getRemoteAddress())
                .map(InetSocketAddress::getHostString).map(this::prepareIp);
        Optional<String> oauthClientId = accessLogAttributesProducer.getOAuthClientId(exchange);
        Optional<String> oauthClientName = accessLogAttributesProducer.getOAuthClientName(exchange);
        Optional<String> requestId = getHeader(exchange, "X-Request-ID");
        Optional<String> referer = getHeader(exchange, HttpHeaders.REFERER);
        Optional<String> userAgent = getHeader(exchange, HttpHeaders.USER_AGENT);
        Optional<String> forwardedFor = getHeader(exchange, "X-Forwarded-For");
        String logId = accessLogAttributesProducer.getLogId(exchange);
        Map<String, String> mdcMap = new HashMap<>();
        mdcMap.put("access_protocol", "HTTP");
        sourceIp.ifPresent(v -> mdcMap.put("access_source_ip", v));
        remoteIp.ifPresent(v -> mdcMap.put("access_remote_ip", v));
        mdcMap.put("access_http_method", method);
        mdcMap.put("access_http_path", path);
        statusCode.ifPresent(v -> mdcMap.put("access_http_status", v));
        contentLength.ifPresent(v -> mdcMap.put("access_response_size", v));
        mdcMap.put("access_response_time_ms", String.valueOf(elapsedMillis));
        uid.ifPresent(v -> mdcMap.put("access_uid", v));
        tvmServiceId.ifPresent(v -> mdcMap.put("access_tvm_id", v));
        oauthClientId.ifPresent(v -> mdcMap.put("access_oauth_id", v));
        oauthClientName.ifPresent(v -> mdcMap.put("access_oauth_name", v));
        requestId.ifPresent(v -> mdcMap.put("access_http_request_id", v));
        mdcMap.put(MdcTaskDecorator.LOG_ID_MDC_KEY, logId);
        userAgent.ifPresent(v -> mdcMap.put("access_http_user_agent", v));
        referer.ifPresent(v -> mdcMap.put("access_http_referer", v));
        forwardedFor.ifPresent(v -> mdcMap.put("access_http_forwarded_for", v));
        pattern.ifPresent(v -> mdcMap.put("access_http_endpoint", v));
        Map<String, String> previousMdc = MDC.getCopyOfContextMap();
        try {
            MDC.setContextMap(mdcMap);
            ACCESS_LOG.info("{} HTTP {} {} {} {} {} {} {}", sourceIp.orElse("-"), method, path,
                    statusCode.orElse("-"), contentLength.orElse("-"), elapsedMillis,
                    uid.orElse("-"), tvmServiceId.orElse("-"));
        } finally {
            MDC.clear();
            MDC.setContextMap(previousMdc);
        }
        httpServerMetrics.onRequestCompletion(statusCode.orElse(null), elapsedMillis,
                maybeContentLength >= 0 ? maybeContentLength : null);
    }

    private Optional<String> getSourceIp(ServerWebExchange exchange) {
        String forwardedForY = exchange.getRequest().getHeaders().getFirst("X-Forwarded-For-Y");
        if (forwardedForY != null && !forwardedForY.isBlank()) {
            return Optional.of(forwardedForY);
        }
        return Optional.ofNullable(exchange.getRequest().getRemoteAddress()).map(InetSocketAddress::getHostString);
    }

    private Optional<String> getHeader(ServerWebExchange exchange, String header) {
        String requestId = exchange.getRequest().getHeaders().getFirst(header);
        if (requestId != null && !requestId.isBlank()) {
            return Optional.of(requestId);
        }
        return Optional.empty();
    }

    private String prepareIp(String value) {
        // Cut %0 from addresses like 2a02:6b8:0:e00:0:0:0:1a%0
        if (value == null || value.isEmpty()) {
            return value;
        }
        if (value.contains("%")) {
            int lastSeparatorIndex = value.lastIndexOf("%");
            return value.substring(0, lastSeparatorIndex);
        }
        return value;
    }

}
