package ru.yandex.chemodan.grpc.server.interceptors;

import java.io.IOException;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Objects;

import io.grpc.Context;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.chemodan.grpc.GrpcContextCommonKeys;
import ru.yandex.chemodan.grpc.TvmGrpcHeaders;
import ru.yandex.inside.passport.tvm2.Tvm2;
import ru.yandex.inside.passport.tvm2.common.Tvm2AuthenticationHelper;
import ru.yandex.inside.passport.tvm2.exceptions.MissingTvmServiceTicketException;
import ru.yandex.inside.passport.tvm2.exceptions.TvmBaseException;
import ru.yandex.inside.passport.tvm2.web.Tvm2CheckingMode;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;

public class Tvm2GrpcServerInterceptor implements ServerInterceptor {

    private static final Logger logger = LoggerFactory.getLogger(Tvm2GrpcServerInterceptor.class);

    private final Tvm2AuthenticationHelper tvm2AuthenticationHelper;

    private final Tvm2CheckingMode tvm2Mode;

    private final SetF<String> localAddresses = getLocalAddresses();

    public Tvm2GrpcServerInterceptor(Tvm2 tvm2, Tvm2CheckingMode tvm2Mode) {
        this.tvm2AuthenticationHelper = new Tvm2AuthenticationHelper(tvm2);
        this.tvm2Mode = tvm2Mode;
    }

    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
            Metadata headers, ServerCallHandler<ReqT, RespT> next)
    {
        Option<String> uidO = Option.ofNullable(headers.get(TvmGrpcHeaders.TVM_UID));
        if (!tvm2AuthenticationHelper.isInitialized() ||
                isFromLocalAddress(call, headers) ||
                isHealthCheck(call) ||
                checkAuthentication(call, headers, uidO)) {
            return next.startCall(call, headers);
        }
        throw new StatusRuntimeException(Status.UNAUTHENTICATED);
    }

    private <ReqT, RespT> boolean checkAuthentication(ServerCall<ReqT, RespT> call,
            Metadata headers, Option<String> uidO)
    {
        try {
            tvm2AuthenticationHelper.checkAuthentication(Option.ofNullable(headers.get(
                    TvmGrpcHeaders.TVM2_SERVICE_TICKET)),
                    Option.ofNullable(headers.get(TvmGrpcHeaders.TVM2_USER_TIKET)), uidO,
                    srcClientId -> headers.put(TvmGrpcHeaders.TVM_CLIENT_ID, String.valueOf(srcClientId)));
            return true;
        } catch (MissingTvmServiceTicketException mtse) {
            logWarn(call, headers, mtse);
            return tvm2Mode == Tvm2CheckingMode.IF_PRESENT || tvm2Mode == Tvm2CheckingMode.ONLY_WARNINGS;
        } catch (TvmBaseException e) {
            logWarn(call, headers, e);
            return tvm2Mode == Tvm2CheckingMode.ONLY_WARNINGS;
        }
    }

    private <ReqT, RespT> void logWarn(ServerCall<ReqT, RespT> call, Metadata metadata, TvmBaseException e) {
        Context context = Context.current();
        logger.warn(
                "Access forbidden for address = {}, gRPC method name = {}, ycrid={}, rid={}, metadata={}, exception={}",
                Option.ofNullable(call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)),
                call.getMethodDescriptor().getFullMethodName(),
                Option.ofNullable(GrpcContextCommonKeys.YCRID.get(context)).getOrElse("-"),
                Option.ofNullable(GrpcContextCommonKeys.REQUEST_ID.get(context)).getOrElse("-"),
                LoggingGrpcServerInterceptor.getHeaders(metadata),
                e.getMessage());
    }

    private <ReqT, RespT> boolean isFromLocalAddress(ServerCall<ReqT, RespT> call, Metadata headers) {
        return localAddresses.containsTs(getClientIpFromCall(call)) && headers.get(TvmGrpcHeaders.TVM_REQUIRED) == null;
    }

    private <ReqT, RespT> String getClientIpFromCall(ServerCall<ReqT, RespT> call) {
        String address = Option.ofNullable(call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR))
                .map(Objects::toString).getOrElse("");
        if (address.startsWith("/")) {
            address = address.substring(1);
        }
        return address.replaceAll(":\\d+$", "");
    }

    private SetF<String> getLocalAddresses() {
        try {
            SetF<String> localAddresses = Cf.hashSet(InetAddress.getLocalHost().getHostAddress());
//            skip all hosts ifaces
            Arrays.stream(InetAddress.getAllByName("localhost"))
                    .map(InetAddress::getHostAddress)
                    .forEach(localAddresses::add);
            logger.info("skip check for ifaces: {}", localAddresses);
            return localAddresses;
        } catch (IOException e) {
            logger.warn("Unable to lookup local addresses");
            return Cf.set();
        }
    }

    private <ReqT, RespT> boolean isHealthCheck(ServerCall<ReqT, RespT> call) {
        return call.getMethodDescriptor().getServiceName().equals("grpc.health.v1.Health");
    }

}
