package healthcheckbalancer

import (
	"context"
	"math"
	"sync"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/balancer"
	"google.golang.org/grpc/health/grpc_health_v1"
	"google.golang.org/grpc/resolver"

	"a.yandex-team.ru/library/go/core/log"
	tvmutil "a.yandex-team.ru/travel/library/go/tvm"
)

type hostState struct {
	lastChecked time.Time
	isHealthy   bool
	lastLatency time.Duration
}

type healthCheckBalancer struct {
	mu               sync.RWMutex
	cfg              healthCheckBalancerCfg
	baseBalancer     balancer.Balancer
	hosts            []string
	hostStatuses     map[string]hostState
	isPolling        bool
	healthyHostFound bool
	logger           log.Logger
	tvmHelper        *tvmutil.TvmHelper
	serviceTvmID     uint32
}

type addrAttribute struct{}

func (hcb *healthCheckBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
	if config, ok := state.BalancerConfig.(*loadBalancerConfig); ok {
		state.ResolverState.Addresses = hcb.replicateAddresses(config, state)
	}
	return hcb.baseBalancer.UpdateClientConnState(state)
}

func (hcb *healthCheckBalancer) ResolverError(err error) {
	hcb.baseBalancer.ResolverError(err)
}

func (hcb *healthCheckBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
	hcb.baseBalancer.UpdateSubConnState(sc, state)
}

func (hcb *healthCheckBalancer) startPolling() {
	hcb.mu.Lock()
	defer hcb.mu.Unlock()
	hcb.isPolling = true
	go hcb.loop()
}

func (hcb *healthCheckBalancer) stopPolling() {
	hcb.mu.Lock()
	defer hcb.mu.Unlock()
	hcb.isPolling = false
}

func (hcb *healthCheckBalancer) loop() {
	for {
		hcb.mu.RLock()
		if !hcb.isPolling {
			break
		}
		hcb.mu.RUnlock()
		hcb.poll()
		time.Sleep(hcb.cfg.healthCheckInterval)
	}
}

func (hcb *healthCheckBalancer) poll() {
	hcb.mu.RLock()
	hosts := hcb.hosts
	hcb.mu.RUnlock()

	for _, host := range hosts {
		hostStatus := hcb.healthCheck(host)
		if hostStatus.isHealthy {
			hcb.healthyHostFound = true
		}
		hcb.setStatus(host, hostStatus)
	}
}

func (hcb *healthCheckBalancer) healthCheck(host string) hostState {
	ctx, cancel := context.WithTimeout(context.Background(), hcb.cfg.healthCheckTimeout)
	defer cancel()
	healthCheckStart := time.Now()
	dialOptions := []grpc.DialOption{
		grpc.WithInsecure(),
		grpc.WithBlock(),
		grpc.WithDisableHealthCheck(),
	}
	if hcb.tvmHelper != nil {
		dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(hcb.tvmHelper.GRPCClientInterceptor(hcb.serviceTvmID)))
	}
	gConn, err := grpc.DialContext(
		ctx,
		host,
		dialOptions...,
	)
	if err != nil {
		hcb.logger.Debug("Unable to contact host", log.String("host", host), log.Error(err))
	} else {
		defer gConn.Close()
		pingClient := grpc_health_v1.NewHealthClient(gConn)
		pingReq := grpc_health_v1.HealthCheckRequest{}
		if hcb.cfg.healthCheckServiceNameFunc != nil {
			pingReq.Service = hcb.cfg.healthCheckServiceNameFunc(host)
		} else {
			pingReq.Service = hcb.cfg.healthCheckServiceName
		}
		pingResp, err := pingClient.Check(ctx, &pingReq)
		if err != nil {
			hcb.logger.Debug("Error when health-checking host", log.String("host", host), log.Error(err))
		} else {
			if pingResp.GetStatus() == grpc_health_v1.HealthCheckResponse_SERVING {
				return hostState{time.Now(), true, time.Since(healthCheckStart)}
			}
		}
	}
	return hostState{time.Now(), false, time.Duration(math.MaxInt64)}
}

func (hcb *healthCheckBalancer) getStatus(host string) (state hostState, hasValue bool) {
	hcb.mu.RLock()
	defer hcb.mu.RUnlock()
	state, hasValue = hcb.hostStatuses[host]
	return
}

func (hcb *healthCheckBalancer) setStatus(host string, state hostState) {
	hcb.mu.Lock()
	defer hcb.mu.Unlock()
	hcb.hostStatuses[host] = state
}

func (hcb *healthCheckBalancer) updateHosts(hosts []string) {
	hcb.mu.Lock()
	defer hcb.mu.Unlock()
	hcb.hosts = hosts
}

func (hcb *healthCheckBalancer) isHealthy(host string) bool {
	hcb.mu.RLock()
	defer hcb.mu.RUnlock()
	// All servers are considered "healthy" until the truly healthy one is found - or the picker would have no hosts to choose from
	if !hcb.healthyHostFound {
		return true
	}
	hostStatus, ok := hcb.hostStatuses[host]
	return ok && hostStatus.isHealthy && time.Since(hostStatus.lastChecked) <= hcb.cfg.healthyStateTimeout
}

func (hcb *healthCheckBalancer) getLatency(host string) time.Duration {
	hcb.mu.RLock()
	defer hcb.mu.RUnlock()
	// All servers are considered "healthy" until the truly healthy one is found - or the picker would have no hosts to choose from
	if !hcb.healthyHostFound {
		return time.Duration(math.MaxInt64)
	}
	if hostStatus, ok := hcb.hostStatuses[host]; ok {
		return hostStatus.lastLatency
	}
	return time.Duration(math.MaxInt64)
}

func (hcb *healthCheckBalancer) Close() {
	hcb.logger.Debug("Closing the balancer")
	hcb.stopPolling()
	hcb.baseBalancer.Close()
}

func (hcb *healthCheckBalancer) replicateAddresses(config *loadBalancerConfig, state balancer.ClientConnState) []resolver.Address {
	connectionsPerHost := int(config.ConnectionsPerHost)
	addresses := make([]resolver.Address, connectionsPerHost*len(state.ResolverState.Addresses))
	for i, address := range state.ResolverState.Addresses {
		for j := 0; j < connectionsPerHost; j++ {
			newAddress := address
			newAddress.Attributes = newAddress.Attributes.WithValue(addrAttribute{}, j)
			addresses[i*connectionsPerHost+j] = newAddress
		}
	}
	return addresses
}
