package client

import (
	"context"
	"crypto/tls"
	tally "github.com/uber-go/tally/v4"
	prometheus "github.com/uber-go/tally/v4/prometheus"
	"go.temporal.io/sdk/client"
	"go.uber.org/zap"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/metadata"
	"net"
	"time"
)

type headersProvider struct {
	oauthToken string
}

func (h *headersProvider) GetHeaders(ctx context.Context) (map[string]string, error) {
	headers := make(map[string]string)
	headers["Authorization"] = h.oauthToken
	return headers, nil
}

type Option func(*client.Options)

func WithMetricsScope(metricsScope tally.Scope) Option {
	return func(c *client.Options) {
		c.MetricsScope = metricsScope
	}
}

func getTLSConfig(hostPort string) *tls.Config {
	host, _, _ := net.SplitHostPort(hostPort)
	return &tls.Config{
		MinVersion: tls.VersionTLS12,
		NextProtos: []string{
			"h2",
		},
		ServerName: host,
	}
}

func getOptions(hostPort, oauthToken string, opts ...Option) *client.Options {
	o := &client.Options{
		HostPort:        hostPort,
		HeadersProvider: &headersProvider{oauthToken},
		ConnectionOptions: client.ConnectionOptions{
			TLS: getTLSConfig(hostPort),
		},
		DataConverter: GetDefaultSwatDataConverter(),
	}
	for _, opt := range opts {
		opt(o)
	}
	return o
}

func NewSdkNamespaceClient(hostPort, oauthToken string, opts ...Option) (client.NamespaceClient, error) {
	o := getOptions(hostPort, oauthToken, opts...)
	return client.NewNamespaceClient(*o)
}

func NewSdkClient(hostPort, oauthToken, namespace string, opts ...Option) (client.Client, error) {
	o := getOptions(hostPort, oauthToken, opts...)
	o.Namespace = namespace
	return client.NewClient(*o)
}

// copy-paste from https://github.com/temporalio/samples-go/blob/cac8bb0ee971be97cd2e1f3d48ddfd247a21b7ec/metrics/worker/main.go#L40

func NewPrometheusScope(c prometheus.Configuration, logger *zap.Logger) (tally.Scope, error) {
	reporter, err := c.NewReporter(
		prometheus.ConfigurationOptions{
			OnError: func(err error) {
				logger.Info("tally prometheus reporter error", zap.Error(err))
			},
		},
	)
	if err != nil {
		return nil, err
	}
	scopeOpts := tally.ScopeOptions{
		CachedReporter:  reporter,
		Separator:       prometheus.DefaultSeparator,
		SanitizeOptions: &sanitizeOptions,
		Prefix:          "temporal",
	}
	scope, _ := tally.NewRootScope(scopeOpts, time.Second)
	return scope, nil
}

// tally sanitizer options that satisfy Prometheus restrictions.
// This will rename metrics at the tally emission level, so metrics name we
// use maybe different from what gets emitted. In the current implementation
// it will replace - and . with _
var (
	safeCharacters = []rune{'_'}

	sanitizeOptions = tally.SanitizeOptions{
		NameCharacters: tally.ValidCharacters{
			Ranges:     tally.AlphanumericRange,
			Characters: safeCharacters,
		},
		KeyCharacters: tally.ValidCharacters{
			Ranges:     tally.AlphanumericRange,
			Characters: safeCharacters,
		},
		ValueCharacters: tally.ValidCharacters{
			Ranges:     tally.AlphanumericRange,
			Characters: safeCharacters,
		},
		ReplacementCharacter: tally.DefaultReplacementCharacter,
	}
)

// /copy-paste

// copy-paste from tctl/factory.go
func headersProviderInterceptor(headersProvider headersProvider) grpc.UnaryClientInterceptor {
	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
		headers, err := headersProvider.GetHeaders(ctx)
		if err != nil {
			return err
		}
		for k, v := range headers {
			ctx = metadata.AppendToOutgoingContext(ctx, k, v)
		}
		return invoker(ctx, method, req, reply, cc, opts...)
	}
}

func CreateGRPCConnection(hostPort, oauthToken string) (*grpc.ClientConn, error) {
	o := getOptions(hostPort, oauthToken)
	grpcSecurityOptions := grpc.WithInsecure()
	if o.ConnectionOptions.TLS != nil {
		grpcSecurityOptions = grpc.WithTransportCredentials(credentials.NewTLS(getTLSConfig(hostPort)))
	}

	x := grpc.WithChainUnaryInterceptor(
		headersProviderInterceptor(headersProvider{oauthToken}),
	)
	connection, err := grpc.Dial(hostPort, grpcSecurityOptions, x)
	if err != nil {
		return nil, err
	}
	return connection, nil
}

// /copy-paste
