package ru.yandex.qe.bus.features;

import java.net.InetAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.http.Cookie;

import org.apache.cxf.Bus;
import org.apache.cxf.feature.AbstractFeature;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.interceptor.InterceptorProvider;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.cxf.phase.Phase;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;

import ru.yandex.qe.spring.RequestIdentity;

/**
 * Propagates cookies and headers, which identify the current request, if outgoing message is generated
 * within HTTP request execution.
 * @author lvovich
 */
public class PropagationFeature extends AbstractFeature {

    private String domain;

    private String defaultOauthToken;

    private static final Logger LOG = LoggerFactory.getLogger(PropagationFeature.class);

    @Override
    protected void initializeProvider(final InterceptorProvider provider, final Bus bus) {
        // propagation interceptor must be the first, to come before LoggingFeature, otherwise we would see not updated headers in log
        provider.getOutInterceptors().add(0, new OutPropagationInterceptor());
        provider.getInInterceptors().add(new InPropagationInterceptor());
        LOG.info("out interceptor added");
    }

    public void setDefaultOauthToken(final String defaultOauthToken) {
        this.defaultOauthToken = defaultOauthToken;
    }

    @Required
    public void setDomain(final String domain) {
        this.domain = domain;
    }

    private class OutPropagationInterceptor extends AbstractPhaseInterceptor<Message> {

        private OutPropagationInterceptor() {
            super(Phase.POST_LOGICAL);
        }

        @Override
        public void handleMessage(final Message message) throws Fault {
            if (isRequestor(message)) {
                Map<String, List<String>> headersMap = (Map<String, List<String>>) message.get(Message.PROTOCOL_HEADERS);
                if (headersMap == null) {
                    headersMap = new HashMap<>();
                    message.put(Message.PROTOCOL_HEADERS, headersMap);
                }
                final String requestUri = String.valueOf(message.get(Message.REQUEST_URI));
                final RequestIdentity requestIdentity = RequestIdentity.get();
                final String remoteAddress = requestIdentity.getRemoteAddress();
                if (remoteAddress == null && requestIdentity.getHeader("Authorization") == null) {
                    // not HTTP-request and not Robot context;
                    // use default oauth token if defined
                    if (defaultOauthToken != null) {
                        final String authorization = "OAuth " + defaultOauthToken;
                        headersMap.put("Authorization", Collections.singletonList(authorization));
                        return;
                    }
                }
                final StringBuilder cookieBuilder = new StringBuilder();
                {
                        // propagate cookies only to correct domain
                        String host = null;
                        try {
                            host = new URI(requestUri).getHost();
                        } catch (URISyntaxException e) {
                            LOG.warn("Error parsing URI " + requestUri);
                            // ignore; host remains null
                        }
                        if (host != null && (host.endsWith("." + domain))) {
                            {
                                final Cookie cookie = requestIdentity.getSessionCookie();
                                if (cookie != null) {
                                    cookieBuilder.append(cookie.getName() + "=" + cookie.getValue() + ";");
                                }
                            }
                            {
                                final Cookie cookie = requestIdentity.getSslSessionCookie();
                                if (cookie != null) {
                                    cookieBuilder.append(cookie.getName() + "=" + cookie.getValue() + ";");
                                }
                            }
                        }
                }
                if (cookieBuilder.length() > 0) {
                    headersMap.put("Cookie", Collections.singletonList(cookieBuilder.toString()));
                }

                final String authorization = requestIdentity.getHeader("Authorization");
                if (authorization != null) {
                    headersMap.put("Authorization", Collections.singletonList(authorization));
                }
                // do not set X-Forwarded-For for loopbacks
                // it may happen in development environment, where front calls backend via loopback interface
                try {
                    if (InetAddress.getByName(remoteAddress).isLoopbackAddress()) {
                        return;
                    }
                } catch (UnknownHostException e) {
                    // ignore
                }
                final String forwardedFor = requestIdentity.getHeader("X-Forwarded-For");
                final String forward = forwardedFor == null ? remoteAddress : forwardedFor + "," + remoteAddress;
                headersMap.put("X-Forwarded-For", Collections.singletonList(forward));
            }

        }
    }

    private class InPropagationInterceptor extends AbstractPhaseInterceptor<Message> {

        private InPropagationInterceptor() {
            super(Phase.RECEIVE);
        }

        @Override
        public void handleMessage(final Message message) throws Fault {
            if (isRequestor(message)) {
                final Map<String, List<String>> headersMap = (Map<String, List<String>>) message.get(Message.PROTOCOL_HEADERS);
                if (headersMap == null) {
                    return;
                }
                final List<String> refreshHeaders = headersMap.get("X-Yandex-Refresh-Cookie");
                final RequestIdentity newIdentity = RequestIdentity.get();
                boolean identityChanged = false;
                if (refreshHeaders != null) {
                     for (final String header: refreshHeaders) {
                         newIdentity.setHeader("X-Yandex-Refresh-Cookie", header);
                         identityChanged = true;
                     }
                 }
                if (identityChanged) {
                    newIdentity.enter();
                }
            }
        }
    }
}
