package db

import (
	"context"
	"database/sql"
	"fmt"

	"code.justin.tv/cb/roster/internal/postgres"
	"github.com/pkg/errors"
	log "github.com/sirupsen/logrus"
)

// DisplayPositioner contains shared functions for updating display positions in db tables
type DisplayPositioner interface {
	Update(ctx context.Context, table, teamID, channelID string, desiredPosition uint) error
	Normalize(ctx context.Context, tx *sql.Tx, table, teamID string) error
}

// displayPositionHelper contains functions to change display position of
// channels within a team and maintain consistency of other channels
type displayPositionHelper struct {
	db *postgres.DB
}

// Update makes queries to change the display position of the
// channel specified, changes display positions of other channels as needed,
// and normalizes display positions of all channels on team to remove gaps.
func (h *displayPositionHelper) Update(ctx context.Context, table, teamID, channelID string, desiredPosition uint) error {
	tx, err := h.db.BeginTx(ctx, nil)
	if err != nil {
		return errors.Wrapf(err, "db: failed to begin transaction for updating row display position in %s", table)
	}

	defer func() {
		if err == nil {
			return
		}

		if err = tx.Rollback(); err != nil {
			log.WithError(err).WithFields(log.Fields{
				"team_id":    teamID,
				"channel_id": channelID,
			}).Errorf("db: failed to rollback transaction for updating row display position in %s", table)
		}
	}()

	if err = h.Normalize(ctx, tx, table, teamID); err != nil {
		return err
	}

	rowID, currentPosition, err := idAndDisplayPosition(ctx, tx, table, teamID, channelID)
	if err != nil {
		return err
	}

	maxPosition, err := maxDisplayPositionInTeam(ctx, tx, table, teamID)
	if err != nil {
		return err
	}

	if desiredPosition > maxPosition {
		desiredPosition = maxPosition
	}

	// update to outside of bounds to avoid conflicts if there are unique constraints
	if err = changeDisplayPosition(ctx, tx, table, rowID, maxPosition+1); err != nil {
		return err
	}

	if desiredPosition < currentPosition {
		err = incrementDisplayPositionsBetween(ctx, tx, table, teamID, desiredPosition, currentPosition)
	} else if desiredPosition > currentPosition {
		err = decrementDisplayPositionsBetween(ctx, tx, table, teamID, currentPosition, desiredPosition)
	}

	if err != nil {
		return err
	}

	if err = changeDisplayPosition(ctx, tx, table, rowID, desiredPosition); err != nil {
		return err
	}

	if err = tx.Commit(); err != nil {
		return errors.Wrapf(err, "db: failed to commit transaction for updating row display position in %s", table)
	}

	return nil
}

// Normalize eliminates gaps in channel display orders.
func (h *displayPositionHelper) Normalize(ctx context.Context, tx *sql.Tx, table, teamID string) error {
	statement := fmt.Sprintf(`
		UPDATE %s
		SET display_order = normalized.position - 1
		FROM (
			SELECT
				id,
				ROW_NUMBER() OVER (ORDER BY display_order ASC) AS position
			FROM %s
			WHERE team_id = $1
		) AS normalized
		WHERE %s.id = normalized.id
	`, table, table, table)

	_, err := tx.ExecContext(ctx, statement, teamID)
	if err != nil {
		return errors.Wrapf(err, "db: failed to update display positions for normalization in %s", table)
	}

	return nil
}

func idAndDisplayPosition(ctx context.Context, tx *sql.Tx, table, teamID, channelID string) (string, uint, error) {
	statement := fmt.Sprintf(`
		SELECT id, display_order
		FROM %s
		WHERE team_id = $1
		AND user_id = $2
	`, table)

	var id string
	var displayPosition uint

	err := tx.QueryRowContext(ctx, statement, teamID, channelID).Scan(&id, &displayPosition)
	switch {
	case err == sql.ErrNoRows:
		return "", 0, sql.ErrNoRows
	case err != nil:
		return "", 0, errors.Wrapf(err, "db: failed to scan row from %s for display position", table)
	}

	return id, displayPosition, nil
}

func changeDisplayPosition(ctx context.Context, tx *sql.Tx, table, membershipID string, desiredPosition uint) error {
	statement := fmt.Sprintf(`
		UPDATE %s
		SET display_order = $1
		WHERE id = $2
	`, table)

	_, err := tx.ExecContext(ctx, statement, desiredPosition, membershipID)
	if err != nil {
		return errors.Wrapf(err, "db: failed to update row display position in %s", table)
	}

	return nil
}

func incrementDisplayPositionsBetween(ctx context.Context, tx *sql.Tx, table, teamID string, desired, current uint) error {
	statement := fmt.Sprintf(`
		UPDATE %s
		SET display_order = display_order + 1
		WHERE display_order >= $1
		AND display_order < $2
		AND team_id = $3
	`, table)

	_, err := tx.ExecContext(ctx, statement, desired, current, teamID)
	if err != nil {
		return errors.Wrapf(err, "db: failed to increment display positions of surrounding rows in %s", table)
	}

	return nil
}

func decrementDisplayPositionsBetween(ctx context.Context, tx *sql.Tx, table, teamID string, current, desired uint) error {
	statement := fmt.Sprintf(`
		UPDATE %s
		SET display_order = display_order - 1
		WHERE display_order > $1
		AND display_order <= $2
		AND team_id = $3
	`, table)

	_, err := tx.ExecContext(ctx, statement, current, desired, teamID)
	if err != nil {
		return errors.Wrapf(err, "db: failed to decrement display positions of surrounding rows in %s", table)
	}

	return nil
}

func maxDisplayPositionInTeam(ctx context.Context, tx *sql.Tx, table, teamID string) (uint, error) {
	statement := fmt.Sprintf("SELECT MAX(display_order) FROM %s WHERE team_id = $1", table)

	var maxDisplayPosition uint

	if err := tx.QueryRowContext(ctx, statement, teamID).Scan(&maxDisplayPosition); err != nil {
		return 0, errors.Wrapf(err, "db: failed to scan for maximum display order from %s", table)
	}

	return maxDisplayPosition, nil
}
