package leviathan

import (
	"context"
	"time"

	"code.justin.tv/safety/datastore/models"

	"github.com/Masterminds/squirrel"
	"github.com/pkg/errors"
)

// CreatePartnerEscalationOverrides ensures the partner is and stays on the list with the specified override value
func (t *Transaction) CreatePartnerEscalationOverrides(ctx context.Context, userIDs []string, override *bool) (int64, error) {
	now := time.Now()

	columns, _ := toColumnsAndValues(models.EscalationPartner{})
	q := squirrel.Insert(tablePartnerEscalationList).Columns(columns...)
	for _, userID := range userIDs {
		_, values := toColumnsAndValues(models.EscalationPartner{
			UserID:    userID,
			Override:  override,
			CreatedAt: now,
			UpdatedAt: now,
		})
		q = q.Values(values...)
	}

	q = q.Suffix("ON DUPLICATE KEY UPDATE override = ?, updated_at = ?", override, now)

	sql, args, err := q.ToSql()
	if err != nil {
		return 0, errors.Wrap(err, msgSQLConversion)
	}

	res, err := t.tx.ExecContext(ctx, sql, args...)
	if err != nil {
		return 0, errors.Wrap(err, msgInsertContext)
	}
	affected, err := res.RowsAffected()
	return affected, errors.Wrap(err, msgRetrieveRowsAffected)
}

// DeletePartnerEscalationEntries deletes select partners from the partner escalation list, regardless of their override status
func (t *Transaction) DeletePartnerEscalationEntries(ctx context.Context, userIDs []string) (int64, error) {
	q := squirrel.Delete(tablePartnerEscalationList).Where(squirrel.Eq{"user_id": userIDs})

	sql, args, err := q.ToSql()
	if err != nil {
		return 0, errors.Wrap(err, msgSQLConversion)
	}

	res, err := t.tx.ExecContext(ctx, sql, args...)
	if err != nil {
		return 0, errors.Wrap(err, msgDeleteContext)
	}
	affected, err := res.RowsAffected()
	return affected, errors.Wrap(err, msgRetrieveRowsAffected)
}

// ReplacePartnerEscalationList replaces all batch added partners with the new ones
func (t *Transaction) ReplacePartnerEscalationList(ctx context.Context, userIDs []string) (*models.ReplacementInfo, error) {
	replacements := &models.ReplacementInfo{}

	if len(userIDs) == 0 {
		return replacements, nil
	}

	deleteQ := squirrel.Delete(tablePartnerEscalationList).Where(squirrel.Eq{"override": nil})

	sql, args, err := deleteQ.ToSql()
	if err != nil {
		return replacements, errors.Wrap(err, msgSQLConversion)
	}

	res, err := t.tx.ExecContext(ctx, sql, args...)
	if err != nil {
		return replacements, errors.Wrap(err, msgDeleteContext)
	}
	replacements.Removed, err = res.RowsAffected()
	if err != nil {
		return replacements, errors.Wrap(err, msgRetrieveRowsAffected)
	}

	now := time.Now()
	columns, _ := toColumnsAndValues(models.EscalationPartner{})
	insertQ := squirrel.Insert(tablePartnerEscalationList).Columns(columns...).Options("IGNORE")

	for _, userID := range userIDs {
		_, values := toColumnsAndValues(models.EscalationPartner{
			UserID:    userID,
			CreatedAt: now,
			UpdatedAt: now,
			Override:  nil,
		})
		insertQ = insertQ.Values(values...)
	}

	sql, args, err = insertQ.ToSql()
	if err != nil {
		return replacements, errors.Wrap(err, msgSQLConversion)
	}

	res, err = t.tx.ExecContext(ctx, sql, args...)
	if err != nil {
		return replacements, errors.Wrap(err, msgInsertContext)
	}
	replacements.Created, err = res.RowsAffected()
	if err != nil {
		return replacements, errors.Wrap(err, msgRetrieveRowsAffected)
	}

	return replacements, nil
}

// EscalationPartnersByID returns partners on the escalation list and their status
func (t *Transaction) EscalationPartnersByID(ctx context.Context, userIDs []string) ([]*models.EscalationPartner, error) {
	q := squirrel.Select("*").From(tablePartnerEscalationList).Where(squirrel.Eq{"user_id": userIDs})

	sql, args, err := q.ToSql()
	if err != nil {
		return nil, errors.Wrap(err, msgSQLConversion)
	}

	var result []*models.EscalationPartner
	err = t.tx.SelectContext(ctx, &result, sql, args...)
	return result, errors.Wrap(err, msgSelectContext)
}

// EscalationPartnersPage returns the entire partner escalation list paginated
func (t *Transaction) EscalationPartnersPage(ctx context.Context, limit uint64, offset uint64) ([]*models.EscalationPartner, error) {
	q := squirrel.Select("*").From(tablePartnerEscalationList).Limit(limit).OrderBy("user_id").Offset(offset)

	sql, args, err := q.ToSql()
	if err != nil {
		return nil, errors.Wrap(err, msgSQLConversion)
	}

	var result []*models.EscalationPartner
	err = t.tx.SelectContext(ctx, &result, sql, args...)
	return result, errors.Wrap(err, msgSelectContext)
}

// EscalationPartnersCount returns the amount of partners on the escalation list
func (t *Transaction) EscalationPartnersCount(ctx context.Context) (int64, error) {
	q := squirrel.Select("COUNT(1) AS total").From(tablePartnerEscalationList)

	sql, args, err := q.ToSql()
	if err != nil {
		return 0, errors.Wrap(err, msgSQLConversion)
	}

	var result []*models.PageInfo
	err = t.tx.SelectContext(ctx, &result, sql, args...)
	if err != nil {
		return 0, errors.Wrap(err, msgSelectContext)
	}
	return result[0].Total, nil
}
