package grpcresolver

import (
	"context"
	"io"
	"os"
	"sync"
	"time"

	"github.com/gofrs/uuid"
	"go.uber.org/atomic"
	"google.golang.org/grpc"

	pb "a.yandex-team.ru/infra/yp_service_discovery/api"
	"a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/core/xerrors"
)

var _ io.Closer = new(Resolver)
var _ resolver.Resolver = new(Resolver)

type ResolveEndpointsRUIDFunc func(*pb.TReqResolveEndpoints) string
type ResolvePodsRUIDFunc func(*pb.TReqResolvePods) string
type ResolveNodeRUIDFunc func(*pb.TReqResolveNode) string

type Resolver struct {
	serviceURI string
	clientName string

	grpcConn   *grpc.ClientConn
	grpcClient pb.TServiceDiscoveryServiceClient

	dialed atomic.Bool
	dialMu sync.Mutex

	resolveEndpointsRUIDFunc ResolveEndpointsRUIDFunc
	resolvePodsRUIDFunc      ResolvePodsRUIDFunc
	resolveNodeRUIDFunc      ResolveNodeRUIDFunc

	logger log.Structured
}

// New returns gRPC resolver instance.
func New(opts ...ResolverOpt) (*Resolver, error) {
	r := &Resolver{
		serviceURI: resolver.ServiceDiscoveryHostProd + ":" + resolver.ServiceDiscoveryGRPCPort,
		clientName: getClientName(),
		logger:     new(nop.Logger),
		resolveEndpointsRUIDFunc: func(_ *pb.TReqResolveEndpoints) string {
			return uuid.Must(uuid.NewV4()).String()
		},
		resolvePodsRUIDFunc: func(_ *pb.TReqResolvePods) string {
			return uuid.Must(uuid.NewV4()).String()
		},
		resolveNodeRUIDFunc: func(_ *pb.TReqResolveNode) string {
			return uuid.Must(uuid.NewV4()).String()
		},
	}

	for _, opt := range opts {
		opt(r)
	}

	if r.grpcConn != nil && r.grpcClient != nil {
		r.dialed.Store(true)
	}

	return r, nil
}

func (r *Resolver) Close() error {
	r.logger.Debug("closing YP SD gRPC connection")
	if r.grpcConn != nil {
		return r.grpcConn.Close()
	}
	return nil
}

func (r *Resolver) ResolveEndpoints(ctx context.Context, cluster, endpointSet string) (*resolver.ResolveEndpointsResponse, error) {
	if err := r.dialGRPCLazy(ctx); err != nil {
		return nil, err
	}

	// send actual gRPC request
	request := &pb.TReqResolveEndpoints{
		ClusterName:   cluster,
		EndpointSetId: endpointSet,
		ClientName:    r.clientName,
	}
	request.Ruid = r.resolveEndpointsRUIDFunc(request)

	r.logger.Debug("sending resolve request", log.Any("request", request))
	res, err := r.grpcClient.ResolveEndpoints(ctx, request)
	if err != nil {
		return nil, xerrors.Errorf("failed to resolve endpoint-set %s@%s: %w", endpointSet, cluster, err)
	}

	r.logger.Debug("unmarshaling response", log.Any("response", res))
	var response resolver.ResolveEndpointsResponse
	err = response.UnmarshalProto(res)
	if err != nil {
		return nil, err
	}

	return &response, nil
}

func (r *Resolver) ResolvePods(ctx context.Context, cluster, podSet string) (*resolver.ResolvePodsResponse, error) {
	if err := r.dialGRPCLazy(ctx); err != nil {
		return nil, err
	}

	// send actual gRPC request
	request := &pb.TReqResolvePods{
		ClusterName: cluster,
		PodSetId:    podSet,
		ClientName:  r.clientName,
	}
	request.Ruid = r.resolvePodsRUIDFunc(request)

	r.logger.Debug("sending resolve request", log.Any("request", request))
	res, err := r.grpcClient.ResolvePods(ctx, request)
	if err != nil {
		return nil, xerrors.Errorf("failed to resolve pod-set %s@%s: %w", podSet, cluster, err)
	}

	r.logger.Debug("unmarshaling response", log.Any("response", res))
	var response resolver.ResolvePodsResponse
	err = response.UnmarshalProto(res)
	if err != nil {
		return nil, err
	}

	return &response, nil
}

func (r *Resolver) ResolveNode(ctx context.Context, cluster, node string) (*resolver.ResolveNodeResponse, error) {
	if err := r.dialGRPCLazy(ctx); err != nil {
		return nil, err
	}

	// send actual gRPC request
	request := &pb.TReqResolveNode{
		ClusterName: cluster,
		NodeId:      node,
		ClientName:  r.clientName,
	}
	request.Ruid = r.resolveNodeRUIDFunc(request)

	r.logger.Debug("sending resolve request", log.Any("request", request))
	res, err := r.grpcClient.ResolveNode(ctx, request)
	if err != nil {
		return nil, xerrors.Errorf("failed to resolve pod-set %s@%s: %w", node, cluster, err)
	}

	r.logger.Debug("unmarshaling response", log.Any("response", res))
	var response resolver.ResolveNodeResponse
	err = response.UnmarshalProto(res)
	if err != nil {
		return nil, err
	}

	return &response, nil
}

func (r *Resolver) dialGRPCLazy(ctx context.Context) error {
	if r.dialed.Load() {
		return nil
	}

	r.dialMu.Lock()
	defer r.dialMu.Unlock()
	// check dial status in case of multiple async dials
	if r.dialed.Load() {
		return nil
	}

	r.logger.Debug("creating gRPC client since none given", log.String("uri", r.serviceURI))
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
	defer cancel()

	conn, err := grpc.DialContext(ctx, r.serviceURI,
		grpc.WithInsecure(),
		grpc.WithBlock(),
		grpc.FailOnNonTempDialError(true),
	)
	if err != nil {
		return xerrors.Errorf("failed to connect to %s: %w", r.serviceURI, err)
	}

	r.grpcConn = conn
	r.grpcClient = pb.NewTServiceDiscoveryServiceClient(conn)

	r.dialed.Store(true)
	return nil
}

func getClientName() string {
	hostname := "unknown"
	if osHostname, err := os.Hostname(); err == nil {
		hostname = osHostname
	}

	username := "go_resolver"
	for _, key := range []string{"SUDO_USER", "USER"} {
		if user := os.Getenv(key); user != "" {
			username = user
			break
		}
	}

	return username + "@" + hostname
}
