package ypresolver

import (
	"context"
	"errors"
	"fmt"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/resolver"

	pb "a.yandex-team.ru/infra/yp_service_discovery/api"
	resolverconst "a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
	"a.yandex-team.ru/infra/yp_service_discovery/golang/resolver/grpcresolver"
	"a.yandex-team.ru/library/go/core/log"
)

const YPScheme = "ypservice"

func BuildServiceFQDN(serviceName string) string {
	return fmt.Sprintf("%s:///%s", YPScheme, serviceName)
}

type ypResolverBuilder struct {
	Clusters            []string
	ConnectionTimeout   time.Duration
	YpConnectionTimeout time.Duration
	logger              log.Structured
}

type ypResolverBuilderOpt func(*ypResolverBuilder)

func WithLogger(l log.Structured) ypResolverBuilderOpt {
	return func(r *ypResolverBuilder) {
		r.logger = l
	}
}

func WithConnectionTimeout(timeout time.Duration) ypResolverBuilderOpt {
	return func(r *ypResolverBuilder) {
		r.ConnectionTimeout = timeout
	}
}

func WithYpConnectionTimeout(timeout time.Duration) ypResolverBuilderOpt {
	return func(r *ypResolverBuilder) {
		r.YpConnectionTimeout = timeout
	}
}

func WithClusters(clusters []string) ypResolverBuilderOpt {
	if len(clusters) == 0 {
		return func(r *ypResolverBuilder) {}
	}
	return func(r *ypResolverBuilder) {
		r.Clusters = clusters
	}
}

func NewYPResolverBuilder(opts ...ypResolverBuilderOpt) resolver.Builder {
	yprb := &ypResolverBuilder{
		Clusters:            resolverconst.AvailableClusters,
		ConnectionTimeout:   30 * time.Second,
		YpConnectionTimeout: 1 * time.Second,
	}

	for _, opt := range opts {
		opt(yprb)
	}

	return yprb
}

func (b *ypResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
	r := &ypResolver{
		clientConn:          cc,
		ypConnectionTimeout: b.YpConnectionTimeout,
	}
	ctx, cancel := context.WithTimeout(context.Background(), b.ConnectionTimeout)
	defer cancel()

	err := r.start(ctx, target.Endpoint, b.Clusters, b.logger)
	if err != nil {
		return nil, err
	}
	return r, nil
}

func (b *ypResolverBuilder) Scheme() string { return YPScheme }

type ypResolver struct {
	addresses           []string
	clientConn          resolver.ClientConn
	ypConnectionTimeout time.Duration
}

func (r *ypResolver) start(ctx context.Context, serviceName string, clusters []string, logger log.Structured) (err error) {
	var grpcConn *grpc.ClientConn
	for {
		ypConnectionCtx, cancelFunc := context.WithTimeout(ctx, r.ypConnectionTimeout)
		defer cancelFunc()
		grpcConn, err = grpc.DialContext(
			ypConnectionCtx,
			fmt.Sprintf("%s:%s", resolverconst.ServiceDiscoveryHostProd, resolverconst.ServiceDiscoveryGRPCPort),
			grpc.WithInsecure(),
			grpc.WithBlock(),
			grpc.FailOnNonTempDialError(true),
		)
		if err == nil || !errors.Is(err, context.DeadlineExceeded) {
			break
		}
	}
	if err != nil {
		return ServiceDiscoveryNotAvailableError{err}
	}
	defer grpcConn.Close()

	grpcClient := pb.NewTServiceDiscoveryServiceClient(grpcConn)
	rsOpts := []grpcresolver.ResolverOpt{
		grpcresolver.WithGRPCClient(grpcClient, grpcConn),
	}
	if logger != nil {
		rsOpts = append(rsOpts, grpcresolver.WithLogger(logger))
	}
	rs, err := grpcresolver.New(rsOpts...)
	if err != nil {
		return YPResolverNotAvailableError{err}
	}
	defer rs.Close()

	r.addresses = []string{}

	for _, dc := range clusters {
		response, err := rs.ResolveEndpoints(ctx, dc, serviceName)
		if err != nil {
			return UnknownServiceNameError{err, serviceName}
		}
		if logger != nil && response != nil {
			logger.Debug(
				"service discovery response",
				log.String("cluster", dc),
				log.Reflect("response", response),
			)
		}
		if response != nil && response.EndpointSet != nil {
			for _, endpoint := range response.EndpointSet.Endpoints {
				if !endpoint.Ready {
					continue
				}

				r.addresses = append(r.addresses, fmt.Sprintf("%v:%v", endpoint.FQDN, endpoint.Port))
			}
		}
	}

	if len(r.addresses) == 0 {
		return ServiceHasNoHostsError{serviceName}
	}

	addrs := make([]resolver.Address, len(r.addresses))
	for i, s := range r.addresses {
		addrs[i] = resolver.Address{Addr: s}
	}

	return r.clientConn.UpdateState(resolver.State{Addresses: addrs})
}

func (*ypResolver) ResolveNow(o resolver.ResolveNowOptions) {}

func (*ypResolver) Close() {}
