package models

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"sync"
	"time"

	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/library/go/core/log"
	zsql "a.yandex-team.ru/zootopia/library/go/db/sql"
)

const (
	// LeaderLockName represents leader lock name.
	LeaderLockName = "leader"
)

// Lock represents distributed lock.
type Lock struct {
	ID       int    `db:"id"`
	Name     string `db:"name"`
	HostID   NInt   `db:"host_id"`
	PingTime int64  `db:"ping_time"`
}

const (
	lockTimeout     = 30 * time.Second
	safeLockTimeout = lockTimeout * 2 / 3
)

func (l Lock) pingTime() time.Time {
	return time.Unix(l.PingTime, 0)
}

func (l Lock) endTime() time.Time {
	return l.pingTime().Add(lockTimeout)
}

func (l Lock) safeEndTime() time.Time {
	return l.pingTime().Add(safeLockTimeout)
}

// LockStore represents store for distributed locks.
type LockStore struct {
	db     *gosql.DB
	table  string
	logger log.Logger
}

// GetTx returns lock by name.
func (s *LockStore) GetTx(tx gosql.Runner, name string) (Lock, error) {
	row := tx.QueryRow(
		fmt.Sprintf(
			`SELECT "id", "name", "host_id", "ping_time"`+
				` FROM %q WHERE "name" = $1 LIMIT 1`,
			s.table,
		),
		name,
	)
	var lock Lock
	if err := row.Scan(
		&lock.ID, &lock.Name, &lock.HostID, &lock.PingTime,
	); err != nil {
		return Lock{}, err
	}
	return lock, nil
}

var (
	ErrLockAcquired = errors.New("lock already acquired")
	ErrLockReleased = errors.New("lock already released")
)

// WithLock runs function with specified lock on specified host.
//
// If lock is already acquired you will get ErrLockAcquired error.
func (s *LockStore) WithLock(
	ctx context.Context, name string, host int,
	fn func(context.Context) error,
) error {
	// First of all we should try to acquire lock.
	lock, err := s.acquire(ctx, name, host)
	if err != nil {
		return err
	}
	// At the end of call we should release lock.
	defer func() {
		if err := s.release(lock); err != nil {
			s.logger.Error(
				"Unable to release lock",
				log.String("lock_name", name),
				log.Int("host_id", host),
				log.Error(err),
			)
		}
	}()
	// Now we should create child context that will be cancelled
	// on timeout or when function successfully finished.
	lockCtx, cancel := context.WithCancel(ctx)
	var waiter sync.WaitGroup
	waiter.Add(1)
	// We should hang function until lock will be released.
	defer waiter.Wait()
	// Run ping loop in goroutine.
	go func() {
		// When function exits we should notify waiter about it.
		defer waiter.Done()
		// Start ticker for ping loop.
		ticker := time.NewTicker(time.Second)
		// After end ticker should be stopped.
		defer ticker.Stop()
		// We should always cancel function execution.
		defer cancel()
		for {
			select {
			case <-lockCtx.Done():
				return
			case <-ticker.C:
				// If we reached safeLockTimeout we should cancel
				// execution of our function.
				if time.Now().After(lock.safeEndTime()) {
					s.logger.Error(
						"Lock is lost after safeLockTimeout",
						log.String("lock_name", name),
						log.Int("host_id", host),
					)
					return
				}
				// Try to ping current lock.
				if err := s.ping(lockCtx, &lock); err != nil {
					if err != sql.ErrTxDone {
						s.logger.Warn(
							"Unable to ping lock",
							log.String("lock_name", name),
							log.Int("host_id", host),
							log.Error(err),
						)
					}
					if err == ErrLockReleased {
						// Lock is already released so we should exit loop.
						return
					}
				}
			}
		}
	}()
	// We should cancel ping loop when function is ends.
	defer cancel()
	// Now we can run our function with child context.
	return fn(lockCtx)
}

// acquire tries to acquire lock right now.
//
// If lock can not be acquired, you will get ErrLockAcquired.
func (s *LockStore) acquire(
	ctx context.Context, name string, host int,
) (Lock, error) {
	lock := Lock{
		Name:     name,
		HostID:   NInt(host),
		PingTime: time.Now().Unix(),
	}
	if err := zsql.WithTxContext(ctx, s.db, nil, func(tx *sql.Tx) error {
		// Before we will try to acquire lock, we should check
		// that we can do it right now.
		state, err := s.GetTx(tx, name)
		if err != nil {
			return err
		}
		if !lock.pingTime().After(state.endTime()) {
			return ErrLockAcquired
		}
		values := []interface{}{
			lock.HostID, lock.PingTime, state.ID, state.PingTime,
		}
		if state.HostID.has() {
			values = append(values, state.HostID)
		}
		// We pass all lock parameters to avoid race conditions.
		result, err := tx.Exec(
			fmt.Sprintf(
				`UPDATE %q SET "host_id" = $1, "ping_time" = $2`+
					` WHERE "id" = $3 AND "ping_time" = $4`+
					` AND "host_id" %s`,
				s.table, state.HostID.eqOp(5),
			),
			values...,
		)
		if err != nil {
			return err
		}
		affected, err := result.RowsAffected()
		if err != nil {
			return err
		}
		// If no rows affected we should understand that lock
		// is not acquired because someone else done this before us.
		if affected != 1 {
			return ErrLockAcquired
		}
		// All ok, we successfully acquired the lock.
		lock.ID = state.ID
		return nil
	}); err != nil {
		return Lock{}, err
	}
	return lock, nil
}

// release tries to release lock immediately.
func (s *LockStore) release(lock Lock) error {
	ctx, cancel := context.WithDeadline(context.Background(), lock.endTime())
	defer cancel()
	return zsql.WithTxContext(ctx, s.db, nil, func(tx *sql.Tx) error {
		values := []interface{}{lock.ID, lock.PingTime}
		if lock.HostID.has() {
			values = append(values, lock.HostID)
		}
		result, err := tx.Exec(
			fmt.Sprintf(
				`UPDATE %q SET "host_id" = NULL, "ping_time" = 0`+
					` WHERE "id" = $1 AND "ping_time" = $2`+
					` AND "host_id" %s`,
				s.table, lock.HostID.eqOp(3),
			),
			values...,
		)
		if err != nil {
			return err
		}
		affected, err := result.RowsAffected()
		if err != nil {
			return err
		}
		if affected != 1 {
			return ErrLockReleased
		}
		return nil
	})
}

// ping tries to update lock PingTime.
//
// On failure lock will be not changed.
func (s *LockStore) ping(ctx context.Context, lock *Lock) error {
	ctx, cancel := context.WithDeadline(ctx, lock.safeEndTime())
	defer cancel()
	return zsql.WithTxContext(ctx, s.db, nil, func(tx *sql.Tx) error {
		now := time.Now().Unix()
		values := []interface{}{now, lock.ID, lock.PingTime}
		if lock.HostID.has() {
			values = append(values, lock.HostID)
		}
		result, err := tx.Exec(
			fmt.Sprintf(
				`UPDATE %q SET "ping_time" = $1`+
					` WHERE "id" = $2 AND "ping_time" = $3`+
					` AND "host_id" %s`,
				s.table, lock.HostID.eqOp(4),
			),
			values...,
		)
		if err != nil {
			return err
		}
		affected, err := result.RowsAffected()
		if err != nil {
			return err
		}
		if affected != 1 {
			return ErrLockReleased
		}
		lock.PingTime = now
		return nil
	})
}

// NewLockStore creates a new instance of lock store.
func NewLockStore(
	db *gosql.DB, table string, logger log.Logger,
) *LockStore {
	return &LockStore{
		db:     db,
		table:  table,
		logger: logger,
	}
}
