package model

import (
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"
	"time"

	"github.com/jmoiron/sqlx"
)

type Question struct {
	ID          int              `db:"id"            json:"id"`
	ChannelID   int              `db:"channel_id"    json:"channel_id"`
	Text        string           `db:"text"          json:"text"`
	ActiveFrom  *time.Time       `db:"active_from"   json:"active_from"`
	ActiveUntil *time.Time       `db:"active_until"  json:"active_until"`
	Responses   int              `db:"responses"     json:"responses"`
	Settings    QuestionSettings `db:"settings"      json:"settings"`

	AnswerIDs []int     `db:"-"  json:"answer_ids"`
	Answers   []*Answer `db:"-"  json:"answers"`

	// Only exists when a question lookup has a user context (`Question.ForUser()`):
	Answered   *bool       `db:"-"  json:"answered,omitempty"`
	UserAnswer *UserAnswer `db:"-"  json:"user_answer,omitempty"`

	db *DB
}

type QuestionSettings struct {
	EnableChat bool `json:"enable_chat"`
}

// Implements `database/sql.Scan` (`lib/pq` doesn't know how to scan `JSONB`).
func (qs *QuestionSettings) Scan(src interface{}) error {
	return json.Unmarshal(src.([]byte), qs)
}

type Answer struct {
	ID         int       `db:"id"           json:"id"`
	QuestionID int       `db:"question_id"  json:"question_id"`
	Question   *Question `db:"-"            json:"-"`
	Text       string    `db:"text"         json:"text"`
	EmoteID    *int      `db:"emote_id"     json:"emote_id"`
	Match      *string   `db:"match"        json:"match"`
	IsCorrect  bool      `db:"is_correct"   json:"is_correct"`
	Responses  int       `db:"responses"    json:"responses"`

	db *DB
}

// Convenience methods

func (q Question) FindAnswer(id int) *Answer {
	for _, answer := range q.Answers {
		if answer.ID == id {
			return answer
		}
	}
	return nil
}

func (q Question) HasCorrect() bool {
	for _, answer := range q.Answers {
		if answer.IsCorrect {
			return true
		}
	}
	return false
}

// Find & List

func (db *DB) FindQuestion(id int) (*Question, Error) {
	var question Question
	if err := db.Get(&question, `SELECT * FROM "questions" WHERE "id" = $1 LIMIT 1`, id); err == sql.ErrNoRows {
		return nil, nil
	} else if err != nil {
		return nil, DBError(err)
	}
	question.db = db

	if err := db.Select(&question.Answers, `SELECT * FROM "answers" WHERE "question_id" = $1 ORDER BY "id" ASC`, id); err != nil {
		return nil, DBErrorf("error getting answers: %s", err)
	}
	question.AnswerIDs = make([]int, len(question.Answers))
	for i, answer := range question.Answers {
		answer.Question = &question
		answer.db = db
		question.AnswerIDs[i] = answer.ID
	}

	return &question, nil
}

func (db *DB) FindQuestionForAnswer(answerID int) (*Question, Error) {
	var qID int
	if err := db.Get(&qID, `SELECT "question_id" FROM "answers" WHERE "id" = $1 LIMIT 1`, answerID); err == sql.ErrNoRows {
		return nil, nil
	} else if err != nil {
		return nil, DBError(err)
	}
	return db.FindQuestion(qID)
}

func (db *DB) ListQuestions(channelID int, activeFrom, activeUntil *time.Time) ([]*Question, Error) {
	query := `SELECT "id" FROM "questions" WHERE "channel_id" = $1`
	queryArgs := []interface{}{channelID}
	if activeFrom != nil {
		queryArgs = append(queryArgs, activeFrom.UTC())
		query += fmt.Sprintf(` AND "active_from" >= $%d`, len(queryArgs))
	}
	if activeUntil != nil {
		queryArgs = append(queryArgs, activeUntil.UTC())
		query += fmt.Sprintf(` AND "active_until" <= $%d`, len(queryArgs))
	}
	query += ` ORDER BY "id" ASC`

	var ids []int
	if err := db.Select(&ids, query, queryArgs...); err != nil {
		return nil, DBErrorf("error enumerating question IDs for channel %d: %s", channelID, err)
	}

	questions := make([]*Question, len(ids))
	for i, id := range ids {
		if question, err := db.FindQuestion(id); err != nil {
			return nil, err.Prefix("error fetching question %d (%d/%d) for channel %d", id, i+1, len(ids), channelID)
		} else {
			questions[i] = question
		}
	}

	return questions, nil
}

func (q *Question) Delete() Error {
	if now := time.Now(); q.ActiveFrom != nil && q.ActiveFrom.Before(time.Now()) {
		when := "past"
		if q.ActiveUntil.After(now) {
			when = "active"
		}
		return UserErrorf("cannot delete %s question", when)
	}

	tx, err := q.db.Beginx()
	if err != nil {
		return DBErrorf("error initializing transaction: %s", err)
	}
	defer tx.Rollback()

	if _, err = tx.Exec(`DELETE FROM "answers" WHERE "question_id" = $1`, q.ID); err != nil {
		return DBErrorf("error deleting answers: %s", err)
	} else if _, err = tx.Exec(`DELETE FROM "questions" WHERE "id" = $1`, q.ID); err != nil {
		return DBErrorf("error deleting question: %s", err)
	} else if err = tx.Commit(); err != nil {
		return DBErrorf("error committing transaction: %s", err)
	}

	return nil
}

// Create

type NewQuestion struct {
	ChannelID   int              `json:"channel_id"`
	Text        string           `json:"text"`
	ActiveFrom  *time.Time       `json:"active_from"`
	ActiveUntil *time.Time       `json:"active_until"`
	Settings    QuestionSettings `json:"settings"`
	Answers     []NewAnswer      `json:"answers"`
}

type NewAnswer struct {
	Text      string  `json:"text"`
	EmoteID   *int    `json:"emote_id"`
	Match     *string `json:"match"`
	IsCorrect bool    `json:"is_correct"`
}

func (q NewQuestion) Validate() error {
	if q.ChannelID <= 0 {
		return errors.New("invalid channel_id")
	} else if q.Text == "" {
		return errors.New("missing text")
	} else if (q.ActiveFrom == nil && q.ActiveUntil != nil) || (q.ActiveFrom != nil && q.ActiveUntil == nil) {
		return errors.New("active_from and active_until must be both null or both non-null")
	} else if len(q.Answers) < 1 {
		return errors.New("missing answers")
	}
	for i, answer := range q.Answers {
		if err := answer.Validate(); err != nil {
			return fmt.Errorf("answer %d: %s", i+1, err)
		}
	}
	return nil
}

func (a NewAnswer) Validate() error {
	if a.Text == "" {
		return errors.New("missing text")
	}
	return nil
}

func overlappingQuestionIDs(tx *sqlx.Tx, channelID int, from, until time.Time) (ids []int, err error) {
	/* Four kinds of overlaps to test for:
	 *     1. <-new->     2. <---new--->  3.   <-new->    4.    <-new->
	 *           <-old->       <-old->       <---old--->     <-old->
	 * These can be covered in two tests:
	 *     1 & 2: newStart <= oldStart <= newEnd
	 *     3 & 4: oldStart <= newStart <= oldEnd
	 */
	if err = tx.Select(&ids,
		`SELECT "id" FROM "questions" WHERE "channel_id" = $1 AND `+
			`(("active_from" >= $2 AND "active_from" <= $3) OR ("active_from" <= $2 AND "active_until" >= $2)) `+
			`ORDER BY "id" ASC`,
		channelID, from, until,
	); err == sql.ErrNoRows {
		err = nil
	}
	return
}

func (db *DB) CreateQuestion(newQ NewQuestion) (*Question, Error) {
	var settingsJSON string
	if err := newQ.Validate(); err != nil {
		return nil, UserError(err)
	} else if buf, err := json.Marshal(&newQ.Settings); err != nil {
		return nil, DBErrorf("error serializing question settings to JSON: %s", err)
	} else {
		settingsJSON = string(buf)
	}

	// Ensure channel exists.
	if _, err := db.FindOrCreateChannel(newQ.ChannelID); err != nil {
		return nil, err.Prefix("error looking up channel %q", newQ.ChannelID)
	}

	tx, err := db.Beginx()
	if err != nil {
		return nil, DBErrorf("error initiating transaction: %s", err)
	}
	defer tx.Rollback()

	if newQ.ActiveFrom != nil && newQ.ActiveUntil != nil {
		// Ensure times are in UTC, because `lib/pq` blithely ignores `time.Location()` and assumes UTC.
		fromUTC, untilUTC := newQ.ActiveFrom.UTC(), newQ.ActiveUntil.UTC()
		newQ.ActiveFrom, newQ.ActiveUntil = &fromUTC, &untilUTC

		if ids, err := overlappingQuestionIDs(tx, newQ.ChannelID, *newQ.ActiveFrom, *newQ.ActiveUntil); err != nil {
			return nil, DBErrorf("error checking for overlapping questions: %s", err.Error())
		} else if len(ids) > 0 {
			return nil, UserErrorf("question overlaps existing questions").WithMeta("conflicts", ids)
		}
	}

	var id int
	if err := tx.Get(&id,
		`INSERT INTO "questions" ("channel_id", "text", "active_from", "active_until", "settings") `+
			`VALUES ($1, $2, $3, $4, $5::jsonb) RETURNING "id"`,
		newQ.ChannelID, newQ.Text, newQ.ActiveFrom, newQ.ActiveUntil, settingsJSON,
	); err != nil {
		return nil, DBErrorf("error inserting question: %s", err)
	}

	for i, answer := range newQ.Answers {
		if err := createAnswer(tx, id, answer); err != nil {
			return nil, DBErrorf("error inserting answer #%d: %s", i+1, err.Error())
		}
	}

	if err = tx.Commit(); err != nil {
		return nil, DBErrorf("error committing transaction: %s", err)
	}

	q, qErr := db.FindQuestion(id)
	return q, qErr
}

func createAnswer(tx *sqlx.Tx, questionID int, answer NewAnswer) error {
	_, err := tx.Exec(
		`INSERT INTO "answers" ("question_id", "text", "emote_id", "match", "is_correct") VALUES ($1, $2, $3, $4, $5)`,
		questionID, answer.Text, answer.EmoteID, answer.Match, answer.IsCorrect,
	)
	return err
}
