package collecting

import (
	"fmt"
	"sort"
	"sync"

	"golang.org/x/sync/errgroup"

	"a.yandex-team.ru/travel/library/go/renderer"
)

type ParallelCollector struct {
	providers     []func() (interface{}, error)
	errorHandlers []func(error) error
}

func NewParallelCollector() *ParallelCollector {
	return &ParallelCollector{}
}

func (pc *ParallelCollector) AddProvider(blockProviderFunc func() (interface{}, error), onError func(error) error) {
	pc.providers = append(pc.providers, blockProviderFunc)
	pc.errorHandlers = append(pc.errorHandlers, onError)
}

func (pc *ParallelCollector) Collect() ([]interface{}, error) {
	blockByIdx := sync.Map{}
	eg := errgroup.Group{}
	for i, blockProvider := range pc.providers {
		blockIdx := i
		provider := blockProvider
		eg.Go(
			func() error {
				if block, err := provider(); err == nil {
					if block == nil {
						err = fmt.Errorf("block for %+v is nil", blockIdx)
						return pc.errorHandlers[blockIdx](err)
					}
					blockByIdx.Store(blockIdx, block)
				} else {
					return pc.errorHandlers[blockIdx](err)
				}
				return nil
			},
		)
	}
	err := eg.Wait()
	if err != nil {
		return nil, err
	}

	// As blocks were collected in parallel, they must be ordered in accordance with the order of adding providers
	return pc.getSortedBlocks(&blockByIdx), nil
}

func (pc *ParallelCollector) getSortedBlocks(blockByIdx *sync.Map) []interface{} {
	blocks := make([]interface{}, 0, len(pc.providers))
	indices := make([]int, 0, len(pc.providers))
	blockByIdx.Range(
		func(key, value interface{}) bool {
			blocks = append(blocks, value.(renderer.Block))
			indices = append(indices, key.(int))
			return true
		},
	)
	sort.SliceStable(
		indices,
		func(i, j int) bool {
			if indices[i] < indices[j] {
				blocks[i], blocks[j] = blocks[j], blocks[i]
				return true
			}
			return false
		},
	)
	return blocks
}
