package leviathan

import (
	"context"
	"time"

	"code.justin.tv/safety/datastore/models"
	"github.com/Masterminds/squirrel"
	"github.com/pkg/errors"
)

// CreateEnforcement creates a new enforcement record in database
func (t *Transaction) CreateEnforcement(ctx context.Context, enforcement models.Enforcement) (*int64, error) {
	now := time.Now()
	enforcement.CreatedAt = now
	enforcement.UpdatedAt = now

	stmt, args, err := toInsertStatement(tableEnforcements, enforcement)
	if err != nil {
		return nil, errors.Wrap(err, msgSQLConversion)
	}

	sqlResult, err := t.tx.ExecContext(ctx, stmt, args...)
	if err != nil {
		return nil, errors.Wrap(err, msgInsertContext)
	}

	id, err := sqlResult.LastInsertId()
	if err != nil {
		return nil, errors.Wrap(err, msgRetrieveLastID)
	}
	return &id, nil
}

// Enforcement returns the enforcement for a given id
func (t *Transaction) Enforcement(ctx context.Context, id int64) (*models.Enforcement, error) {
	enforcements, err := t.Enforcements(ctx, []int64{id})
	if err != nil {
		return nil, err
	}

	if len(enforcements) == 0 {
		return nil, nil
	}

	if len(enforcements) == 1 {
		return enforcements[0], nil
	}

	return nil, errors.Errorf("Found more than one enforcement of id %d", id)
}

// Enforcements returns enforcements for given ids
func (t *Transaction) Enforcements(ctx context.Context, ids []int64) ([]*models.Enforcement, error) {
	q := squirrel.Select("*").From(tableEnforcements).
		Where(squirrel.Eq{"id": ids})

	sql, args, err := q.ToSql()
	if err != nil {
		return nil, errors.Wrap(err, msgSQLConversion)
	}

	var enforcements []*models.Enforcement
	err = t.tx.SelectContext(ctx, &enforcements, sql, args...)
	if err != nil {
		return nil, errors.Wrap(err, msgSelectContext)
	}

	return enforcements, nil
}

// UpdateEnforcement updates an existing Enforcement in place
func (t *Transaction) UpdateEnforcement(ctx context.Context, enforcement models.Enforcement) error {
	enforcement.UpdatedAt = time.Now()

	sql, args, err := toUpdateStatement(tableEnforcements, enforcement)
	if err != nil {
		return errors.Wrap(err, msgSQLConversion)
	}

	_, err = t.tx.ExecContext(ctx, sql, args...)
	if err != nil {
		return errors.Wrap(err, msgUpdateContext)
	}

	return nil
}

// EnforcementPage returns enforcements up to a limit, resuming from offset
func (t *Transaction) EnforcementPage(
	ctx context.Context,
	filter *models.EnforcementFilter,
	sort *models.EnforcementSort,
	limit uint64,
	offset uint64) ([]*models.Enforcement, *models.PageInfo, error) {

	var err error
	q := squirrel.Select("*").From(tableEnforcements).Limit(limit).Offset(offset)

	q, err = sortBy(q, sort)
	if err != nil {
		return nil, nil, err
	}

	q, err = filterBy(q, filter)
	if err != nil {
		return nil, nil, err
	}

	sql, args, err := q.ToSql()
	if err != nil {
		return nil, nil, errors.Wrap(err, msgSQLConversion)
	}

	var enforcements []*models.Enforcement
	err = t.tx.SelectContext(ctx, &enforcements, sql, args...)
	if err != nil {
		return nil, nil, errors.Wrap(err, msgSelectContext)
	}

	q = squirrel.Select("count(*) as total").From(tableEnforcements)
	q, err = filterBy(q, filter)
	if err != nil {
		return nil, nil, err
	}
	sql, args, err = q.ToSql()
	if err != nil {
		return nil, nil, errors.Wrap(err, msgSQLConversion)
	}

	var pageInfo []*models.PageInfo
	err = t.tx.SelectContext(ctx, &pageInfo, sql, args...)
	if err != nil {
		return nil, nil, errors.Wrap(err, msgSelectContext)
	}

	if len(pageInfo) != 1 {
		return nil, nil, errors.New("Incorrect number of results returned for page info")
	}

	return enforcements, pageInfo[0], nil
}
