package searchresultscache

import (
	"context"
	"fmt"
	"time"

	"github.com/opentracing/opentracing-go"

	ydbLib "a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/library/go/searchcontext"
)

type SearchResultsCache struct {
	logger                log.Logger
	sessionPool           *table.SessionPool
	transactionControl    *table.TransactionControl
	allVariantsByQIDQuery string
	preparedSessionsCount int
	retryer               *table.Retryer
}

type Config struct {
	Token    string `config:"YDB_TOKEN,required" yaml:"token"`
	Cluster  string `config:"SEARCH_RESULTS_CACHE_YDB_CLUSTER,required" yaml:"cluster"`
	Database string `config:"SEARCH_RESULTS_CACHE_YDB_DATABASE,required" yaml:"database"`
	Table    string `config:"SEARCH_RESULTS_CACHE_YDB_TABLE,required" yaml:"table"`

	PreparedSessionsCount int `config:"SEARCH_RESULTS_CACHE_YDB_PREPARED_SESSIONS_COUNT,required" yaml:"prepared_sessions_count"`
}

var DefaultConfig = Config{
	Cluster:               "ydb-ru-prestable.yandex.net:2135",
	Database:              "/ru-prestable/ticket/testing/search_results",
	Table:                 "results",
	PreparedSessionsCount: 10,
}

func NewSearchResultsCache(logger log.Logger, sessionPool *table.SessionPool, config Config) (*SearchResultsCache, error) {
	readOnlyTransactionControl := table.TxControl(
		table.BeginTx(table.WithStaleReadOnly()),
		table.CommitTx(),
	)
	cache := &SearchResultsCache{
		logger:      logger,
		sessionPool: sessionPool,
		retryer: &table.Retryer{
			MaxRetries:      3,
			Backoff:         ydbLib.DefaultBackoff,
			SessionProvider: sessionPool,
			RetryChecker:    ydbLib.RetryChecker{RetryNotFound: false},
		},
		transactionControl:    readOnlyTransactionControl,
		allVariantsByQIDQuery: fmt.Sprintf(allVariantsByQIDQuery, config.Table),
		preparedSessionsCount: config.PreparedSessionsCount,
	}
	err := cache.fillSessionPool()
	if err != nil {
		return nil, err
	}
	return cache, nil
}

const (
	allVariantsByQIDQuery = `
		DECLARE $point_from AS Utf8;
		DECLARE $point_to AS Utf8;
		DECLARE $date_forward AS Uint32;
		DECLARE $date_backward AS Uint32;
		DECLARE $klass AS Uint8;
		DECLARE $passengers AS Uint32;
		DECLARE $national_version AS Utf8;
		DECLARE $lang AS Utf8;
		DECLARE $unixtime AS Uint32;

		SELECT
			point_from,
			point_to,
			date_forward,
			date_backward,
			klass,
			passengers,
			national_version,
			lang,
			partner_code,
			meta,
			variants,
			created_at,
			expires_at
		FROM
			%[1]s
		WHERE
			point_from = $point_from AND
			point_to = $point_to AND
			date_forward = $date_forward AND
			date_backward = $date_backward AND
			klass = $klass AND
			passengers = $passengers AND
			national_version = $national_version AND
			lang = $lang AND
			expires_at > $unixtime;
	`
)

var errFillSessionPool = fmt.Errorf("couldn't fill YDB sessions pool")

func (cache *SearchResultsCache) GetSearchResultsByQueryKey(ctx context.Context, queryKey searchcontext.QKey) ([]*SearchResult, error) {
	requestSpan, ctx := opentracing.StartSpanFromContext(ctx, "YDB request")
	dateForward := buildDate(queryKey.DateForward)
	dateBackward := buildDate(queryKey.DateBackward)
	class := buildClass(queryKey.Class)
	unixtime := uint32(time.Now().UTC().Unix())
	const lang = "any"
	passengers := buildPassengers(queryKey.Adults, queryKey.Children, queryKey.Infants)
	requestSpan.SetTag("point_from", queryKey.PointFromKey)
	requestSpan.SetTag("point_to", queryKey.PointToKey)
	requestSpan.SetTag("date_forward", dateForward)
	requestSpan.SetTag("date_backward", dateBackward)
	requestSpan.SetTag("klass", queryKey.Class)
	requestSpan.SetTag("passengers", passengers)
	requestSpan.SetTag("national_version", queryKey.NationalVersion)
	requestSpan.SetTag("lang", lang)
	requestSpan.SetTag("unixtime", unixtime)
	defer requestSpan.Finish()
	statement := cache.allVariantsByQIDQuery
	params := table.NewQueryParameters(
		table.ValueParam("$point_from", ydbLib.UTF8Value(queryKey.PointFromKey)),
		table.ValueParam("$point_to", ydbLib.UTF8Value(queryKey.PointToKey)),
		table.ValueParam("$date_forward", ydbLib.Uint32Value(dateForward)),
		table.ValueParam("$date_backward", ydbLib.Uint32Value(dateBackward)),
		table.ValueParam("$klass", ydbLib.Uint8Value(class)),
		table.ValueParam("$passengers", ydbLib.Uint32Value(passengers)),
		table.ValueParam("$national_version", ydbLib.UTF8Value(queryKey.NationalVersion)),
		table.ValueParam("$lang", ydbLib.UTF8Value(lang)),
		table.ValueParam("$unixtime", ydbLib.Uint32Value(unixtime)),
	)
	res, err := cache.performRequest(ctx, statement, params)
	if err != nil {
		return nil, err
	}
	return cache.readSearchResults(res, ctx)
}

func buildPassengers(adults uint8, children uint8, infants uint8) uint32 {
	return uint32(adults*100 + children*10 + infants)
}

func buildClass(class string) uint8 {
	if class == "economy" {
		return 1
	}
	return 2
}

func buildDate(date time.Time) uint32 {
	const secondsInDay = 60 * 60 * 24
	if date.IsZero() {
		return 0
	}
	return uint32(date.Unix() / secondsInDay)
}

func (cache *SearchResultsCache) performRequest(
	ctx context.Context,
	statement string,
	operationParams *table.QueryParameters,
) (*table.Result, error) {
	var res *table.Result
	err := cache.retryer.Do(
		ctx,
		table.OperationFunc(
			func(ctx context.Context, session *table.Session) (err error) {
				requestSpan, ctx := opentracing.StartSpanFromContext(ctx, "Performing YDB request")
				defer requestSpan.Finish()
				prepareStatementSpan, _ := opentracing.StartSpanFromContext(ctx, "Prepare statement")
				statement, err := session.Prepare(ctx, statement)
				prepareStatementSpan.Finish()
				if err != nil {
					return fmt.Errorf("couldn't prepare ydb statement: %w", err)
				}
				executeStatementSpan, _ := opentracing.StartSpanFromContext(ctx, "Execute statement")
				defer executeStatementSpan.Finish()
				executionStart := time.Now()
				_, res, err = statement.Execute(ctx, cache.transactionControl, operationParams)
				cache.logger.Info("ydb request execution time", log.Int64("durationMs", time.Since(executionStart).Milliseconds()))
				if err != nil {
					return fmt.Errorf("couldn't execute ydb statement: %w", err)
				}
				return
			},
		),
	)
	return res, err
}

func (cache *SearchResultsCache) readSearchResults(res *table.Result, ctx context.Context) (searchResults []*SearchResult, err error) {
	span, _ := opentracing.StartSpanFromContext(ctx, "Reading YDB results")
	defer span.Finish()
	start := time.Now()
	for res.NextSet() {
		for res.NextRow() {
			searchResult := NewSearchResult(ctx)
			if err := searchResult.Scan(res); err != nil {
				return nil, err
			}
			if searchResult.Variants.Error != nil {
				continue
			}
			if searchResult.Meta.Error != nil {
				continue
			}
			searchResults = append(searchResults, searchResult)
		}
	}
	cache.logger.Info("ydb response scan", log.Int64("durationMs", time.Since(start).Milliseconds()))

	return searchResults, nil
}

func (cache *SearchResultsCache) fillSessionPool() error {
	sessionPoolPreparingContext, cancelFunc := context.WithTimeout(context.Background(), 2*time.Minute)
	defer cancelFunc()
	remainedAttempts := 4 * cache.preparedSessionsCount
	queries := []string{cache.allVariantsByQIDQuery}

	for successfullyPreparedSessions := 0; successfullyPreparedSessions < cache.preparedSessionsCount; {
		session, err := cache.sessionPool.Create(sessionPoolPreparingContext)
		if err != nil {
			cache.logger.Error("an error occurred during creating session", log.Error(err))
			remainedAttempts--
			if remainedAttempts == 0 {
				return errFillSessionPool
			}
			continue
		}

		remainedAttempts = cache.prepareStatements(sessionPoolPreparingContext, session, queries, remainedAttempts)
		if remainedAttempts == 0 {
			return errFillSessionPool
		}
		successfullyPreparedSessions++
	}
	cache.logger.Info("ydb session pool has been filled", log.Int("sessionsCount", cache.preparedSessionsCount))
	return nil
}

func (cache *SearchResultsCache) prepareStatements(
	ctx context.Context,
	session *table.Session,
	queries []string,
	remainedAttempts int,
) int {
	for _, query := range queries {
		for remainedAttempts > 0 {
			if _, err := session.Prepare(ctx, query); err != nil {
				cache.logger.Error("an error occurred during preparation query", log.String("query", query), log.Error(err))
				remainedAttempts--
			} else {
				break
			}
		}
	}
	return remainedAttempts
}
