package ru.yandex.travel.http;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.servlet.DispatcherType;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import io.opentracing.Tracer;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.springframework.boot.web.servlet.error.DefaultErrorAttributes;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.MultiValueMap;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
public class ReqAnsLoggerInterceptor implements HandlerInterceptor {
    public static final String ATTR_NAME = "REQANS_LOGGED_REQUEST";
    public static final String ADDITIONAL_INFO_ATTR_NAME = "REQANS_ADDITIONAL_INFO";
    private static final String REQUEST_ID_HEADER = "X-Request-Id";
    private static final Logger httpLog = org.slf4j.LoggerFactory.getLogger("ru.yandex.travel.http.ReqAns");

    private final ObjectMapper objectMapper;
    private final Tracer tracer;
    private final List<String> headersToMask;

    public ReqAnsLoggerInterceptor(ObjectMapper objectMapper, Tracer tracer, List<String> headersToMask) {
        this.objectMapper = objectMapper;
        this.tracer = tracer;
        if (headersToMask != null) {
            this.headersToMask = ImmutableList.copyOf(headersToMask);
        } else {
            this.headersToMask = List.of();
        }
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (request.getDispatcherType() == DispatcherType.REQUEST) {
            ReqAnsEvent loggedRequest = new ReqAnsEvent(request);
            if (tracer != null && tracer.scopeManager().activeSpan() != null) {
                tracer.scopeManager().activeSpan().setTag(REQUEST_ID_HEADER, loggedRequest.getRequestId());
            }
            if (request.getAttribute(ATTR_NAME) == null) {
                request.setAttribute(ATTR_NAME, loggedRequest);
            } else {
                log.warn("Duplicate logging event, this should not happen");
            }
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler,
                                Exception ex) throws Exception {
        ReqAnsEvent loggedRequest = (ReqAnsEvent) request.getAttribute(ATTR_NAME);
        if (loggedRequest != null) {
            loggedRequest.addRequestBody(request);
            loggedRequest.addResponse(request, response);
            loggedRequest.additionalInfo = request.getAttribute(ADDITIONAL_INFO_ATTR_NAME);
            httpLog.info(objectMapper.writerFor(ReqAnsEvent.class).writeValueAsString(loggedRequest));
        } else {
            log.warn("No logging event");
        }
    }


    @Getter
    class ReqAnsEvent {
        private String requestId;
        private String method;
        private String url;
        private MultiValueMap<String, String> queryParams;
        private String remoteIp;
        private HttpHeaders requestHeaders;
        private long requestContentLength;
        private long requestTimestamp;
        private Object requestBody;
        private Object additionalInfo;

        private long responseContentLength;
        private int statusCode;
        private long responseTimestamp;
        private Object responseBody;
        private String errorReason;
        private Map<String, Object> error;

        public ReqAnsEvent(HttpServletRequest request) {
            HttpHeaders headers = new ServletServerHttpRequest(request).getHeaders();
            String requestId = null;
            if (headers.containsKey(REQUEST_ID_HEADER)) {
                requestId = headers.getFirst(REQUEST_ID_HEADER);
            }
            headersToMask.forEach(header -> {
                if (headers.containsKey(header)) {
                    List<String> values = headers.remove(header);
                    headers.put(header, values.stream().map(value -> mask(value, 10, 4)).collect(Collectors.toList()));
                }
            });
            this.requestId = requestId;
            this.url = request.getRequestURI();
            this.method = request.getMethod();
            try {
                this.queryParams =
                        UriComponentsBuilder.fromUriString(request.getRequestURL() + "?" + request.getQueryString()).build().getQueryParams();
            } catch (IllegalArgumentException ignored) {
            }
            this.remoteIp = request.getRemoteAddr();
            this.requestHeaders = headers;
            this.requestTimestamp = System.currentTimeMillis();
        }

        public void addRequestBody(HttpServletRequest request) throws UnsupportedEncodingException {
            this.requestContentLength = request.getContentLength();
            if (request instanceof ContentCachingRequestWrapper) {
                ContentCachingRequestWrapper requestWrapper = (ContentCachingRequestWrapper) request;
                byte[] bodyBytes = requestWrapper.getContentAsByteArray();
                if (bodyBytes.length > 0) {
                    String bodyString = new String(bodyBytes, requestWrapper.getCharacterEncoding());
                    try {
                        this.requestBody = objectMapper.readerFor(Map.class).<Map<String, Object>>readValue(bodyString);
                    } catch (IOException ignored) {
                        this.requestBody = bodyString;
                    }
                }
            }
        }

        public void addResponse(HttpServletRequest request, HttpServletResponse response) throws UnsupportedEncodingException {
            if (response instanceof CustomContentCachingResponseWrapper) {
                this.statusCode = response.getStatus();
                CustomContentCachingResponseWrapper responseWrapper = (CustomContentCachingResponseWrapper) response;
                this.responseContentLength = responseWrapper.getContentSize();
                this.responseTimestamp = System.currentTimeMillis();
                this.errorReason = responseWrapper.getErrorReason();
                byte[] bodyBytes = responseWrapper.getContentAsByteArray();
                if (bodyBytes.length > 0) {
                    String bodyString = new String(bodyBytes, responseWrapper.getCharacterEncoding());
                    try {
                        this.responseBody = objectMapper.readerFor(Map.class).readValue(bodyString);
                    } catch (IOException ignored) {
                        this.responseBody = bodyString;
                    }
                }
                var errorAttrs = new DefaultErrorAttributes();
                ServletWebRequest webRequest = new ServletWebRequest(request);
                if (errorAttrs.getError(webRequest) != null) {
                    this.error = errorAttrs.getErrorAttributes(webRequest, true);
                }
            }
        }

        private String mask(String input, int startLength, int endLength) {
            if (Strings.isNullOrEmpty(input)) {
                return "<none>";
            }
            if (input.length() < (startLength + endLength) * 2) {
                return "<too short to mask>";
            }
            String begin = input.substring(0, startLength);
            String end = input.substring(input.length() - endLength);
            return begin + "****" + end;
        }

        public long getTimeTaken() {
            return responseTimestamp - requestTimestamp;
        }
    }
}
