package ru.yandex.direct.api.v5.security;

import java.io.IOException;
import java.io.OutputStream;
import java.net.InetAddress;
import java.util.Optional;

import javax.annotation.Nonnull;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.google.common.net.InetAddresses;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.authentication.CredentialsExpiredException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.stereotype.Component;

import ru.yandex.direct.api.v5.context.ApiContext;
import ru.yandex.direct.api.v5.context.ApiContextHolder;
import ru.yandex.direct.api.v5.security.exception.BadCredentialsException;
import ru.yandex.direct.api.v5.security.exception.NoRegistrationException;
import ru.yandex.direct.api.v5.security.exception.TokenAbsentOrHasInvalidFormatException;
import ru.yandex.direct.api.v5.security.ticket.TvmUserTicketAuthProvider;
import ru.yandex.direct.api.v5.security.token.ApiTokenAuthProvider;
import ru.yandex.direct.api.v5.security.token.DirectApiTokenAuthRequest;
import ru.yandex.direct.api.v5.ws.ApiMessage;
import ru.yandex.direct.api.v5.ws.exceptionresolver.ApiExceptionResolver;
import ru.yandex.direct.api.v5.ws.json.JsonMessageFactory;
import ru.yandex.direct.api.v5.ws.soap.SoapMessageFactory;
import ru.yandex.direct.common.util.HttpUtil;
import ru.yandex.direct.env.Environment;
import ru.yandex.direct.env.EnvironmentType;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tvm.TvmIntegration;
import ru.yandex.direct.tvm.TvmService;

import static org.apache.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.util.MimeTypeUtils.APPLICATION_JSON_VALUE;
import static org.springframework.util.MimeTypeUtils.TEXT_XML_VALUE;
import static ru.yandex.direct.api.v5.ws.WsConstants.HEADER_REQUEST_ID;
import static ru.yandex.direct.common.util.HttpUtil.HEADER_X_PROXY_REAL_IP;
import static ru.yandex.direct.common.util.HttpUtil.getHeaderValue;

/**
 * Проверяет токен клиента, и если токен не валиден - сразу выбрасывает ошибку.
 * Дальнейшие шаги при невалидном токене смысла не имеют
 */
@Lazy
@Component
public class AuthenticationFilter implements Filter {
    private static final Logger logger = LoggerFactory.getLogger(AuthenticationFilter.class);
    static final String XML_CONTENT_TYPE = TEXT_XML_VALUE + ";charset=utf-8";

    private final ApiContextHolder apiContextHolder;
    private final TvmUserTicketAuthProvider tvmUserTicketAuthProvider;
    private final ApiTokenAuthProvider apiTokenAuthProvider;
    private final JsonMessageFactory jsonMessageFactory;
    private final SoapMessageFactory soapMessageFactory;
    private final ApiExceptionResolver apiExceptionResolver;
    private final TvmIntegration tvmIntegration;
    private final TvmService blackboxTvmService;

    @Autowired
    public AuthenticationFilter(ApiContextHolder apiContextHolder,
                                TvmUserTicketAuthProvider tvmUserTicketAuthProvider,
                                ApiTokenAuthProvider apiTokenAuthProvider,
                                JsonMessageFactory jsonMessageFactory,
                                SoapMessageFactory soapMessageFactory,
                                ApiExceptionResolver apiExceptionResolver,
                                EnvironmentType environmentType,
                                TvmIntegration tvmIntegration) {
        this.apiContextHolder = apiContextHolder;
        this.tvmUserTicketAuthProvider = tvmUserTicketAuthProvider;
        this.apiTokenAuthProvider = apiTokenAuthProvider;
        this.jsonMessageFactory = jsonMessageFactory;
        this.soapMessageFactory = soapMessageFactory;
        this.apiExceptionResolver = apiExceptionResolver;
        this.tvmIntegration = tvmIntegration;
        blackboxTvmService = environmentType.isProductionOrPrestable()
                ? TvmService.BLACKBOX_PROD
                : TvmService.BLACKBOX_MIMINO;
    }

    @Override
    public void init(FilterConfig filterConfig) {
        // no initialization
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        HttpServletResponse httpServletResponse = (HttpServletResponse) response;
        ApiContext apiContext = apiContextHolder.get();

        try {
            InetAddress userIpAddress = getUserIpAddress(httpServletRequest);
            DirectApiCredentials credentials = new DirectApiCredentials(httpServletRequest, userIpAddress);
            if (StringUtils.isNotEmpty(credentials.getUserTicket())) {
                apiContext.setAuthRequest(tvmUserTicketAuthProvider.authenticate(credentials));
            } else {
                String blackboxTvmTicket = tvmIntegration.getTicket(blackboxTvmService);
                DirectApiTokenAuthRequest authRequest = new DirectApiTokenAuthRequest(credentials, blackboxTvmTicket);
                apiContext.setAuthRequest(apiTokenAuthProvider.authenticate(authRequest));
            }
        } catch (RuntimeException e) {
            if (shouldWriteFullStackTrace(e)) {
                logger.error("Authentication exception caught", e);
            } else {
                logger.info(e.getMessage());
            }
            handleException(httpServletRequest, httpServletResponse, e);
            return;
        }

        chain.doFilter(request, response);
    }

    @Nullable
    private InetAddress getUserIpAddress(HttpServletRequest httpServletRequest) {
        Optional<String> proxyRealIp = getHeaderValue(httpServletRequest, HEADER_X_PROXY_REAL_IP);
        if (proxyRealIp.isPresent()) {
            TvmService tvmService = getTvmService(httpServletRequest);
            if (HttpUtil.acceptXProxyRealApi(tvmService, Environment.getCached())) {
                return InetAddresses.forString(proxyRealIp.get());
            }
        }
        return HttpUtil.getRemoteAddress(httpServletRequest);
    }

    private TvmService getTvmService(HttpServletRequest httpServletRequest) {
        var serviceTicket = HttpUtil.getHeaderValue(httpServletRequest, TvmIntegration.SERVICE_TICKET_HEADER);
        return serviceTicket
                .map(tvmIntegration::getTvmService)
                .orElse(null);
    }

    /**
     * Проверяет есть ли смысл писать полный stacktrace в лог для перехваченного исключения
     */
    private boolean shouldWriteFullStackTrace(RuntimeException e) {
        if (e instanceof TokenAbsentOrHasInvalidFormatException) {
            return false;
        }
        return !(e instanceof BadCredentialsException)
                && !(e instanceof CredentialsExpiredException)
                && !(e instanceof DisabledException)
                && !(e instanceof NoRegistrationException);
    }

    private void handleException(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response,
                                 @Nonnull Exception exception) {
        boolean isJson = Optional.ofNullable(request.getRequestURI()).map(s -> s.startsWith("/json/")).orElse(true);
        ApiMessage message = isJson
                ? jsonMessageFactory.createWebServiceMessage()
                : soapMessageFactory.createWebServiceMessage();
        apiContextHolder.get().setShouldChargeUnitsForRequest(false);
        apiExceptionResolver.resolveException(message, exception);

        response.setStatus(HttpServletResponse.SC_OK);
        response.addHeader(CONTENT_TYPE, isJson ? APPLICATION_JSON_VALUE : XML_CONTENT_TYPE);
        response.addHeader(HEADER_REQUEST_ID, String.valueOf(Trace.current().getSpanId()));
        response.addHeader("X-Java-Implementation", Boolean.TRUE.toString());
        try (OutputStream os = response.getOutputStream()) {
            message.writeTo(os);
        } catch (IOException e) {
            logger.error("Exception in writing exception", e);
        }
    }

    @Override
    public void destroy() {
        // no finalization
    }
}
