package ru.yandex.logbroker2;

import java.io.IOException;
import java.time.Instant;
import java.util.Date;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.logging.Level;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.Timestamp;
import io.grpc.ManagedChannel;
import io.grpc.internal.DnsNameResolverProvider;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.netty.channel.ChannelOption;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import yandex.cloud.priv.iam.v1.IamTokenServiceGrpc;
import yandex.cloud.priv.iam.v1.PITS.CreateIamTokenRequest;
import yandex.cloud.priv.iam.v1.PITS.CreateIamTokenResponse;

import ru.yandex.kikimr.persqueue.auth.Credentials;
import ru.yandex.logbroker2.config.ImmutableIamJwtConfig;
import ru.yandex.logger.PrefixedLogger;

public class IamJwtCredentialsProvider implements Supplier<Credentials> {
    private static final long TICKET_EXPIRE_SECONDS = TimeUnit.MINUTES.toSeconds(10);
    private static final int TIMEOUT = (int) TimeUnit.SECONDS.toMillis(5);

    private final PrefixedLogger logger;
    private final ImmutableIamJwtConfig iamConfig;
    private final IamTokenServiceGrpc.IamTokenServiceFutureStub stub;
    private final AtomicReference<TokenBox> tokenCache = new AtomicReference<>();

    public IamJwtCredentialsProvider(
        final PrefixedLogger logger,
        final ImmutableIamJwtConfig iamConfig)
        throws IOException
    {
        this.logger = logger;
        this.iamConfig = iamConfig;

        NettyChannelBuilder channelBuilder = NettyChannelBuilder
            .forAddress(iamConfig.host().getHostName(), iamConfig.host().getPort())
            .nameResolverFactory(new DnsNameResolverProvider());
        if (iamConfig.keepAliveTime() > 0) {
            channelBuilder.keepAliveTime(
                iamConfig.keepAliveTime(),
                TimeUnit.MILLISECONDS);
        }

        if (iamConfig.keepAliveTimeout() > 0) {
            channelBuilder.keepAliveTimeout(
                iamConfig.keepAliveTimeout(),
                TimeUnit.MILLISECONDS);
        }

        if (iamConfig.retries().count() > 0) {
            channelBuilder
                .enableRetry()
                .maxRetryAttempts(iamConfig.retries().count());
        }

        channelBuilder.keepAliveWithoutCalls(iamConfig.keepAliveWithoutCalls());
        channelBuilder.userAgent(iamConfig.userAgent());
        if (iamConfig.https().trustManagerFactory() != null) {
            channelBuilder.sslContext(
                GrpcSslContexts.forClient()
                    .trustManager(InsecureTrustManagerFactory.INSTANCE)
                    //.trustManager(iamConfig.https().trustManagerFactory())
                    .build());
        } else {
            channelBuilder.usePlaintext();
        }

        channelBuilder =
            channelBuilder
                .withOption(ChannelOption.SO_TIMEOUT, TIMEOUT)
                .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, TIMEOUT);

        ManagedChannel channel = channelBuilder.build();
        stub = IamTokenServiceGrpc.newFutureStub(channel);
    }

    @SuppressWarnings(value = "JdkObsolete")
    private TokenBox createToken() throws Exception {
        Instant now = Instant.now();

        logger.info("Getting Iam token");
        String encodedToken = Jwts.builder()
            .setHeaderParam("kid", iamConfig.keyId())
            .setIssuer(iamConfig.serviceAccountId())
            .setAudience(iamConfig.audience())
            .setIssuedAt(Date.from(now))
            .setExpiration(Date.from(now.plusSeconds(TICKET_EXPIRE_SECONDS)))
            .signWith(SignatureAlgorithm.PS256, iamConfig.privateKey())
            .compact();


        int retry = 0;
        int maxRetries = Math.max(iamConfig.retries().count(), 0);
        Exception exception = null;
        while (retry <= maxRetries) {
            if (retry > 0) {
                Thread.sleep(iamConfig.retries().interval());
            }
            try {
                ListenableFuture<CreateIamTokenResponse> future =
                    stub.create(CreateIamTokenRequest.newBuilder().setJwt(encodedToken).build());
                CreateIamTokenResponse response =
                    future.get(TIMEOUT, TimeUnit.MILLISECONDS);
                logger.info("Iam token received");
                return new TokenBox(response.getExpiresAt(), response.getIamToken());
            } catch (Exception e) {
                logger.log(Level.WARNING, "Temp failure receive iam token " + retry, e);
                exception = e;
                retry++;
            }
        }

        throw exception;
    }

    @Override
    public Credentials get() {
        TokenBox box = tokenCache.get();

        if (box == null
            || (box.expireInSeconds() < TimeUnit.MINUTES.toSeconds(1)))
        {
            if (box != null) {
                logger.info("Token expires in " + box.expireInSeconds() + " seconds, updating");
            }
            TokenBox newBox = null;
            synchronized (tokenCache) {
                while (newBox == null) {
                    try {
                        newBox = createToken();
                        box = newBox;
                        tokenCache.set(newBox);
                        break;
                    } catch (Exception e) {
                        logger.log(Level.WARNING, "Failed to get iam ticket", e);
                        if (box != null && box.expireInSeconds() > 0) {
                            break;
                        }

                        try {
                            Thread.sleep(TimeUnit.SECONDS.toMillis(10));
                        } catch (InterruptedException ie) {
                            logger.log(Level.WARNING, "Failed to get iam ticket", ie);
                            return null;
                        }

                    }
                }
            }
        }
        return Credentials.iamToken(box.token());
    }

    private static final class TokenBox {
        private final long expiration;
        private final String token;

        public TokenBox(final Timestamp expiration, final String token) {
            this.expiration = expiration.getSeconds();
            this.token = token;
        }

        public long expiration() {
            return expiration;
        }

        public long expireInSeconds() {
            return expiration - System.currentTimeMillis() / 1000;
        }

        public String token() {
            return token;
        }
    }
}
