package healthcheckbalancer

import (
	"encoding/json"
	"time"

	"google.golang.org/grpc/balancer"
	"google.golang.org/grpc/balancer/base"
	"google.golang.org/grpc/serviceconfig"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/travel/library/go/tvm"
)

const BalancerName = "healthcheck_balancer"

type BalancingMethod uint

const (
	BalancingMethodUnknown BalancingMethod = iota
	BalancingMethodRoundRobin
	BalancingMethodChooseClosest
	BalancingMethodChooseLexicographicallyFirst
)

// For testing purposes only
type HealthCheckServiceNameFunc func(host string) string

type healthCheckBalancerBuilderOpt func(*HealthCheckBalancerBuilder)

type healthCheckBalancerCfg struct {
	healthCheckTimeout         time.Duration
	healthCheckInterval        time.Duration
	healthyStateTimeout        time.Duration
	healthCheckServiceName     string
	healthCheckServiceNameFunc HealthCheckServiceNameFunc // for testing purposes only
}

type HealthCheckBalancerBuilder struct {
	balancer.Builder
	pickerBuilder *healthCheckPickerBuilder
	cfg           healthCheckBalancerCfg
	logger        log.Logger
	tvmProvider   func() (helper *tvm.TvmHelper, tvmID uint32)
}

func NewBalancerBuilder(opts ...healthCheckBalancerBuilderOpt) *HealthCheckBalancerBuilder {
	hcbb := HealthCheckBalancerBuilder{
		cfg: healthCheckBalancerCfg{
			healthCheckTimeout:     time.Second,
			healthCheckInterval:    2 * time.Second,
			healthyStateTimeout:    30 * time.Second,
			healthCheckServiceName: "",
		},
		pickerBuilder: &healthCheckPickerBuilder{},
		logger:        &nop.Logger{},
		tvmProvider:   nil,
	}

	for _, opt := range opts {
		opt(&hcbb)
	}

	hcbb.pickerBuilder.logger = hcbb.logger

	return &hcbb
}

func WithLogger(l log.Logger) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.logger = l
	}
}

func WithHealthCheckTimeout(timeout time.Duration) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.cfg.healthCheckTimeout = timeout
	}
}

func WithHealthyStateTimeout(timeout time.Duration) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.cfg.healthyStateTimeout = timeout
	}
}

func WithHealthCheckInterval(interval time.Duration) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.cfg.healthCheckInterval = interval
	}
}

func WithBalancingMethod(method BalancingMethod) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.pickerBuilder.balancingMethod = method
	}
}

func WithHealthCheckServiceName(healthCheckServiceName string) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.cfg.healthCheckServiceName = healthCheckServiceName
	}
}

func WithHealthCheckServiceNameFunc(fn HealthCheckServiceNameFunc) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.cfg.healthCheckServiceNameFunc = fn
	}
}

func WithTVMProvider(tvmHelperFactory func() (helper *tvm.TvmHelper, tvmID uint32)) healthCheckBalancerBuilderOpt {
	return func(r *HealthCheckBalancerBuilder) {
		r.tvmProvider = tvmHelperFactory
	}
}

// Name implements "grpc/balancer.Builder" interface.
func (b *HealthCheckBalancerBuilder) Name() string { return BalancerName }

// Build implements "grpc/balancer.Builder" interface.
func (b *HealthCheckBalancerBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
	baseBalancerBuilder := base.NewBalancerBuilder(BalancerName, b.pickerBuilder, base.Config{HealthCheck: false})
	baseBalancer := baseBalancerBuilder.Build(cc, opt)
	bb := &healthCheckBalancer{
		baseBalancer: baseBalancer,
		cfg:          b.cfg,
		hostStatuses: make(map[string]hostState),
		logger:       b.logger,
	}
	if b.tvmProvider != nil {
		bb.tvmHelper, bb.serviceTvmID = b.tvmProvider()
	}
	b.pickerBuilder.hostListManager = bb

	bb.startPolling()

	return bb
}

func (b HealthCheckBalancerBuilder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
	config := &loadBalancerConfig{ConnectionsPerHost: 1}
	err := json.Unmarshal(c, config)
	if err != nil {
		return nil, err
	}
	if err := config.Validate(); err != nil {
		return nil, err
	}
	return config, nil
}
