package grpcresolver

import (
	"context"
	"fmt"
	"net"
	"sort"
	"sync"
	"time"

	"github.com/cenkalti/backoff/v4"
	"google.golang.org/grpc/attributes"
	"google.golang.org/grpc/resolver"

	ypResolver "a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
)

var _ resolver.Resolver = (*Resolver)(nil)

type (
	Resolver struct {
		t        target
		ctx      context.Context
		cancel   context.CancelFunc
		cc       resolver.ClientConn
		resolver ypResolver.Resolver

		// rn channel is used by ResolveNow() to force an immediate resolution of the target.
		rn chan struct{}

		// wg is used to enforce Close() to return after the watcher() goroutine has finished.
		wg sync.WaitGroup
	}

	ypResponse struct {
		cluster string
		yp      *ypResolver.ResolveEndpointsResponse
	}
)

func (r *Resolver) ResolveNow(_ resolver.ResolveNowOptions) {
	select {
	case r.rn <- struct{}{}:
	default:
	}
}

func (r *Resolver) Close() {
	r.cancel()
	r.wg.Wait()
	close(r.rn)
}

func (r *Resolver) watcher() {
	defer r.wg.Done()

	expBackoff := backoff.NewExponentialBackOff()
	// make it unstoppable
	expBackoff.MaxElapsedTime = 0
	// configure backoff intervals
	expBackoff.InitialInterval = backoffMinInterval
	expBackoff.MaxInterval = r.t.resRate
	expBackoff.Reset()

	resolveTimer := time.NewTimer(0)
	defer resolveTimer.Stop()

	for {
		select {
		case <-r.ctx.Done():
			return
		case <-resolveTimer.C:
		case <-r.rn:
			if !resolveTimer.Stop() {
				// Before resetting a timer, it should be stopped to prevent racing with
				// reads on it's channel.
				<-resolveTimer.C
			}
		}

		state, incomplete := r.lookup()
		var err error
		if len(state.Addresses) == 0 {
			logger.Warningf("empty addresses list for endpointSet set: %s", r.t.endpointSet)
			switch {
			case incomplete:
				err = fmt.Errorf("unable to resolve enpoint set %s on all of the clusters: %s", r.t.endpointSet, r.t.clusters)
			case !r.t.allowEmpty:
				err = fmt.Errorf("no alive endpoints on endpointSet set %s was resolved on all of the clusters: %s", r.t.endpointSet, r.t.clusters)
			}
		}

		if err != nil {
			// Report error to the underlying grpc.ClientConn.
			r.cc.ReportError(err)
		} else {
			err = r.cc.UpdateState(state)
		}

		if err != nil {
			// Resolution error or an error received from ClientConn.
			// Wait backoff to next try
			resolveTimer.Reset(expBackoff.NextBackOff())
		} else {
			// Success resolving, wait next resolve or ResolveNow
			expBackoff.Reset()
			resolveTimer.Reset(r.t.freq)
		}

		// Sleep to prevent excessive re-resolutions.
		// Incoming resolution requests will be queued in r.rn.
		timer := time.NewTimer(r.t.resRate)
		select {
		case <-timer.C:
		case <-r.ctx.Done():
			timer.Stop()
			return
		}
	}
}

func (r *Resolver) lookup() (resolver.State, bool) {
	incomplete := false
	responses := make([]ypResponse, 0, len(r.t.clusters))
	resolve := func(cluster string) {
		ctx, cancel := context.WithTimeout(r.ctx, r.t.timeout)
		defer cancel()

		response, err := r.resolver.ResolveEndpoints(ctx, cluster, r.t.endpointSet)
		if err != nil {
			logger.Errorf("failed to resolve %s@%s endpointSet set: %v", r.t.endpointSet, cluster, err)
			incomplete = true
			return
		}

		if response.ResolveStatus != ypResolver.StatusEndpointOK {
			logger.Warningf("failed to resolve %s@%s endpointSet set: status not OK: %d", r.t.endpointSet, cluster, response.ResolveStatus)
			incomplete = true
			return
		}

		responses = append(responses, ypResponse{
			cluster: cluster,
			yp:      response,
		})
	}

	for _, cluster := range r.t.clusters {
		resolve(cluster)
	}

	var addresses []resolver.Address
	for _, resp := range responses {
		for _, endpoint := range resp.yp.EndpointSet.Endpoints {
			if !endpoint.Ready {
				logger.Infof("skip not ready endpointSet: %s", endpoint.ID)
				continue
			}

			addr, err := r.endpointToAddress(resp.cluster, endpoint)
			if err != nil {
				logger.Errorf("skip invalid endpointSet %s: %v", endpoint.ID, err)
				continue
			}

			addresses = append(addresses, addr)
		}
	}

	return resolver.State{
		Addresses: addresses,
	}, incomplete
}

func (r *Resolver) endpointToAddress(cluster string, e *ypResolver.Endpoint) (resolver.Address, error) {
	var addr string
	switch {
	case e.IPv6 != nil && !e.IPv6.Equal(net.IPv6unspecified):
		addr = fmt.Sprintf("[%s]:%d", e.IPv6, e.Port)
	case e.IPv4 != nil && !e.IPv4.Equal(net.IPv4zero):
		addr = fmt.Sprintf("%s:%d", e.IPv4, e.Port)
	default:
		return resolver.Address{}, fmt.Errorf("endpointSet %s@%s doesn't have an IPv6 or IPv4 address", e.ID, cluster)
	}

	sort.Strings(e.Labels)
	return resolver.Address{
		Addr:       addr,
		ServerName: r.t.serverName,
		Attributes: attributes.New(attrClusterKey, cluster).
			WithValue(attrIDKey, e.ID).
			WithValue(attrLabelsKey, labelsSlice(e.Labels)),
	}, nil
}
