package kv

import (
	"context"
	"time"
)

type crossDcClient struct {
	MainClient   Client
	RetryClients []Client
	RetryTimeout time.Duration
}

type getFunc func(ctx context.Context, client Client) (map[string]string, error)

type getResult struct {
	Records map[string]string
	Error   error
}

func (client *crossDcClient) Lookup(ctx context.Context, key string) (*string, error) {
	records, err := client.doWithRetries(ctx, func(ctx context.Context, client Client) (map[string]string, error) {
		return client.LookupMany(ctx, []string{key})
	})

	if err != nil {
		return nil, err
	}

	var value *string
	if _, ok := records[key]; ok {
		value = new(string)
		*value = records[key]

	}
	return value, nil
}

func (client *crossDcClient) LookupMany(ctx context.Context, keys []string) (map[string]string, error) {
	return client.doWithRetries(ctx, func(ctx context.Context, client Client) (map[string]string, error) {
		return client.LookupMany(ctx, keys)
	})
}

func (client *crossDcClient) Select(ctx context.Context, query string) (map[string]string, error) {
	return client.doWithRetries(ctx, func(ctx context.Context, client Client) (map[string]string, error) {
		return client.Select(ctx, query)
	})
}

func NewCrossDcClient(mainClient Client, retryClients []Client, retryTimeout time.Duration) ReadOnlyClient {
	return &crossDcClient{
		MainClient:   mainClient,
		RetryClients: retryClients,
		RetryTimeout: retryTimeout,
	}
}

func (client *crossDcClient) doWithRetries(ctx context.Context, f getFunc) (map[string]string, error) {
	requestCount := 1 + len(client.RetryClients)
	c := make(chan getResult, requestCount)

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	go func() {
		records, err := f(ctx, client.MainClient)
		c <- getResult{records, err}
	}()

	for _, retryClient := range client.RetryClients {
		thisRetryClient := retryClient
		go func() {
			time.Sleep(client.RetryTimeout)

			select {
			case <-ctx.Done():
				c <- getResult{nil, ctx.Err()}
				return
			default:
			}

			records, err := f(ctx, thisRetryClient)
			c <- getResult{records, err}
		}()
	}

	var result getResult

	for ; requestCount > 0; requestCount-- {
		result = <-c
		if result.Error == nil {
			break
		}
	}

	return result.Records, result.Error
}
