package status

import (
	"context"
	"sort"
	"time"

	"github.com/cenkalti/backoff/v4"
	"github.com/jackc/pgx/v4"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/direction"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/logger"
	"a.yandex-team.ru/travel/avia/shared_flights/status_importer/internal/objects"
	"a.yandex-team.ru/travel/avia/shared_flights/status_importer/pkg/logging/yt/updatestatuslog"
	"a.yandex-team.ru/travel/library/go/errutil"
)

type BatchSender interface {
	SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults
}

type tStatusWriter struct {
	batchSender      BatchSender
	objects          *objects.Objects
	minibatchTimeout time.Duration
}

func (sw tStatusWriter) Write(ctx context.Context, ss []tStatusWithArguments, updateResults UpdateResults) {
	N := len(ss)
	const GroupSize = 256
	groups := (N + GroupSize - 1) / GroupSize
	for i := 0; i < groups; i++ {
		sliceStart := i * GroupSize
		sliceStop := (i + 1) * GroupSize
		if sliceStop > N {
			sliceStop = N
		}
		sw.write(ctx, ss[sliceStart:sliceStop], updateResults)
		logger.Logger().Debugf("Done %d/%d", sliceStop, N)
	}
}

func (sw tStatusWriter) write(ctx context.Context, ss []tStatusWithArguments, updateResults UpdateResults) {
	// use map to keep track of and easily delete written statuses
	statuses := make(map[int]*tStatusWithArguments, len(ss))
	for i := range ss {
		statuses[i] = &ss[i]
	}

	written := make(Statuses, 0, len(ss))
	ignoredByUpdated := make(Statuses, 0, len(ss))
	errorStatuses := make(Statuses, 0, len(ss))
	MaxAttempts := 16
	exponentialBackOff := backoff.NewExponentialBackOff()
	for attempt := 0; len(statuses) > 0 && attempt < MaxAttempts; attempt++ {
		if attempt > 0 {
			logger.Logger().Warn("Failed to write some statuses. Sleeping and retrying", log.Int("attempt", attempt))
			exponentialBackOff.NextBackOff()
		}

		batch := &pgx.Batch{}
		queuedStatuses := sw.queueStatuses(batch, statuses)
		if len(queuedStatuses) == 0 {
			continue
		}
		if len(queuedStatuses) != batch.Len() {
			sw.logStatusCountMismatchError(queuedStatuses, batch)
			continue
		}

		timeoutCtx, cancel := context.WithTimeout(ctx, sw.minibatchTimeout)
		if err := sw.executeBatch(timeoutCtx, batch, queuedStatuses, statuses, &written, &ignoredByUpdated); err != nil {
			errorOrWarn(MaxAttempts-attempt <= 1, "Execution error", log.Int("attempt", attempt), log.Error(err))
		}
		cancel()
	}

	for _, s := range statuses {
		errorStatuses = append(errorStatuses, s.status)
	}

	updateResults.appendStatuses(updatestatuslog.Success, written)
	updateResults.appendStatuses(updatestatuslog.IgnoredByUpdatedAt, ignoredByUpdated)
	updateResults.appendStatuses(updatestatuslog.Error, errorStatuses)
}

func (sw tStatusWriter) executeBatch(
	ctx context.Context,
	batch *pgx.Batch,
	queuedStatuses []int,
	statuses map[int]*tStatusWithArguments,
	written *Statuses,
	ignoredByUpdated *Statuses,
) (err error) {
	defer errutil.Wrap(&err, "executeBatch")
	br := sw.batchSender.SendBatch(ctx, batch)
	defer func() {
		if brErr := br.Close(); brErr != nil {
			err = errOrGroup(err, xerrors.Errorf("batch close: %w", brErr))
		}
	}()

	for _, i := range queuedStatuses {
		status, exists := statuses[i]
		if status == nil || !exists {
			logger.Logger().Error("Programming error: got empty status from map")
			continue
		}
		tag, execErr := br.Exec()
		if execErr != nil {
			err = errOrGroup(err, xerrors.Errorf("exec: %w", execErr))
			continue
		}
		if tag.RowsAffected() > 0 {
			*written = append(*written, status.status)
		} else {
			*ignoredByUpdated = append(*ignoredByUpdated, status.status)
		}
		delete(statuses, i)
	}
	return
}

func (sw tStatusWriter) queueStatuses(batch *pgx.Batch, statuses map[int]*tStatusWithArguments) (queuedStatuses []int) {
	for _, i := range orderedKeys(statuses) {
		status := statuses[i]
		sw.queueStatus(batch, status)
		queuedStatuses = append(queuedStatuses, i)
	}
	return queuedStatuses
}

func (sw tStatusWriter) queueStatus(batch *pgx.Batch, s *tStatusWithArguments) {
	var (
		sql string
		dir direction.Direction
	)
	dir, _ = direction.FromString(s.status.Direction) // can safely ignore, because it's validated
	_, sql = directionStatement(dir)
	batch.Queue(sql, s.arguments...)
}

func orderedKeys(m map[int]*tStatusWithArguments) []int {
	keys := make([]int, 0, len(m))
	for k := range m {
		keys = append(keys, k)
	}
	sort.Ints(keys)
	return keys
}

func errorOrWarn(errNotWarn bool, msg string, fields ...log.Field) {
	l := logger.Logger().AddCallerSkip(1)
	if errNotWarn {
		l.Error(msg, fields...)
	} else {
		l.Warn(msg, fields...)
	}
}

func (sw tStatusWriter) logStatusCountMismatchError(queuedStatuses []int, batch *pgx.Batch) {
	logger.Logger().Fatal(
		"Programming error: incorrect queued status count",
		log.Int("batch.Len()", batch.Len()),
		log.Int("ours", len(queuedStatuses)),
	)

}
