package ru.yandex.direct.common.logging;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.fasterxml.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.WebUtils;

import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.util.ThreadUsedResources;
import ru.yandex.direct.tracing.util.ThreadUsedResourcesProvider;
import ru.yandex.direct.tracing.util.TraceUtil;
import ru.yandex.direct.utils.JsonUtils;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Strings.isNullOrEmpty;
import static ru.yandex.direct.tvm.TvmIntegration.TVM_SERVICE_ID_ATTRIBUTE_NAME;
import static ru.yandex.direct.utils.JsonUtils.MAPPER;

@ParametersAreNonnullByDefault
public abstract class LoggingFilter<T extends LogRecord> implements Filter {
    private static final Charset LOG_BODY_CHARSET = StandardCharsets.UTF_8;
    private final Logger requestLogger;

    private final LoggingSettings loggingDefaults;
    protected final LogRecordHolder<T> logRecordHolder;

    private final ThreadUsedResourcesProvider threadUsedResourcesProvider;

    public LoggingFilter(String loggerName, LoggingSettings loggingDefaults, LogRecordHolder<T> logRecordHolder) {
        checkArgument(!isNullOrEmpty(loggerName), "loggerName cannot be empty");
        this.requestLogger = LoggerFactory.getLogger(loggerName);
        this.loggingDefaults = loggingDefaults;
        this.logRecordHolder = logRecordHolder;
        this.threadUsedResourcesProvider = ThreadUsedResourcesProvider.instance();
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        // no initialization
    }

    @Override
    public void destroy() {
        // no finalization
    }

    protected abstract void initLogRecord(long requestId, HttpServletRequest requestToUse);

    protected void beforeLogCall(HttpServletRequest request) { }

    @Override
    @SuppressWarnings("CheckReturnValue")
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest requestToUse = (HttpServletRequest) request;
        HttpServletResponse responseToUse = (HttpServletResponse) response;
        if (!(request instanceof ContentCachingRequestWrapper)) {
            requestToUse = new ContentCachingRequestWrapper(requestToUse);
        }

        if (!(response instanceof ProxyResponseWrapper)) {
            responseToUse = new ProxyResponseWrapper(responseToUse);
        }

        ThreadUsedResources startResources = threadUsedResourcesProvider.getCurrentThreadCpuTime();
        initLogRecord(Trace.current().getSpanId(), requestToUse);
        checkNotNull(logRecordHolder.getLogRecord(), "log record must be initialized");

        try {
            chain.doFilter(requestToUse, responseToUse);
            responseToUse.flushBuffer();
        } finally {
            beforeLogCall(requestToUse);
            logCall(requestToUse, responseToUse, startResources);
        }
    }

    private void logCall(HttpServletRequest request, HttpServletResponse response, ThreadUsedResources startResources) {
        LoggingSettings settings = LoggingConfigurationUtils.findConfig(request)
                .map(loggingDefaults::withConfig)
                .orElse(loggingDefaults);
        if (!settings.enabled().isNeeded(response)) {
            return;
        }

        LogRecord logRecord = logRecordHolder.getLogRecord();

        if (settings.logRequestBody().isNeeded(response)) {
            ContentCachingRequestWrapper requestWrapper =
                    WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class);
            setRequestData(settings, requestWrapper, logRecord);
        }

        if (settings.logResponseBody().isNeeded(response)) {
            ProxyResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, ProxyResponseWrapper.class);
            if (responseWrapper != null) {
                logRecord.setResponse(responseWrapper.getContentString(LOG_BODY_CHARSET));
            }
        }

        logRecord.setHttpStatus(response.getStatus());
        logRecord.setRuntime(Trace.current().elapsed());
        ThreadUsedResources currentResources = threadUsedResourcesProvider.getCurrentThreadCpuTime();
        logRecord.setCpuUserTime(
                TraceUtil.secondsFromNanoseconds(currentResources.getCpuTime() - startResources.getCpuTime()));
        logRecord.setTvmServiceId((Integer) request.getAttribute(TVM_SERVICE_ID_ATTRIBUTE_NAME));

        requestLogger.info("{} {}", logRecord.getPrefixLogTime(), JsonUtils.toJson(logRecord));
    }

    private void setRequestData(LoggingSettings settings, @Nullable ContentCachingRequestWrapper requestWrapper,
                                LogRecord logRecord) {
        if (requestWrapper == null) {
            return;
        }

        if (!hasJsonBody(requestWrapper)) {
            return;
        }

        byte[] bytes = requestWrapper.getContentAsByteArray();
        if (bytes.length == 0) {
            return;
        }
        JsonNode jsonBody = safeParseBodyOrNull(settings, bytes);
        if (jsonBody != null) {
            logRecord.setRequestObject(jsonBody);
        } else {
            int length = Math.min(bytes.length, settings.maxRequestBodySize());
            String requestObject = new String(bytes, 0, length, LOG_BODY_CHARSET);
            logRecord.setRequestObject(requestObject);
        }
    }

    /**
     * @return {@code true}, если {@code Content-Type} запроса говорит о наличии Json в теле.
     */
    private boolean hasJsonBody(ServletRequest request) {
        try {
            return MediaType.APPLICATION_JSON.includes(MediaType.valueOf(request.getContentType()));
        } catch (InvalidMediaTypeException e) {
            return false;
        }
    }

    /**
     * @return Разобранное тело Json-запроса из {@code bytes}.
     * {@code null} в случае неуспеха.
     */
    private JsonNode safeParseBodyOrNull(
            LoggingSettings settings,
            byte[] bytes
    ) {
        try {
            if (bytes.length <= settings.maxRequestBodySize()) {
                return MAPPER.readTree(bytes);
            }
        } catch (IOException e) {
            return null;
        }
        return null;
    }
}
