package cacher

import (
	"context"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/security/libs/go/ydbtvm"
)

var ErrNotFound = xerrors.New("record not found")

type (
	Cacher struct {
		ctx                 context.Context
		sp                  *table.SessionPool
		selectEntryQuery    string
		upsertEntryQuery    string
		cleanupEntriesQuery string
	}

	Options struct {
		Database string
		Path     string
		Endpoint string
	}

	Entry struct {
		Source  string
		Hash    string
		Epoch   string
		Results []byte
	}
)

func New(ctx context.Context, tvmClient tvm.Client, opts Options) (*Cacher, error) {
	config := &ydb.DriverConfig{
		Database: opts.Database,
		Credentials: &ydbtvm.TvmCredentials{
			DstID:     ydbtvm.YDBClientID,
			TvmClient: tvmClient,
		},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, opts.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	ydbPath := path.Join(opts.Database, opts.Path)
	err = createTables(ctx, &sp, ydbPath)
	if err != nil {
		return nil, fmt.Errorf("can't create tables: %v", err)
	}

	return &Cacher{
		ctx:                 ctx,
		sp:                  &sp,
		selectEntryQuery:    selectEntryQuery(ydbPath),
		upsertEntryQuery:    upsertEntryQuery(ydbPath),
		cleanupEntriesQuery: cleanupEntriesQuery(ydbPath),
	}, nil
}

func (c *Cacher) Close(ctx context.Context) error {
	return c.sp.Close(ctx)
}

func (c *Cacher) LookupEntry(ctx context.Context, source, hash string) ([]byte, error) {
	tx := table.TxControl(
		table.BeginTx(
			table.WithStaleReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectEntryQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, tx, table.NewQueryParameters(
				table.ValueParam("$source", ydb.UTF8Value(source)),
				table.ValueParam("$hash", ydb.UTF8Value(hash)),
			))
			return
		}),
	)

	if err != nil {
		return nil, err
	}

	if !res.NextSet() || !res.NextRow() {
		return nil, ErrNotFound
	}

	res.SeekItem("results")
	return res.OString(), res.Err()
}

func (c *Cacher) UpsertEntry(ctx context.Context, entry Entry) error {
	tx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	return table.Retry(ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.upsertEntryQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, tx, table.NewQueryParameters(
				table.ValueParam("$source", ydb.UTF8Value(entry.Source)),
				table.ValueParam("$hash", ydb.UTF8Value(entry.Hash)),
				table.ValueParam("$epoch", ydb.UTF8Value(entry.Epoch)),
				table.ValueParam("$results", ydb.StringValue(entry.Results)),
			))
			return err
		}),
	)
}

func (c *Cacher) CleanupEntries(ctx context.Context, source, epoch string) error {
	tx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	return table.Retry(ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.cleanupEntriesQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, tx, table.NewQueryParameters(
				table.ValueParam("$source", ydb.UTF8Value(source)),
				table.ValueParam("$epoch", ydb.UTF8Value(epoch)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	err := table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, fmt.Sprintf("%s/%s", prefix, "yaudit_cache"),
				table.WithColumn("source", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("hash", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("epoch", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("results", ydb.Optional(ydb.TypeString)),
				table.WithPrimaryKeyColumn("source", "hash"),
			)
		}),
	)
	if err != nil {
		return xerrors.Errorf("failed to create yaudit_cache table: %w", err)
	}

	return nil
}
