package ru.yandex.qe.bus.features.log;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import org.apache.cxf.Bus;
import org.apache.cxf.feature.AbstractFeature;
import org.apache.cxf.interceptor.InterceptorProvider;
import org.apache.cxf.message.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;

import ru.yandex.qe.logging.security.PrivateHeaderSecurityGuard;

public class LogFeature extends AbstractFeature implements ApplicationContextAware {

    private static final Logger LOG = LoggerFactory.getLogger(LogFeature.class);

    private static final String REQUEST_ID_HEADER_KEY = "X-qe-bus-request-id";
    private static final String HOST_NAME_HEADER_KEY = "X-qe-hostname";
    private static final String USER_MDC_KEY = "X-qe-user";

    private String hostName;
    private final String appName;

    private final List<PrivateHeaderSecurityGuard> headerSecurityGuards;

    private ApplicationContext ctx;

    private LogInterceptFormatter formatter = new LogInterceptFormatter.Info();

    public LogFeature(String appName, List<PrivateHeaderSecurityGuard> headerSecurityGuards) {
        this.appName = appName;
        this.headerSecurityGuards = headerSecurityGuards;
    }

    public void setHostName(final String hostName) {
        this.hostName = hostName;
    }

    @Required
    public void setTraceLevel(boolean isTrace) {
        this.formatter = isTrace ? new LogInterceptFormatter.Trace() : new LogInterceptFormatter.Info();
    }

    @Override
    protected void initializeProvider(InterceptorProvider provider, Bus bus) {
        Set<LogInterceptFilter> filters = discoverFilters();

        provider.getInInterceptors().add(new OnReceiveInLogInterceptor(this, filters));
        final PostInvokeInLogInterceptor endingInLogInterceptor = new PostInvokeInLogInterceptor(this);
        provider.getInInterceptors().add(endingInLogInterceptor);
        provider.getInFaultInterceptors().add(endingInLogInterceptor);

        provider.getOutInterceptors().add(new PostLogicOutLogInterceptor(this, filters));
        final SetupEndingOutLogInterceptor endingOutLogInterceptor = new SetupEndingOutLogInterceptor(this);
        provider.getOutInterceptors().add(endingOutLogInterceptor);
        provider.getOutFaultInterceptors().add(endingOutLogInterceptor);
    }

    protected String findReqIdInExchange(Message message) {
        if (message.getExchange().containsKey(REQUEST_ID_HEADER_KEY)) {
            return message.getExchange().get(REQUEST_ID_HEADER_KEY).toString();
        }
        return null;
    }

    protected void putReqIdInExchange(Message message, String requestId) {
        message.getExchange().put(REQUEST_ID_HEADER_KEY, requestId);
    }

    protected String findReqIdInMDC() {
        return MDC.get(REQUEST_ID_HEADER_KEY);
    }

    protected void putReqIdInMDC(String requestId) {
        MDC.put(REQUEST_ID_HEADER_KEY, requestId);
        final Authentication auth = SecurityContextHolder.getContext().getAuthentication();
        if (auth != null) {
            MDC.put(USER_MDC_KEY, auth.getPrincipal().toString());
        }
    }

    protected void removeReqIdFromMDC() {
        MDC.remove(REQUEST_ID_HEADER_KEY);
        MDC.remove(USER_MDC_KEY);
    }

    @SuppressWarnings("unchecked")
    protected String findReqIdInHeaders(Message message) {
        final Map<String, List<String>> headers =
                (Map<String, List<String>>) message.get(Message.PROTOCOL_HEADERS);
        if (headers == null) {
            return null;
        }
        if (headers.containsKey(REQUEST_ID_HEADER_KEY)) {
            return headers.get(REQUEST_ID_HEADER_KEY).get(0);
        }
        return null;
    }

    @SuppressWarnings("unchecked")
    protected void putHostnameAndReqIdInHeaders(Message message, String requestId) {
        Map<String, List<String>> headersMap = (Map<String, List<String>>) message.get(Message.PROTOCOL_HEADERS);
        if (headersMap == null) {
            headersMap = new HashMap<>();
            message.put(Message.PROTOCOL_HEADERS, headersMap);
        }
        headersMap.put(REQUEST_ID_HEADER_KEY, Collections.<String>singletonList(requestId));
        headersMap.put(HOST_NAME_HEADER_KEY, Collections.<String>singletonList(hostName));

    }

    protected void putHostnameAndReqIdWithAppNameInHeaders(Message message, String requestId) {
        final String reqHeaderValue = appName + "-" + requestId;
        putHostnameAndReqIdInHeaders(message, reqHeaderValue);
    }

    protected String generateReqIdIfNull(String requestId) {
        if (requestId != null) {
            return requestId;
        } else {
            Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
            final String auth = authentication != null ? authentication.getName() : null;
            return (auth != null ? auth + "-" : "") + UUID.randomUUID().toString();
        }
    }

    protected StringBuilder preConstructLogLine(Message message, String requestId) {
        final StringBuilder logBuilder = new StringBuilder();

        formatter.formatLine(message, logBuilder, this);

        return logBuilder;
    }

    // Discover any user-defined filters
    private Set<LogInterceptFilter> discoverFilters() {
        HashSet<LogInterceptFilter> filters = new HashSet<>(
                ctx.getBeansOfType(LogInterceptFilter.class).values());

        if (filters.size() > 0) {
            LOG.debug("{} discovered user log request filters: {}", this, filters);
        }

        return filters;
    }

    public Map<String, List<String>> securePrivateData(final Map<String, List<String>> headersMap) {
        final Map<String, List<String>> result = new HashMap<>();
        for (Map.Entry<String, List<String>> entry : headersMap.entrySet()) {
            final String headerName = entry.getKey();
            List<String> securedHeaderValues = entry.getValue();
            for (PrivateHeaderSecurityGuard headerGuard : headerSecurityGuards) {
                securedHeaderValues = headerGuard.secure(headerName, securedHeaderValues);
            }
            result.put(headerName, securedHeaderValues);
        }
        return result;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        ctx = applicationContext;
    }
}
