package storage

import (
	"context"
	"fmt"
	"path"
	"time"

	"github.com/gofrs/uuid"
	"github.com/opentracing/opentracing-go"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/library/go/vault"
)

const (
	OffersTableName = "train_offers"
)

type Config struct {
	Endpoint  string
	Database  string
	Timeout   time.Duration
	SecretID  string
	SecretKey string
}

var DefaultConfig = Config{
	Timeout:   time.Second,
	SecretID:  "sec-01cjvsaf2hrdkfdr97c30dv3t4",
	SecretKey: "yql-token",
}

type Storage struct {
	Logger      log.Logger
	Config      *Config
	SessionPool *table.SessionPool
	Retryer     *table.Retryer
	ReadTx      *table.TransactionControl
	WriteTx     *table.TransactionControl
}

func NewStorage(ctx context.Context, config *Config, logger log.Logger) (*Storage, error) {
	vaultResolver := vault.NewYavSecretsResolver()
	ydbToken, err := vaultResolver.GetSecretValue(config.SecretID, config.SecretKey)
	if err != nil {
		return nil, fmt.Errorf("get ydb token error: %w", err)
	}
	ydbConf := new(ydb.DriverConfig)
	ydbConf.Credentials = ydb.AuthTokenCredentials{
		AuthToken: ydbToken,
	}
	ydbConf.Database = config.Database
	ydbConf.RequestTimeout = config.Timeout
	ydbConf.StreamTimeout = config.Timeout
	dialer := &ydb.Dialer{
		DriverConfig: ydbConf,
		Timeout:      config.Timeout,
	}
	driver, err := dialer.Dial(ctx, config.Endpoint)
	if err != nil {
		return nil, fmt.Errorf("dial error: %w", err)
	}
	tableClient := table.Client{
		Driver: driver,
	}
	sp := table.SessionPool{
		IdleThreshold:      time.Minute,
		Builder:            &tableClient,
		KeepAliveBatchSize: -1,
		SizeLimit:          3,
	}
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)
	retryer := &table.Retryer{
		MaxRetries:      ydb.DefaultMaxRetries,
		Backoff:         ydb.BackoffFunc(func(n int) <-chan time.Time { return time.After(0) }),
		SessionProvider: &sp,
		RetryChecker:    ydb.RetryChecker{RetryNotFound: true},
	}

	storage := Storage{
		Logger:      logger,
		SessionPool: &sp,
		Config:      config,
		ReadTx:      readTx,
		WriteTx:     writeTx,
		Retryer:     retryer,
	}
	return &storage, nil
}

func (s *Storage) Close(ctx context.Context) error {
	return s.SessionPool.Close(ctx)
}

func (s *Storage) CreateTable(ctx context.Context) error {
	err := s.Retryer.Do(ctx, table.OperationFunc(func(ctx context.Context, session *table.Session) error {
		return session.CreateTable(ctx, path.Join(s.Config.Database, OffersTableName),
			table.WithColumn("token", ydb.Optional(ydb.TypeUTF8)),
			table.WithColumn("created", ydb.Optional(ydb.TypeTimestamp)),
			table.WithColumn("data", ydb.Optional(ydb.TypeString)),
			table.WithPrimaryKeyColumn("token"),
		)
	}),
	)
	if err != nil {
		return fmt.Errorf("ydb create table error: %w", err)
	}
	return nil
}

var insertQuery = fmt.Sprintf(`
DECLARE $token AS "Utf8";
DECLARE $created AS "Timestamp";
DECLARE $data AS "String";

INSERT INTO %s (token, created, data)
VALUES ($token, $created, $data);
`, OffersTableName)

func (s *Storage) Save(ctx context.Context, data []byte, token uuid.UUID) error {
	span, ctx := opentracing.StartSpanFromContext(ctx, "storage.Save")
	defer span.Finish()
	ydbNow := ydb.Time(time.Now())
	err := s.Retryer.Do(ctx, table.OperationFunc(func(ctx context.Context, session *table.Session) (err error) {
		stmt, err := session.Prepare(ctx, insertQuery)
		if err != nil {
			return err
		}
		_, _, err = stmt.Execute(ctx, s.WriteTx, table.NewQueryParameters(
			table.ValueParam("$token", ydb.UTF8Value(token.String())),
			table.ValueParam("$created", ydb.TimestampValue(ydbNow.Timestamp())),
			table.ValueParam("$data", ydb.StringValue(data)),
		))
		return err
	}))
	if err != nil {
		return fmt.Errorf("ydb insert error: %w", err)
	}
	return nil
}

var selectQuery = fmt.Sprintf(`
DECLARE $token AS "Utf8";

SELECT data
FROM %s
WHERE token = $token;
`, OffersTableName)

func (s *Storage) Get(ctx context.Context, token uuid.UUID) (data []byte, err error) {
	span, ctx := opentracing.StartSpanFromContext(ctx, "storage.Get")
	defer span.Finish()
	var res *table.Result
	err = s.Retryer.Do(ctx, table.OperationFunc(func(ctx context.Context, session *table.Session) (err error) {
		stmt, err := session.Prepare(ctx, selectQuery)
		if err != nil {
			return err
		}
		_, res, err = stmt.Execute(ctx, s.ReadTx, table.NewQueryParameters(
			table.ValueParam("$token", ydb.UTF8Value(token.String())),
		))
		return err
	}))
	if err != nil {
		return nil, fmt.Errorf("ydb select error: %w", err)
	}
	for res.NextSet() {
		for res.NextRow() {
			res.SeekItem("data")
			return res.OString(), nil
		}
	}
	return nil, fmt.Errorf("not found offer: %v", token)
}

var selectManyQuery = fmt.Sprintf(`
DECLARE $tokens AS "List<Utf8>";

SELECT data
FROM %s
WHERE token in $tokens;
`, OffersTableName)

func (s *Storage) GetMany(ctx context.Context, tokens []uuid.UUID) (datas [][]byte, err error) {
	span, ctx := opentracing.StartSpanFromContext(ctx, "storage.Get")
	defer span.Finish()
	var res *table.Result
	err = s.Retryer.Do(ctx, table.OperationFunc(func(ctx context.Context, session *table.Session) (err error) {
		stmt, err := session.Prepare(ctx, selectManyQuery)
		if err != nil {
			return err
		}
		paramTokens := make([]ydb.Value, len(tokens))
		for i, t := range tokens {
			paramTokens[i] = ydb.UTF8Value(t.String())
		}
		_, res, err = stmt.Execute(ctx, s.ReadTx, table.NewQueryParameters(
			table.ValueParam("$tokens", ydb.ListValue(paramTokens...)),
		))
		return err
	}))
	if err != nil {
		return nil, fmt.Errorf("ydb select error: %w", err)
	}

	for res.NextSet() {
		for res.NextRow() {
			res.SeekItem("data")
			datas = append(datas, res.OString())

			if res.Err() != nil {
				return nil, fmt.Errorf("ydb read data error: %w", res.Err())
			}
		}
	}

	return datas, nil
}
