package stores

import (
	"context"
	"database/sql"
	"fmt"

	"a.yandex-team.ru/drive/library/go/gosql"
)

type txKey struct{}

func WithTx(ctx context.Context, tx *sql.Tx) context.Context {
	return context.WithValue(ctx, txKey{}, tx)
}

func GetTx(ctx context.Context) *sql.Tx {
	tx, ok := ctx.Value(txKey{}).(*sql.Tx)
	if ok {
		return tx
	}
	return nil
}

type RowReader[T any] interface {
	Next() bool
	Row() T
	Err() error
	Close() error
}

type ObjectPtr[T any, ID any] interface {
	*T
	ObjectID() ID
	SetObjectID(ID)
}

type SelectOption func(gosql.SelectQuery) gosql.SelectQuery

func WithWhere(where gosql.BoolExpr) SelectOption {
	return func(query gosql.SelectQuery) gosql.SelectQuery {
		return query.Where(where)
	}
}

func WithLimit(limit int) SelectOption {
	return func(query gosql.SelectQuery) gosql.SelectQuery {
		return query.Limit(limit)
	}
}

type ObjectStore[T any, ID any, PT ObjectPtr[T, ID]] interface {
	CreateObject(ctx context.Context, object PT) error
	UpdateObject(ctx context.Context, object PT) error
	DeleteObject(ctx context.Context, id ID) error
	SelectObjects(ctx context.Context, options ...SelectOption) (RowReader[T], error)
}

type simpleObjectStore[T any, PT ObjectPtr[T, int64]] struct {
	db      *gosql.DB
	table   string
	id      string
	columns []string
}

func (s *simpleObjectStore[T, PT]) LockStore(ctx context.Context) error {
	tx := GetTx(ctx)
	if tx == nil {
		return fmt.Errorf("cannot lock store without transaction")
	}
	switch s.db.Driver {
	case gosql.PostgresDriver:
		_, err := tx.ExecContext(ctx, fmt.Sprintf("LOCK TABLE %q", s.table))
		return err
	default:
		return nil
	}
}

func (s *simpleObjectStore[T, PT]) CreateObject(ctx context.Context, object PT) error {
	cols, vals := gosql.StructNameValues(object, false, s.id)
	builder := s.db.Insert(s.table).Names(cols...).Values(vals...)
	switch s.db.Driver {
	case gosql.PostgresDriver:
		res := getRunner(ctx, s.db.DB).QueryRowContext(
			ctx,
			builder.String()+fmt.Sprintf(" RETURNING %q", s.id),
			vals...,
		)
		var id int64
		if err := res.Scan(&id); err != nil {
			return err
		}
		object.SetObjectID(id)
		return nil
	default:
		res, err := getRunner(ctx, s.db.DB).ExecContext(
			ctx, builder.String(), vals...,
		)
		if err != nil {
			return err
		}
		count, err := res.RowsAffected()
		if err != nil {
			return err
		}
		if count != 1 {
			return fmt.Errorf("invalid amount of affected rows: %d", count)
		}
		id, err := res.LastInsertId()
		if err != nil {
			return err
		}
		object.SetObjectID(id)
		return nil
	}
}

func (s *simpleObjectStore[T, PT]) UpdateObject(ctx context.Context, object PT) error {
	cols, vals := gosql.StructNameValues(object, false, s.id)
	builder := s.db.Update(s.table).Names(cols...).Values(vals...).
		Where(gosql.Column(s.id).Equal(object.ObjectID()))
	query, values := builder.Build()
	res, err := getRunner(ctx, s.db.DB).ExecContext(ctx, query, values...)
	if err != nil {
		return err
	}
	count, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if count < 1 {
		return sql.ErrNoRows
	} else if count > 1 {
		return fmt.Errorf("updated %d objects", count)
	}
	return nil
}

func (s *simpleObjectStore[T, PT]) DeleteObject(ctx context.Context, id int64) error {
	builder := s.db.Delete(s.table).Where(gosql.Column(s.id).Equal(id))
	query, values := builder.Build()
	res, err := getRunner(ctx, s.db.DB).ExecContext(ctx, query, values...)
	if err != nil {
		return err
	}
	count, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if count < 1 {
		return sql.ErrNoRows
	} else if count > 1 {
		return fmt.Errorf("deleted %d objects", count)
	}
	return nil
}

func (s *simpleObjectStore[T, PT]) SelectObjects(ctx context.Context, options ...SelectOption) (RowReader[T], error) {
	builder := s.db.Select(s.table).Names(s.columns...).OrderBy(s.id)
	for _, option := range options {
		builder = option(builder)
	}
	query, values := builder.Build()
	rows, err := getRunner(ctx, s.db.RO).QueryContext(ctx, query, values...)
	if err != nil {
		return nil, err
	}
	if err := checkColumns(rows, s.columns); err != nil {
		return nil, err
	}
	return newRowReader[T](rows), nil
}

func (s *simpleObjectStore[T, PT]) DB() *gosql.DB {
	return s.db
}

func (s *simpleObjectStore[T, PT]) Table() string {
	return s.table
}

func (s *simpleObjectStore[T, PT]) IDColumn() string {
	return s.id
}

type SimpleObjectStore[T any, PT ObjectPtr[T, int64]] interface {
	ObjectStore[T, int64, PT]
	// LockStore locks table with specified transaction.
	LockStore(ctx context.Context) error
	// DB returns database connection pool.
	DB() *gosql.DB
	// Table returns name of database table.
	Table() string
	// IDColumn returns name of ID column.
	IDColumn() string
}

func NewSimpleObjectStore[T any, PT ObjectPtr[T, int64]](db *gosql.DB, table, id string) SimpleObjectStore[T, PT] {
	var object T
	return &simpleObjectStore[T, PT]{
		db:      db,
		table:   table,
		id:      id,
		columns: gosql.StructNames(object),
	}
}

// ScanRow scans only one row from reader.
func ScanRow[T any](r RowReader[T], row *T) error {
	defer func() {
		_ = r.Close()
	}()
	if !r.Next() {
		if err := r.Err(); err != nil {
			return err
		}
		return sql.ErrNoRows
	}
	*row = r.Row()
	if r.Next() {
		return errTooManyRows
	}
	return nil
}

type rowReader[T any] struct {
	rows *sql.Rows
	err  error
	row  T
	// refs contains pointers for each field in row.
	refs []any
}

func (r *rowReader[T]) Next() bool {
	if !r.rows.Next() {
		return false
	}
	r.err = r.rows.Scan(r.refs...)
	return r.err == nil
}

func (r *rowReader[T]) Row() T {
	return r.row
}

func (r *rowReader[T]) Close() error {
	return r.rows.Close()
}

func (r *rowReader[T]) Err() error {
	if err := r.rows.Err(); err != nil {
		return err
	}
	return r.err
}

func newRowReader[T any](rows *sql.Rows) *rowReader[T] {
	r := rowReader[T]{rows: rows}
	r.refs = gosql.StructValues(&r.row, true)
	return &r
}

var (
	errInvalidColumns = fmt.Errorf("result has invalid column sequence")
	errTooManyRows    = fmt.Errorf("too many rows")
)

func checkColumns(rows *sql.Rows, cols []string) error {
	rowCols, err := rows.Columns()
	if err != nil {
		return err
	}
	if len(cols) != len(rowCols) {
		return errInvalidColumns
	}
	for i := 0; i < len(cols); i++ {
		if cols[i] != rowCols[i] {
			return errInvalidColumns
		}
	}
	return nil
}

func getRunner(ctx context.Context, r gosql.Runner) gosql.Runner {
	if tx := GetTx(ctx); tx != nil {
		return tx
	}
	return r
}
