package ypclient

import (
	"a.yandex-team.ru/infra/allocation-ctl/pkg/metricsutil"
	"a.yandex-team.ru/infra/allocation-ctl/pkg/yp/api"
	"a.yandex-team.ru/infra/allocation-ctl/pkg/yp/rpc"
	"a.yandex-team.ru/yp/go/proto/ypapi"
	"a.yandex-team.ru/yp/go/yp"
	"context"
	"fmt"
	"golang.org/x/sync/errgroup"
	"time"
)

type PodsGetter interface {
	Pods() PodInterface
}

type PodInterface interface {
	List(ctx context.Context, opts api.ListOptions) ([]*api.Pod, error)
	UpdateISSSpecList(ctx context.Context, pods []*api.Pod, withCAS bool) error
}

// pod implements PodInterface
type podsClient struct {
	client  rpc.RPCInterface
	metrics *podsMetrics
}

func newPods(c *YPClient) *podsClient {
	return &podsClient{
		client:  c.rpcClient,
		metrics: c.podsMetrics,
	}
}

func extractPodID(rsp rpc.SelectObjectsResponseInterface) (string, error) {
	m := &ypapi.TPodMeta{}
	err := rsp.Fill(m)
	if err != nil {
		return "", err
	}
	return m.GetId(), nil

}

func extractPod(rsp rpc.GetObjectsResponseInterface, opts api.ListOptions) (*api.Pod, error) {
	p := &api.Pod{
		TPod: ypapi.TPod{
			Meta:   &ypapi.TPodMeta{},
			Spec:   &ypapi.TPodSpec{},
			Status: &ypapi.TPodStatus{},
		},
	}
	if err := rsp.Fill(p.Meta, p.Spec, p.Status); err != nil {
		return nil, err
	}
	if !opts.FetchTimestamps {
		return p, nil
	}
	ts, err := rsp.Timestamps()
	if err != nil {
		return nil, fmt.Errorf("failed fetching timestamps: %w", err)
	}
	// fetch timestamps for meta, spec, status
	if len(ts) != len(opts.Selectors) {
		return nil, fmt.Errorf("timestamps count mismatched: got %d != %d", len(ts), len(opts.Selectors))
	}
	p.SpecTimestamp = ts[opts.SpecIdx]
	return p, nil
}

// Fetches pods from YP.
func (c *podsClient) getPods(ctx context.Context, req *yp.GetPodsRequest, opts api.ListOptions) ([]*api.Pod, error) {
	rsp, err := c.client.GetPods(ctx, *req)
	if err != nil {
		return nil, err
	}
	rv := make([]*api.Pod, 0, len(req.IDs))
	for rsp.Next() {
		p, err := extractPod(rsp, opts)
		if err != nil {
			return nil, err
		}
		rv = append(rv, p)
	}
	return rv, nil
}

// Selects pod IDs from YP by batches of size SelectIDsBatchSize and
// pushes them to channel tasksCh by batches of size GetObjectsBatchSize.
func (c *podsClient) produceGetPodsTasks(ctx context.Context, tasksCh chan<- []string, opts api.ListOptions) error {
	var rsp rpc.SelectObjectsResponseInterface = nil
	var err error = nil
	req := yp.SelectPodsRequest{
		Format:    yp.PayloadFormatProto,
		Filter:    opts.Filter,
		Selectors: api.DefaultSelectIDsSelectors,
		Limit:     opts.SelectIDsBatchSize,
	}
	toSend := make([]string, 0, opts.GetObjectsBatchSize)
	for {
		if rsp != nil {
			req.ContinuationToken = rsp.ContinuationToken()
		}
		rsp, err = c.client.SelectPods(ctx, req)
		if err != nil {
			return err
		}
		for rsp.Next() {
			podID, err := extractPodID(rsp)
			if err != nil {
				return err
			}
			toSend = append(toSend, podID)
			// SelectIDsBatchSize is not equal GetObjectsBatchSize so we extract batches of GetObjectsBatchSize size
			// and push them to tasksCh.
			if int32(len(toSend)) >= opts.GetObjectsBatchSize {
				select {
				case tasksCh <- toSend:
					toSend = make([]string, 0, opts.GetObjectsBatchSize)
				case <-ctx.Done():
					return ctx.Err()
				}
			}
		}
		if len(toSend) > 0 {
			select {
			case tasksCh <- toSend:
				toSend = make([]string, 0, opts.GetObjectsBatchSize)
			case <-ctx.Done():
				return ctx.Err()
			}
		}
		if int32(rsp.Count()) < opts.SelectIDsBatchSize {
			return nil
		}
	}
}

// Infinitely reads batch of pod IDs from tasksCh fetches pods batch from YP and pushes the batch to resultsCh.
func (c *podsClient) processGetPodsTasks(ctx context.Context, tasksCh <-chan []string, resultsCh chan<- []*api.Pod, opts api.ListOptions) error {
	req := &yp.GetPodsRequest{
		Selectors:       opts.Selectors,
		FetchTimestamps: opts.FetchTimestamps,
		Format:          yp.PayloadFormatProto,
	}
	for ids := range tasksCh {
		req.IDs = ids
		pods, err := c.getPods(ctx, req, opts)
		if err != nil {
			return err
		}
		select {
		case <-ctx.Done():
			return ctx.Err()
		case resultsCh <- pods:
		}
	}
	return nil
}

// List selects pod IDs by batches in producer goroutine and get pods by IDs in GetObjectsThreadsCount worker goroutines.
// Waits for all goroutines to complete or an error to occur.
func (c *podsClient) List(ctx context.Context, opts api.ListOptions) ([]*api.Pod, error) {
	if c.metrics != nil {
		defer metricsutil.MeasureSecondsSince(c.metrics.listTime, time.Now())
	}

	g, ctx := errgroup.WithContext(ctx)

	tasksCh := make(chan []string, 100)

	// start producer
	g.Go(func() error {
		// close tasksCh when producer is completed
		defer close(tasksCh)
		return c.produceGetPodsTasks(ctx, tasksCh, opts)
	})

	resultsCh := make(chan []*api.Pod, 100)

	g.Go(func() error {
		defer close(resultsCh)
		// start GetObjectsThreadsCount workers
		workersGroup, workersCtx := errgroup.WithContext(ctx)
		for i := 0; i < opts.GetObjectsThreadsCount; i++ {
			workersGroup.Go(func() error {
				return c.processGetPodsTasks(workersCtx, tasksCh, resultsCh, opts)
			})
		}
		err := workersGroup.Wait()
		return err
	})
	rv := make([]*api.Pod, 0)
	for pods := range resultsCh {
		rv = append(rv, pods...)
	}
	if err := g.Wait(); err != nil {
		return nil, err
	}
	if c.metrics != nil {
		c.metrics.count.Update(float64(len(rv)))
	}
	return rv, nil
}

func (c *podsClient) UpdateISSSpecList(ctx context.Context, pods []*api.Pod, withCAS bool) error {
	req := yp.UpdateObjectsRequest{}
	objs := make([]yp.UpdateObject, 0, len(pods))
	for _, p := range pods {
		obj := yp.UpdateObject{
			ObjectType: yp.ObjectTypePod,
			ObjectID:   p.GetMeta().GetId(),
			SetUpdates: []yp.SetObjectUpdate{
				{
					Path:   "/spec/iss",
					Object: p.GetSpec().GetIss(),
				},
			},
		}
		if withCAS {
			obj.AttributeTimestampPrerequisites = []yp.PrerequisiteObjectUpdate{
				{Path: "/spec", Timestamp: p.SpecTimestamp},
			}
		}
		objs = append(objs, obj)
	}
	req.Objects = objs
	return c.client.UpdateObjects(ctx, req)
}
