package ru.yandex.intranet.d.grpc.security;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Optional;

import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.inprocess.InProcessSocketAddress;
import org.springframework.context.annotation.Profile;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Mono;

import ru.yandex.intranet.d.grpc.interceptors.AccessLogServerCall;
import ru.yandex.intranet.d.web.security.blackbox.BlackboxAuthChecker;
import ru.yandex.intranet.d.web.security.model.YaAuthenticationToken;
import ru.yandex.intranet.d.web.security.model.YaPrincipal;
import ru.yandex.intranet.d.web.security.tvm.TvmTicketChecker;

/**
 * Extracts authentication headers.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Profile({"dev", "testing", "production"})
@Component("yaGrpcAuthenticationConverter")
public class YaGrpcAuthenticationConverter implements GrpcAuthenticationConverter {

    private final TvmTicketChecker tvmTicketChecker;
    private final BlackboxAuthChecker blackboxAuthChecker;

    public YaGrpcAuthenticationConverter(TvmTicketChecker tvmTicketChecker, BlackboxAuthChecker blackboxAuthChecker) {
        this.tvmTicketChecker = tvmTicketChecker;
        this.blackboxAuthChecker = blackboxAuthChecker;
    }

    @Override
    public Mono<Authentication> readAuthentication(ServerCall<?, ?> call, Metadata headers) {
        String userTicket = headers.get(Metadata.Key.of("X-Ya-User-Ticket", Metadata.ASCII_STRING_MARSHALLER));
        String serviceTicket = headers.get(Metadata.Key.of("X-Ya-Service-Ticket", Metadata.ASCII_STRING_MARSHALLER));
        String oauthToken = Optional.ofNullable(headers
                .get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)))
                .flatMap(this::getToken).orElse(null);
        if (userTicket != null && serviceTicket != null) {
            return tvmTicketChecker.checkUser(userTicket, serviceTicket)
                    .doOnSuccess(auth -> rememberAuth(call, auth)).cast(Authentication.class);
        } else if (serviceTicket != null) {
            return tvmTicketChecker.checkService(serviceTicket)
                    .doOnSuccess(auth -> rememberAuth(call, auth)).cast(Authentication.class);
        } else if (oauthToken != null) {
            String userIp = getRemoteIp(call).orElse("127.0.0.1");
            return blackboxAuthChecker.checkOauthToken(oauthToken, userIp)
                    .doOnSuccess(auth -> rememberAuth(call, auth)).cast(Authentication.class);
        } else {
            return Mono.empty();
        }
    }

    private Optional<String> getToken(String value) {
        if (value.startsWith("OAuth ") && value.length() > 6) {
            return Optional.of(value.substring(6));
        }
        return Optional.empty();
    }

    private Optional<String> getRemoteIp(ServerCall<?, ?> call) {
        SocketAddress socketAddress = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
        if (socketAddress == null) {
            return Optional.empty();
        }
        if (socketAddress instanceof InProcessSocketAddress) {
            return Optional.empty();
        }
        if (socketAddress instanceof InetSocketAddress) {
            InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress;
            return Optional.of(inetSocketAddress.getHostString()).map(this::prepareIp);
        }
        return Optional.of(socketAddress.toString()).map(this::prepareIp);
    }

    private String prepareIp(String value) {
        if (value.contains("%")) {
            int lastSeparatorIndex = value.lastIndexOf("%");
            return value.substring(0, lastSeparatorIndex);
        }
        return value;
    }

    private void rememberAuth(ServerCall<?, ?> call, YaAuthenticationToken authentication) {
        if (authentication != null) {
            YaPrincipal principal = (YaPrincipal) authentication.getPrincipal();
            GrpcAuthenticationAttributes attributes = new GrpcAuthenticationAttributes(principal.getUid()
                    .orElse(null), principal.getTvmServiceId().orElse(null),
                    principal.getOAuthClientId().orElse(null), principal.getOAuthClientName().orElse(null));
            ((AccessLogServerCall<?, ?>) call).setAuthenticationAttributes(attributes);
        }
    }

}
