package selectrows

import (
	"context"
	"fmt"
	"sync"
	"time"

	"go.uber.org/atomic"
	"golang.org/x/xerrors"

	"a.yandex-team.ru/passport/infra/tools/ytbench/internal/config"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/yt/go/yt/ythttp"
)

type params struct {
	threads  uint32
	duration uint32
	query    string
}

type threadsCtx struct {
	stop      *atomic.Bool
	wgInited  *sync.WaitGroup
	wgStopped *sync.WaitGroup
	start     chan bool

	query string

	succed   *atomic.Uint64
	errors   *atomic.Uint64
	duration *atomic.Duration
}

func run(cfg *config.Config, prms *params) error {
	if err := runOnce(cfg, prms); err != nil {
		return err
	}

	ctx := &threadsCtx{
		stop:      atomic.NewBool(false),
		wgInited:  &sync.WaitGroup{},
		wgStopped: &sync.WaitGroup{},
		start:     make(chan bool),
		query:     prms.query,
		succed:    atomic.NewUint64(0),
		errors:    atomic.NewUint64(0),
		duration:  atomic.NewDuration(0),
	}

	ctx.wgInited.Add(int(prms.threads))
	ctx.wgStopped.Add(int(prms.threads))
	for idx := uint32(0); idx < prms.threads; idx++ {
		go runThread(cfg, ctx, prms)
	}
	ctx.wgInited.Wait()
	close(ctx.start)
	fmt.Println("")

	collectData(ctx, prms)

	ctx.stop.Store(true)
	ctx.wgStopped.Wait()
	fmt.Println("")

	fmt.Printf("Avg per second: %s\n",
		formatMetrics(
			ctx.succed.Load()/uint64(prms.duration),
			ctx.errors.Load()/uint64(prms.duration),
			ctx.duration.Load()/time.Duration(prms.duration),
		))

	return nil
}

func collectData(ctx *threadsCtx, prms *params) {
	spent := uint32(0)
	heartbeat := time.NewTicker(1 * time.Second)

	prevSucced := ctx.succed.Load()
	prevErrors := ctx.errors.Load()
	prevDuration := ctx.duration.Load()

	for range heartbeat.C {
		succed := ctx.succed.Load()
		errors := ctx.errors.Load()
		dur := ctx.duration.Load()

		fmt.Printf("Per second: %s\n",
			formatMetrics(
				succed-prevSucced,
				errors-prevErrors,
				dur-prevDuration,
			))
		prevSucced = succed
		prevErrors = errors
		prevDuration = dur

		spent++
		if spent >= prms.duration {
			return
		}
	}
}

func formatMetrics(succed uint64, errors uint64, dur time.Duration) string {
	sum := succed + errors
	d := time.Duration(0)
	if sum > 0 {
		d = dur / time.Duration(sum)
	}

	return fmt.Sprintf(
		"requests: %d (succed: %d, errors: %d); avg duration: %s",
		sum,
		succed,
		errors,
		d,
	)
}

func runOnce(cfg *config.Config, prms *params) error {
	fmt.Print("Checking connection...")
	yc, err := initClient(cfg)
	if err != nil {
		fmt.Println()
		return err
	}

	if err := doQuery(yc, prms); err != nil {
		fmt.Println()
		return err
	}

	fmt.Println(" OK")
	return nil
}

func initClient(cfg *config.Config) (yt.Client, error) {
	yc, err := ythttp.NewClient(&yt.Config{
		Proxy: cfg.Cluster,
		Token: cfg.OAuthToken,
	})
	if err != nil {
		return nil, xerrors.Errorf("failed to create YT client: %s", err)
	}

	return yc, nil
}

func runThread(cfg *config.Config, ctx *threadsCtx, prms *params) {
	yc, err := initClient(cfg)
	if err != nil {
		panic(err)
	}

	ctx.wgInited.Done()

	<-ctx.start

	for !ctx.stop.Load() {
		start := time.Now()
		if err := doQuery(yc, prms); err != nil {
			ctx.errors.Add(1)
		} else {
			ctx.succed.Add(1)
		}
		ctx.duration.Add(time.Since(start))
	}

	ctx.wgStopped.Done()
}

func doQuery(yc yt.Client, prms *params) error {
	options := &yt.SelectRowsOptions{}
	reader, err := yc.SelectRows(context.Background(), prms.query, options)
	if err != nil {
		return xerrors.Errorf("failed to get create reader from YT: %s", err)
	}
	defer func() { _ = reader.Close() }()

	for reader.Next() {
		m := map[string]interface{}{}
		if err := reader.Scan(&m); err != nil {
			return err
		}
	}

	return reader.Err()
}
