package cache

import (
	"context"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
)

const (
	tableName = "dns_db"
)

type (
	Client struct {
		ctx          context.Context
		sp           *table.SessionPool
		selectQuery  string
		replaceQuery string
	}
)

func selectQuery(pathPrefix string) string {
	return fmt.Sprintf(`
PRAGMA TablePathPrefix("%s");
DECLARE $key AS String;

SELECT
	created_at, updated_at, data
FROM
	%s
WHERE
	key = $key
LIMIT 1;`, pathPrefix, tableName)
}

func replaceQuery(pathPrefix string) string {
	return fmt.Sprintf(`
PRAGMA TablePathPrefix("%s");

DECLARE $key AS "String";
DECLARE $request AS "String";
DECLARE $created_at AS "Int64";
DECLARE $updated_at AS "Int64";
DECLARE $data AS "String";

REPLACE INTO %s
  (key, request, created_at, updated_at, data)
VALUES
  ($key, $request, $created_at, $updated_at, $data);`, pathPrefix, tableName)
}

func NewClient(ctx context.Context, authToken string, database string, endpoint string) (*Client, error) {
	config := new(ydb.DriverConfig)

	config.Credentials = ydb.AuthTokenCredentials{
		AuthToken: authToken,
	}
	config.Database = database

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	err = createTables(ctx, &sp, database)
	if err != nil {
		return nil, fmt.Errorf("create tables error: %v", err)
	}

	return &Client{
		ctx:          ctx,
		sp:           &sp,
		selectQuery:  selectQuery(database),
		replaceQuery: replaceQuery(database),
	}, nil
}

func (c *Client) Reset() error {
	return c.sp.Close(c.ctx)
}

func (c *Client) LookupRecord(key string) (result *Record, err error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)
	var res *table.Result
	err = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$key", ydb.StringValue([]byte(key))),
			))
			return
		}),
	)

	if err != nil {
		return
	}

	if res.NextSet() {
		if res.NextRow() {
			res.SeekItem("created_at")
			createdAt := res.OInt64()

			res.NextItem()
			updatedAt := res.OInt64()

			res.NextItem()
			data := res.OString()

			result = &Record{
				CreatedAt: time.Unix(createdAt, 0),
				UpdatedAt: time.Unix(updatedAt, 0),
				Data:      data,
			}
		}
	}

	err = res.Err()
	return
}

func (c *Client) InsertRecord(key string, upstreamReq string, record *Record) (err error) {
	// Prepare write transaction.
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	return table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.replaceQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$key", ydb.StringValue([]byte(key))),
				table.ValueParam("$request", ydb.StringValue([]byte(upstreamReq))),
				table.ValueParam("$created_at", ydb.Int64Value(record.CreatedAt.Unix())),
				table.ValueParam("$updated_at", ydb.Int64Value(record.UpdatedAt.Unix())),
				table.ValueParam("$data", ydb.StringValue(record.Data)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) (err error) {
	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, tableName),
				table.WithColumn("key", ydb.Optional(ydb.TypeString)),
				table.WithColumn("request", ydb.Optional(ydb.TypeString)),
				table.WithColumn("created_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("data", ydb.Optional(ydb.TypeString)),
				table.WithPrimaryKeyColumn("key"),
			)
		}),
	)
	return err
}
