package healthcheckbalancer

import (
	"math/rand"
	"sort"
	"sync"

	"go.uber.org/atomic"
	"google.golang.org/grpc/balancer"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"a.yandex-team.ru/library/go/core/log"
)

type closestInstancePicker struct {
	subConnsListByHost map[string]*subConnsList
	hostListManager    hostListManager

	mu     sync.Mutex
	logger log.Logger
}

type subConnsList struct {
	index    *atomic.Int32
	subConns []balancer.SubConn
}

func newSubConnsList(subConns ...balancer.SubConn) *subConnsList {
	return &subConnsList{subConns: subConns, index: atomic.NewInt32(rand.Int31())}
}

func (scl *subConnsList) Get() balancer.SubConn {
	currentIndex := int(scl.index.Load()) % len(scl.subConns)
	sc := scl.subConns[currentIndex]
	scl.index.Inc()
	return sc
}

func newClosestInstancePicker(subConns []namedSubConn, hostListManager hostListManager, logger log.Logger) *closestInstancePicker {
	subConnsByHost := make(map[string][]balancer.SubConn, 0)
	for _, sc := range subConns {
		subConnsByHost[sc.host] = append(subConnsByHost[sc.host], sc.subConn)
	}
	picker := &closestInstancePicker{
		subConnsListByHost: make(map[string]*subConnsList, len(subConnsByHost)),
		hostListManager:    hostListManager,
		logger:             logger,
	}
	for host, subConns := range subConnsByHost {
		picker.subConnsListByHost[host] = newSubConnsList(subConns...)
	}
	return picker
}

func (p *closestInstancePicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
	p.mu.Lock()
	defer p.mu.Unlock()

	healthyHosts := make([]string, 0, len(p.subConnsListByHost))
	for host := range p.subConnsListByHost {
		if p.hostListManager.isHealthy(host) {
			healthyHosts = append(healthyHosts, host)
		}
	}
	if len(healthyHosts) == 0 {
		p.logger.Debug("no healthy host has been found")
		return balancer.PickResult{}, status.Error(codes.Unavailable, "no healthy host has been found")
	}
	// sorting tens of elems to pick one with the least latency is fast enough
	sort.SliceStable(
		healthyHosts, func(i, j int) bool {
			return p.hostListManager.getLatency(healthyHosts[i]) < p.hostListManager.getLatency(healthyHosts[j])
		},
	)
	picked := healthyHosts[0]
	p.logger.Debug("Picked host", log.String("host", picked))
	return balancer.PickResult{SubConn: p.subConnsListByHost[picked].Get()}, nil
}
