package postgres

import (
	"context"
	"database/sql"
	"fmt"
	"strconv"
	"time"

	"github.com/lib/pq"
	"go.uber.org/zap/zapcore"

	db "code.justin.tv/eventbus/controlplane/internal/db"
	dbErr "code.justin.tv/eventbus/controlplane/internal/db"
	"github.com/jmoiron/sqlx"
)

const (
	LeasesTableName  = `aws_leases`
	leaseCreateQuery = `
INSERT INTO aws_leases (expires_at)
VALUES ($1)
RETURNING id`

	leaseDeleteQuery = `
DELETE FROM aws_leases
WHERE id = $1`
)

type PostgresAWSLease struct {
	AWSLeaseID int
	ExpiresAt  time.Time

	ctx        context.Context
	cancelFn   context.CancelFunc
	db         *sqlx.DB
	resourceID int
	table      string
}

func (pg *PostgresDB) AcquirePGLease(parent context.Context, resourceID int, table string, timeout time.Duration) (*PostgresAWSLease, context.Context, error) {
	deadline := time.Now().Add(timeout)
	ctxLease, cancel := context.WithDeadline(parent, deadline)

	tx, err := pg.writer.BeginTxx(ctxLease, nil)
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	var rowID int
	var existingLeaseID sql.NullInt64
	var expiresAt pq.NullTime
	getLeaseQuery := fmt.Sprintf("SELECT t1.id, t1.aws_lease_id, l.expires_at FROM %s AS t1 LEFT JOIN %s AS l ON l.id = t1.aws_lease_id WHERE t1.id = $1", table, LeasesTableName)
	err = tx.QueryRowxContext(ctxLease, getLeaseQuery, resourceID).Scan(&rowID, &existingLeaseID, &expiresAt)
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	if existingLeaseID.Valid && expiresAt.Valid {
		if time.Now().Before(expiresAt.Time) {
			cancel()
			return nil, ctxLease, dbErr.ErrLeaseUnavailable
		}

		// Remove any expired leases
		resourceQuery := fmt.Sprintf("UPDATE %s SET aws_lease_id = NULL WHERE id = $1 AND aws_lease_id = $2", table)
		_, err = tx.ExecContext(ctxLease, resourceQuery, rowID, existingLeaseID.Int64)
		if err != nil {
			cancel()
			return nil, ctxLease, err
		}
		_, err = tx.ExecContext(ctxLease, leaseDeleteQuery, existingLeaseID.Int64)
		if err != nil {
			cancel()
			return nil, ctxLease, err
		}
	}

	var leaseID int
	err = tx.QueryRowxContext(ctxLease, leaseCreateQuery, deadline).Scan(&leaseID)
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	resourceQuery := fmt.Sprintf("UPDATE %s SET aws_lease_id = $1 WHERE id = $2 and aws_lease_id IS NULL", table)
	res, err := tx.ExecContext(ctxLease, resourceQuery, leaseID, resourceID)
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	numRows, err := res.RowsAffected()
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	// this will catch the case in which a aws_lease_id is NOT NULL meaning someone is holding onto the lease
	if numRows == 0 {
		cancel()
		return nil, ctxLease, dbErr.ErrLeaseUnavailable
	}

	err = tx.Commit()
	if err != nil {
		cancel()
		return nil, ctxLease, err
	}

	lease := &PostgresAWSLease{
		AWSLeaseID: leaseID,
		ExpiresAt:  time.Now().Add(timeout),
		ctx:        ctxLease,
		cancelFn:   cancel,
		db:         pg.writer,
		resourceID: resourceID,
		table:      table,
	}

	return lease, ctxLease, nil
}

// XXX *PostgresAWSLease(nil) is considered a non-nil db.AWSLease
func (pg *PostgresDB) Acquire(parent context.Context, resourceID int, table string, timeout time.Duration) (db.AWSLease, context.Context, error) {
	l, ctx, err := pg.AcquirePGLease(parent, resourceID, table, timeout)
	if l == nil {
		return nil, ctx, err
	}
	return l, ctx, err
}

func (l *PostgresAWSLease) Release() error {
	defer l.cancelFn()

	tx, err := l.db.BeginTxx(l.ctx, nil)
	if err != nil {
		return err
	}

	resourceQuery := fmt.Sprintf("UPDATE %s SET aws_lease_id = NULL WHERE id = $1", l.table)
	_, err = tx.ExecContext(l.ctx, resourceQuery, l.resourceID)
	if err != nil {
		return err
	}

	_, err = tx.ExecContext(l.ctx, leaseDeleteQuery, l.AWSLeaseID)
	if err != nil {
		return err
	}

	err = tx.Commit()
	if err != nil {
		return err
	}

	return nil
}

func (l *PostgresAWSLease) Expires() time.Time {
	return l.ExpiresAt
}

func (l *PostgresAWSLease) Expired() bool {
	return time.Now().After(l.ExpiresAt)
}

func (l *PostgresAWSLease) String() string {
	return strconv.Itoa(l.AWSLeaseID)
}

func (l *PostgresAWSLease) MarshalLogObject(enc zapcore.ObjectEncoder) error {
	if l == nil {
		enc.AddString("<nil lease>", "<nil>")
	}
	enc.AddTime("expiresAt", l.Expires())
	enc.AddBool("expired", l.Expired())
	enc.AddString("leaseID", l.String())
	return nil
}
