package ru.yandex.direct.common.tracing;

import javax.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.ui.ModelMap;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.context.request.WebRequestInterceptor;

import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceGuard;
import ru.yandex.direct.tracing.TraceHelper;
import ru.yandex.direct.tracing.util.TraceUtil;

@SuppressWarnings("ThreadLocalUsage")
public class TraceContextInterceptor implements WebRequestInterceptor {
    private final ThreadLocal<TraceGuard> currentGuard = new ThreadLocal<>();

    @Autowired
    private TraceHelper traceHelper;

    @Override
    public void preHandle(WebRequest request) throws Exception {
        String header = request.getHeader(TraceUtil.X_YANDEX_TRACE);
        String method = null;
        if (request instanceof NativeWebRequest) {
            HttpServletRequest httpRequest = ((NativeWebRequest) request).getNativeRequest(HttpServletRequest.class);
            if (httpRequest != null) {
                method = httpRequest.getPathInfo();
            }
        }
        Trace trace = TraceUtil.traceFromHeader(header, traceHelper.getService(), method);
        currentGuard.set(traceHelper.guard(trace));
    }

    @Override
    public void postHandle(WebRequest request, ModelMap model) throws Exception {
        // everything is handled in afterCompletion
    }

    @Override
    public void afterCompletion(WebRequest request, Exception ex) throws Exception {
        TraceGuard guard = currentGuard.get();
        if (guard != null) {
            currentGuard.remove();
            guard.close();
        }
    }
}
