package ru.yandex.http.util.server;

import java.net.InetAddress;
import java.net.NetworkInterface;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.HttpStatus;
import org.apache.http.HttpVersion;
import org.apache.http.message.BasicHttpRequest;
import org.apache.http.protocol.HttpContext;
import org.apache.http.protocol.HttpCoreContext;

import ru.yandex.client.tvm2.AuthResult;
import ru.yandex.client.tvm2.Tvm2ServiceContextRenewalTask;
import ru.yandex.client.tvm2.UserAuthResult;
import ru.yandex.collection.PatternMap;
import ru.yandex.function.DuplexConsumer;
import ru.yandex.function.GenericAutoCloseable;
import ru.yandex.http.util.RequestToString;
import ru.yandex.http.util.request.RequestHandlerMapper;
import ru.yandex.http.util.request.RequestInfo;
import ru.yandex.http.util.request.function.RequestFunctionValue;
import ru.yandex.io.StringBuilderWriter;
import ru.yandex.logger.IdGenerator;
import ru.yandex.logger.PrefixedLogger;
import ru.yandex.passport.tvmauth.CheckedServiceTicket;
import ru.yandex.passport.tvmauth.CheckedUserTicket;
import ru.yandex.stater.RequestsStater;
import ru.yandex.util.string.StringUtils;
import ru.yandex.util.timesource.TimeSource;

public class SessionContext implements HttpRequestInterceptor {
    private static final HttpRequest GLOBAL_STATER_REQUEST =
        new BasicHttpRequest(
            RequestHandlerMapper.GET,
            "/_global_stater_",
            HttpVersion.HTTP_1_1);

    private static final String PROCESSING = "Processing request: ";
    private static final String FOR_CONNECTION = " for connection: ";
    private static final String AUTHORIZATION_FAILED =
        "Authorization failed: ";
    private static final String USER_AUTHORIZATION_FAILED =
        "User authorization failed: ";

    public static final String HOSTNAME;

    private static final Set<InetAddress> LOOPBACK_ADDRESSES = new HashSet<>();
    static {
        // 1. Get current host canonical host name
        // 2. Iterate over all network interfaces and all their addresses
        // 3. Populate set with loopback addresses and addresses which
        //    canonical hostname are same as current host hostname
        try {
            HOSTNAME = InetAddress.getLocalHost().getCanonicalHostName();
            Enumeration<NetworkInterface> ifaces =
                NetworkInterface.getNetworkInterfaces();
            // May return null if there is not interfaces at all
            if (ifaces != null) {
                while (ifaces.hasMoreElements()) {
                    NetworkInterface iface = ifaces.nextElement();
                    Enumeration<InetAddress> addrs = iface.getInetAddresses();
                    while (addrs.hasMoreElements()) {
                        InetAddress addr = addrs.nextElement();
                        if (addr.isLoopbackAddress()
                            || HOSTNAME.equals(addr.getCanonicalHostName()))
                        {
                            LOOPBACK_ADDRESSES.add(addr);
                        }
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static final int TIMESTAMP_ID_LENGTH = 6;
    public static final int SESSION_ID_LENGTH = 6;
    public static final int HOST_ID_LENGTH =
        15 - TIMESTAMP_ID_LENGTH - SESSION_ID_LENGTH;

    public static final String INSTANCE_ID =
        new IdGenerator(
            HOSTNAME.hashCode() & 0xffffffffL,
            HOST_ID_LENGTH,
            IdGenerator.MAX_RADIX)
            .next()
        + new IdGenerator(
            TimeSource.INSTANCE.currentTimeMillis(),
            TIMESTAMP_ID_LENGTH,
            IdGenerator.MAX_RADIX)
          .next();

    private final IdGenerator idGenerator = new IdGenerator(
        new SecureRandom().nextLong(),
        SESSION_ID_LENGTH,
        IdGenerator.MAX_RADIX);
    private final ServerConfigProvider<ImmutableBaseServerConfig, BaseServerDynamicConfig> configProvider;
    private final BaseServerDynamicConfig dynamicConfig;
    private final RequestHandlerMapper<?> requestHandlerMapper;
    private final Tvm2ServiceContextRenewalTask serviceContextRenewalTask;
    private final PatternMap<RequestInfo, PrefixedLogger> loggers;
    private final PatternMap<RequestInfo, Logger> accessLoggers;
    private final Consumer<ru.yandex.stater.RequestInfo> globalStater;
    private final PatternMap<RequestInfo, ImmutableAuthConfig> auths;
    private final Set<String> hiddenHeaders;

    public SessionContext(
        final RequestHandlerMapper<?> requestHandlerMapper,
        final Tvm2ServiceContextRenewalTask serviceContextRenewalTask,
        final ServerConfigProvider<ImmutableBaseServerConfig, BaseServerDynamicConfig> provider)
    {
        this.configProvider = provider;
        this.dynamicConfig = provider.dynamicConfig();
        ImmutableBaseServerConfig staticConfig = provider.staticConfig();

        this.requestHandlerMapper = requestHandlerMapper;
        this.serviceContextRenewalTask = serviceContextRenewalTask;

        loggers = staticConfig.loggers().preparedLoggers();
        accessLoggers = staticConfig.loggers().preparedAccessLoggers();

        PatternMap<RequestInfo, RequestsStater> staters =
            staticConfig.staters().preparedStaters();
        RequestsStater globalStater =
            staters.get(new RequestInfo(GLOBAL_STATER_REQUEST));
        if (globalStater == staters.asterisk()) {
            this.globalStater = null;
        } else {
            this.globalStater = globalStater;
        }

        auths = staticConfig.auths().auths();

        hiddenHeaders = staticConfig.hiddenHeaders();
    }

    @Override
    public void process(final HttpRequest request, final HttpContext context) {
        String id = INSTANCE_ID + idGenerator.next();
        RequestInfo requestInfo = new RequestInfo(request);
        Logger logger = loggers.get(requestInfo).addPrefix(id);
        Logger accessLogger = accessLoggers.get(requestInfo);
        RequestsStater stater =
            configProvider.requestsStaters().get(requestInfo);
        Consumer<ru.yandex.stater.RequestInfo> requestInfoConsumer;
        if (globalStater != null && !stater.prefix().startsWith("ignore")) {
            requestInfoConsumer = new DuplexConsumer<>(globalStater, stater);
        } else {
            requestInfoConsumer = stater;
        }

        Limiter limiter =
            dynamicConfig.limiters().preparedLimiters().get(requestInfo);

        LoggingServerConnection conn = (LoggingServerConnection)
            context.getAttribute(HttpCoreContext.HTTP_CONNECTION);

        RequestFunctionValue limiterKey;
        Object overriddenHandler = null;
        if (limiter.enabled() && !(limiter.bypassLoopback()
            && LOOPBACK_ADDRESSES.contains(conn.getRemoteAddress())))
        {
            try {
                limiterKey = limiter.key(request);
            } catch (ExecutionException e) {
                logger.log(
                    Level.SEVERE,
                    "Limiter key function execution error",
                    e);
                StringBuilderWriter sbw = new StringBuilderWriter();
                e.printStackTrace(sbw);
                overriddenHandler =
                    requestHandlerMapper.dummyHandler(
                        HttpStatus.SC_SERVICE_UNAVAILABLE,
                        sbw.toString());
                limiterKey = null;
                limiter = null;
            }
        } else {
            limiter = null;
            limiterKey = null;
        }
        context.setAttribute(HttpServer.SESSION_ID, id);
        context.setAttribute(HttpServer.REQUEST_INFO, requestInfo);
        context.setAttribute(HttpServer.LOGGER, logger);

        if (logger.isLoggable(Level.FINE)) {
            String connString = conn.toString();
            StringBuilder sb = requestToStringBuilder(
                PROCESSING,
                request,
                FOR_CONNECTION.length() + connString.length());
            sb.append(FOR_CONNECTION);
            sb.append(connString);
            logger.fine(new String(sb));
        }

        GenericAutoCloseable<RuntimeException> resourcesReleaser;
        if (limiter == null) {
            resourcesReleaser = null;
        } else {
            long contentLength;
            if (request instanceof HttpEntityEnclosingRequest) {
                contentLength = ((HttpEntityEnclosingRequest) request)
                    .getEntity().getContentLength();
            } else {
                contentLength = -1;
            }
            LimiterResult result = limiter.acquire(contentLength, limiterKey);
            resourcesReleaser = result.resourcesReleaser();
            String message = result.message();
            if (message != null) {
                logger.warning("Request rejected by limiter. " + message);
                overriddenHandler = requestHandlerMapper.dummyHandler(
                    limiter.errorStatusCode(),
                    message);
            }
        }
        conn.setRequestContext(
            accessLogger,
            requestInfoConsumer,
            resourcesReleaser,
            context);
        if (overriddenHandler != null) {
            context.setAttribute(
                HttpServer.OVERRIDDEN_HANDLER,
                overriddenHandler);
            return;
        }

        String authInfo = null;
        boolean authFailed = false;
        ImmutableAuthConfig auth = auths.get(requestInfo);
        if (!auth.disabled()
            && !(auth.bypassLoopback()
                && LOOPBACK_ADDRESSES.contains(conn.getRemoteAddress())))
        {
            AuthResult authResult =
                serviceContextRenewalTask.checkAuthorization(
                    request,
                    auth.headerName(),
                    auth.allowedSrcs());
            CheckedServiceTicket ticket = authResult.ticket();
            if (ticket == null) {
                String errorDescription = authResult.errorDescription();
                authInfo = "serv:UNAUTHORIZED";
                logger.warning(
                    AUTHORIZATION_FAILED + errorDescription
                    + ", strict auth: " + auth.strict());
                if (auth.strict()) {
                    authFailed = true;
                    context.setAttribute(
                        HttpServer.OVERRIDDEN_HANDLER,
                        requestHandlerMapper.dummyHandler(
                            HttpStatus.SC_UNAUTHORIZED,
                            errorDescription));
                } else {
                    context.setAttribute(
                        HttpServer.WARNING_MESSAGE,
                        AUTHORIZATION_FAILED + errorDescription);
                }
            } else {
                int src = ticket.getSrc();
                context.setAttribute(HttpServer.TVM_SRC_ID, src);
                authInfo = "serv:" + src;
            }
        }

        String userAuthInfo = null;
        UserTicketPresence userTicketPresence = auth.userTicketPresence();
        if (userTicketPresence != UserTicketPresence.IGNORED) {
            UserAuthResult userAuthResult =
                serviceContextRenewalTask.checkUserAuthorization(
                    request,
                    auth.userTicketHeaderName(),
                    userTicketPresence == UserTicketPresence.REQUIRED);
            if (userAuthResult != null) {
                String errorDescription = userAuthResult.errorDescription();
                if (errorDescription == null) {
                    CheckedUserTicket ticket = userAuthResult.ticket();
                    long uid = ticket.getDefaultUid();
                    if (uid == 0L) {
                        errorDescription =
                            "default uid is 0, all ticket uids: "
                            + Arrays.toString(ticket.getUids());
                    } else {
                        context.setAttribute(HttpServer.TVM_USER_UID, uid);
                        userAuthInfo = "user:" + uid;
                    }
                }
                if (errorDescription != null) {
                    userAuthInfo = "user:UNAUTHORIZED";
                    logger.warning(
                        USER_AUTHORIZATION_FAILED + errorDescription);
                    if (!authFailed) {
                        // Do not suppress original error if service ticked
                        // validation already failed
                        context.setAttribute(
                            HttpServer.OVERRIDDEN_HANDLER,
                            requestHandlerMapper.dummyHandler(
                                HttpStatus.SC_UNAUTHORIZED,
                                USER_AUTHORIZATION_FAILED + errorDescription));
                    }
                }
            }
        }
        String sessionUser =
            StringUtils.concatOrSet(authInfo, '/', userAuthInfo);
        if (sessionUser != null) {
            context.setAttribute(HttpServer.SESSION_USER, sessionUser);
        }
    }

    public StringBuilder requestToStringBuilder(
        final String prefix,
        final HttpRequest request,
        final int suffixLength)
    {
        return RequestToString.requestToStringBuilder(
            prefix,
            request,
            suffixLength,
            hiddenHeaders);
    }
}

