package locks

import (
	"context"
	"errors"
	"fmt"
	"time"

	"github.com/ydb-platform/ydb-go-sdk/v3/table"
	"github.com/ydb-platform/ydb-go-sdk/v3/table/options"
	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"

	"a.yandex-team.ru/tasklet/experimental/internal/yandex/xydb"
)

type LockState struct {
	LockedBy       string
	LockedUntil    time.Time
	SequenceNumber uint64
}

type LockID string

const tableLocks = "locker_state"
const tableExcludes = "locker_excludes"

//go:generate mockery --name=LocksRepo --inpackage --case underscore
type LocksRepo interface {
	TryLock(ctx context.Context, id LockID, owner, cluster string, until time.Time) (LockState, error)
	Unlock(ctx context.Context, id LockID, owner string) error

	Exclude(ctx context.Context, id LockID, cluster string, until time.Time) error
	DropExclude(ctx context.Context, id LockID, cluster string) error

	ResetLocksTable() error
	ResetExcludesTable() error
}

// requests in this file are mostly copy-pasted from juggler/server/banshee/lib/locker.cpp
// don't forget to modify them these if you change anything

type ydbLocksRepo struct {
	ydb *xydb.Client
}

func NewLocksRepo(ydb *xydb.Client) LocksRepo {
	return &ydbLocksRepo{ydb}
}

func (repo *ydbLocksRepo) TryLock(ctx context.Context, id LockID, owner, cluster string, until time.Time) (
	LockState,
	error,
) {
	if excluded, err := repo.isExcluded(ctx, id, cluster); err != nil {
		return LockState{}, err
	} else if excluded {
		return repo.readCurrentState(ctx, id)
	}
	query := fmt.Sprintf(
		`--!syntax_v1
	DECLARE $lock_id AS Utf8;
	DECLARE $new_owner AS Utf8;
	DECLARE $lock_until AS Timestamp;
	DECLARE $now AS Timestamp;

	$new_data = (
		SELECT
			$lock_id AS lock_id,
			$new_owner AS locked_by,
			$lock_until AS locked_until
	);

	$prev_data = (
		SELECT locked_by, locked_until, sequence_number
		FROM %s
		WHERE lock_id = $lock_id
	);

	$upsert_data = (
		SELECT
			new.lock_id AS lock_id,
			new.locked_by AS locked_by,
			new.locked_until AS locked_until,
			CASE
				WHEN existing.locked_by = $new_owner THEN existing.sequence_number
				ELSE COALESCE(existing.sequence_number + 1u, 1u)
			END AS sequence_number
		FROM $new_data AS new
		LEFT JOIN %s AS existing
		ON new.lock_id = existing.lock_id
		WHERE (existing.lock_id IS NULL
			OR existing.locked_by IS NULL
			OR existing.locked_until < $now
			OR existing.locked_by = $new_owner)
	);

	UPSERT INTO %s
	SELECT lock_id, locked_by, locked_until, sequence_number FROM $upsert_data;

	SELECT locked_by, locked_until, sequence_number FROM $upsert_data;
	SELECT locked_by, locked_until, sequence_number FROM $prev_data;
	`, tableLocks, tableLocks, tableLocks,
	)

	var ls LockState
	query = repo.ydb.QueryPrefix() + query

	err := repo.ydb.DB.Table().Do(
		ctx,
		func(c context.Context, s table.Session) error {
			_, res, err := s.Execute(
				c, repo.ydb.WriteTxControl, query, table.NewQueryParameters(
					table.ValueParam("$lock_id", types.UTF8Value(string(id))),
					table.ValueParam("$new_owner", types.UTF8Value(owner)),
					table.ValueParam("$lock_until", types.TimestampValueFromTime(until)),
					table.ValueParam("$now", types.TimestampValueFromTime(time.Now())),
				),
			)
			if err != nil {
				return err
			}
			defer res.Close()
			if res.NextResultSet(c) && res.NextRow() {
				return res.ScanWithDefaults(&ls.LockedBy, &ls.LockedUntil, &ls.SequenceNumber)
			} else if res.NextResultSet(c) && res.NextRow() {
				return res.ScanWithDefaults(&ls.LockedBy, &ls.LockedUntil, &ls.SequenceNumber)
			} else {
				return errors.New("no result set parsed")
			}
		},
	)
	return ls, err
}

func (repo *ydbLocksRepo) Unlock(ctx context.Context, id LockID, owner string) error {
	query := fmt.Sprintf(
		`--!syntax_v1
		DECLARE $lock_id AS Utf8;
		DECLARE $owner AS Utf8;

		UPDATE %s
		SET locked_by = NULL, locked_until = NULL
		WHERE lock_id == $lock_id AND locked_by == $owner
	`, tableLocks,
	)
	query = repo.ydb.QueryPrefix() + query
	return repo.ydb.ExecuteWriteQuery(
		ctx,
		query,
		table.ValueParam("$lock_id", types.UTF8Value(string(id))),
		table.ValueParam("$owner", types.UTF8Value(owner)),
	)
}

func (repo *ydbLocksRepo) readCurrentState(ctx context.Context, id LockID) (LockState, error) {
	query := fmt.Sprintf(
		`
	DECLARE $lock_id AS Utf8;

	SELECT locked_by, locked_until, sequence_number
	FROM %s
	WHERE lock_id = $lock_id
	`, tableLocks,
	)
	query = repo.ydb.QueryPrefix() + query
	res, err := repo.ydb.ExecuteReadQuery(
		ctx,
		query,
		table.ValueParam("$lock_id", types.UTF8Value(string(id))),
	)
	if err != nil {
		return LockState{}, err
	}
	if res.NextResultSet(ctx) && res.NextRow() {
		var ls LockState
		err := res.ScanWithDefaults(&ls.LockedBy, &ls.LockedUntil, &ls.SequenceNumber)
		return ls, err
	}
	return LockState{}, res.Err()
}

func (repo *ydbLocksRepo) isExcluded(ctx context.Context, id LockID, cluster string) (bool, error) {
	query := fmt.Sprintf(
		`
	DECLARE $lock_id AS Utf8;
	DECLARE $cluster AS Utf8;
	DECLARE $now AS Timestamp;

	SELECT COUNT(1) as excludes_count
	FROM %s
	WHERE lock_id = $lock_id
		AND cluster = $cluster
		AND until > $now
	`, tableExcludes,
	)
	query = repo.ydb.QueryPrefix() + query
	res, err := repo.ydb.ExecuteReadQuery(
		ctx, query,
		table.ValueParam("$lock_id", types.UTF8Value(string(id))),
		table.ValueParam("$cluster", types.UTF8Value(cluster)),
		table.ValueParam("$now", types.TimestampValueFromTime(time.Now())),
	)
	if err != nil {
		return false, err
	}
	if res.NextResultSet(ctx) && res.NextRow() {
		var count uint64
		if err := res.ScanWithDefaults(&count); err != nil {
			return false, err
		}
		return count != 0, nil
	}
	return false, res.Err()
}

func (repo *ydbLocksRepo) Exclude(ctx context.Context, id LockID, cluster string, until time.Time) error {
	query := fmt.Sprintf(
		`
	DECLARE $lock_id AS Utf8;
	DECLARE $cluster AS Utf8;
	DECLARE $until AS Timestamp;

	UPSERT INTO %s
	(
		lock_id,
		cluster,
		until
	)
	VALUES (
		$lock_id,
		$cluster,
		$until
	)
	`, tableExcludes,
	)
	query = repo.ydb.QueryPrefix() + query
	return repo.ydb.ExecuteWriteQuery(
		ctx,
		query,
		table.ValueParam("$lock_id", types.UTF8Value(string(id))),
		table.ValueParam("$cluster", types.UTF8Value(cluster)),
		table.ValueParam("$until", types.TimestampValueFromTime(until)),
	)
}

func (repo *ydbLocksRepo) DropExclude(ctx context.Context, id LockID, cluster string) error {
	query := fmt.Sprintf(
		`
	DECLARE $lock_id AS Utf8;
	DECLARE $cluster AS Utf8;

	DELETE FROM %s
	WHERE ($cluster = "" OR cluster = $cluster)
		AND ($lock_id = "" OR lock_id = $lock_id)
	`, tableExcludes,
	)
	query = repo.ydb.QueryPrefix() + query
	return repo.ydb.ExecuteWriteQuery(
		ctx, query,
		table.ValueParam("$lock_id", types.UTF8Value(string(id))),
		table.ValueParam("$cluster", types.UTF8Value(cluster)),
	)
}

func (repo *ydbLocksRepo) ResetLocksTable() error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
	defer cancel()
	return repo.ydb.ResetTable(
		ctx,
		tableLocks,
		options.WithColumn("lock_id", types.Optional(types.TypeUTF8)),
		options.WithColumn("locked_by", types.Optional(types.TypeUTF8)),
		options.WithColumn("locked_until", types.Optional(types.TypeTimestamp)),
		options.WithColumn("sequence_number", types.Optional(types.TypeUint64)),
		options.WithPrimaryKeyColumn("lock_id"),
	)
}

func (repo *ydbLocksRepo) ResetExcludesTable() error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
	defer cancel()
	return repo.ydb.ResetTable(
		ctx,
		tableExcludes,
		options.WithColumn("lock_id", types.Optional(types.TypeUTF8)),
		options.WithColumn("cluster", types.Optional(types.TypeUTF8)),
		options.WithColumn("until", types.Optional(types.TypeTimestamp)),
		options.WithPrimaryKeyColumn("lock_id", "cluster"),
		options.WithTimeToLiveSettings(
			options.TimeToLiveSettings{
				ColumnName: "until", ExpireAfterSeconds: 3600 * 24,
			},
		),
	)
}
