// +build test

package model

import (
	"fmt"
	"math/rand"
	"strings"
	"testing"
	"time"

	"ting/util"
	. "ting/util/types"

	uuid "github.com/satori/go.uuid"
)

func TestQuestion(t *testing.T) {
	db := InitTestDB(t)
	defer db.Close()

	t.Run("FindQuestion", findQuestionTests(db))
	t.Run("ListQuestions", listQuestionsTests(db))
	t.Run("CreateQuestion", createQuestionTests(db))
	t.Run("ResponseCounts", responseCountsTests(db))
	t.Run("Update", questionUpdateTests(db))
	t.Run("Delete", questionDeleteTests(db))
}

func findQuestionTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		t.Run("Exists", func(t *testing.T) {
			exp := db.NewTestQuestion(t, Attr{"Settings": QuestionSettings{EnableChat: true}})
			if q, err := db.FindQuestion(exp.ID); err != nil {
				t.Fatalf("unexpected error when finding question %d: %s", exp.ID, err.Error())
			} else if q == nil {
				t.Fatalf("received nil for existing question %d", exp.ID)
			} else {
				q.AssertEqual(t, exp)
			}
		})

		t.Run("DoesNotExist", func(t *testing.T) {
			if q, err := db.FindQuestion(-1); err != nil {
				t.Fatalf("unexpected error when finding non-existent question: %s", err.Error())
			} else if q != nil {
				t.Fatalf("found a question that shouldn't exist: %#v", *q)
			}
		})

		t.Run("ForAnswer", func(t *testing.T) {
			t.Run("Exists", func(t *testing.T) {
				exp := db.NewTestQuestion(t, nil)
				for i, answerID := range exp.AnswerIDs {
					errPrefix := fmt.Sprintf("answerID %d (%d/%d)", answerID, i+1, len(exp.AnswerIDs))
					if q, err := db.FindQuestionForAnswer(answerID); err != nil {
						t.Fatalf("(%s) unexpected error: %s", errPrefix, err)
					} else if q == nil {
						t.Fatalf("(%s) expected to find question %d", errPrefix, exp.ID)
					} else if q.ID != exp.ID {
						t.Fatalf("(%s) found wrong question: %d; expected %d", errPrefix, q.ID, exp.ID)
					}
				}
			})

			t.Run("DoesNotExist", func(t *testing.T) {
				if q, err := db.FindQuestionForAnswer(-1); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if q != nil {
					t.Fatalf("found a question that shouldn't exist: %#v", *q)
				}
			})
		})
	}
}

func listQuestionsTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		t.Run("NoChannel", func(t *testing.T) {
			if qs, err := db.ListQuestions(-1, nil, nil); err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			} else if qs == nil {
				t.Fatal("nil result despite nil error")
			} else if len(qs) > 0 {
				t.Fatalf("unexpected results (expected none): %#v", qs)
			}
		})

		goodChannelID := db.NewTestChannel(t, nil).ID

		t.Run("NoQuestions", func(t *testing.T) {
			if qs, err := db.ListQuestions(goodChannelID, nil, nil); err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			} else if qs == nil {
				t.Fatal("nil result despite nil error")
			} else if len(qs) > 0 {
				t.Fatalf("unexpected results (expected none): %#v", qs)
			}
		})

		// For a good channel ID and a bad channel ID, create:
		// [0] = question with no active time
		// [1] = question active 24 hours ago
		// [2] = question active now
		// [3] = question active 24 hours from now
		badChannelID := db.NewTestChannel(t, nil).ID
		var goodIDs [4]int
		goodIDs[0] = db.NewTestQuestion(t, Attr{"ChannelID": goodChannelID}).ID
		db.NewTestQuestion(t, Attr{"ChannelID": badChannelID})
		for i := -1; i <= 1; i++ {
			from, until := questionTimes(i)
			qAttrs := Attr{
				"ActiveFrom":  from,
				"ActiveUntil": until,
			}
			goodIDs[i+2] = db.NewTestQuestion(t, qAttrs.With("ChannelID", goodChannelID)).ID
			db.NewTestQuestion(t, qAttrs.With("ChannelID", badChannelID))
		}
		now := time.Now()
		todayStart, todayEnd := startOfDay(now), endOfDay(now)

		t.Run("All", func(t *testing.T) {
			qs, err := db.ListQuestions(goodChannelID, nil, nil)
			if err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			}
			qIDs := questionIDs(qs)
			expIDs := goodIDs[:]
			if !util.IntsSortEqual(qIDs, expIDs) {
				t.Fatalf("incorrect results: %v; expected %v", qIDs, expIDs)
			}
			for i, q := range qs {
				if q.ChannelID != goodChannelID {
					t.Errorf("question %d has wrong Channel ID: %d; expected %q", i+1, q.ChannelID, goodChannelID)
				}
			}
		})

		t.Run("FromBound", func(t *testing.T) {
			qs, err := db.ListQuestions(goodChannelID, &todayStart, nil)
			if err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			}
			qIDs := questionIDs(qs)
			expIDs := []int{goodIDs[2], goodIDs[3]}
			if !util.IntsSortEqual(qIDs, expIDs) {
				t.Fatalf("incorrect results: %v; expected %v", qIDs, expIDs)
			}
			for i, q := range qs {
				if q.ChannelID != goodChannelID {
					t.Errorf("question %d has wrong Channel ID: %d; expected %q", i+1, q.ChannelID, goodChannelID)
				}
			}
		})

		t.Run("UntilBound", func(t *testing.T) {
			qs, err := db.ListQuestions(goodChannelID, nil, &todayEnd)
			if err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			}
			qIDs := questionIDs(qs)
			expIDs := []int{goodIDs[1], goodIDs[2]}
			if !util.IntsSortEqual(qIDs, expIDs) {
				t.Fatalf("incorrect results: %v; expected %v", qIDs, expIDs)
			}
			for i, q := range qs {
				if q.ChannelID != goodChannelID {
					t.Errorf("question %d has wrong Channel ID: %d; expected %q", i+1, q.ChannelID, goodChannelID)
				}
			}
		})

		t.Run("BothBound", func(t *testing.T) {
			qs, err := db.ListQuestions(goodChannelID, &todayStart, &todayEnd)
			if err != nil {
				t.Fatalf("unexpected error: %s", err.Error())
			}
			qIDs := questionIDs(qs)
			expIDs := []int{goodIDs[2]}
			if !util.IntsSortEqual(qIDs, expIDs) {
				t.Fatalf("incorrect results: %v; expected %v", qIDs, expIDs)
			}
			for i, q := range qs {
				if q.ChannelID != goodChannelID {
					t.Errorf("question %d has wrong Channel ID: %d; expected %q", i+1, q.ChannelID, goodChannelID)
				}
			}
		})
	}
}

func createQuestionTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		// `db.NewTestQuestion` already uses `CreateQuestion`, so just test special cases here.
		t.Run("Overlap", func(t *testing.T) {
			channel := db.NewTestChannel(t, nil)
			now := time.Now()
			conflict := db.NewTestQuestion(t, Attr{
				"ChannelID":   channel.ID,
				"ActiveFrom":  now.Add(-2 * time.Hour),
				"ActiveUntil": now.Add(+2 * time.Hour),
			})
			expErrJSON := util.ToJSON(t,
				UserErrorf("question overlaps existing questions").
					WithMeta("conflicts", []int{conflict.ID}),
			)
			newQ := NewQuestion{
				ChannelID:   channel.ID,
				Text:        "foo",
				ActiveFrom:  nil,
				ActiveUntil: nil,
				Settings:    QuestionSettings{},
				Answers:     []NewAnswer{{Text: "foo-a1"}, {Text: "foo-a2"}, {Text: "foo-a3"}},
			}

			for _, testCase := range []struct {
				name                string
				fromDiff, untilDiff time.Duration
			}{
				{name: "Left", fromDiff: -3 * time.Hour, untilDiff: 0},
				{name: "Right", fromDiff: 0, untilDiff: 3 * time.Hour},
				{name: "Inner", fromDiff: -1 * time.Hour, untilDiff: 1 * time.Hour},
				{name: "Outer", fromDiff: -3 * time.Hour, untilDiff: 3 * time.Hour},
			} {
				t.Run(testCase.name, func(t *testing.T) {
					from := now.Add(testCase.fromDiff)
					until := now.Add(testCase.untilDiff)
					newQ.ActiveFrom = &from
					newQ.ActiveUntil = &until
					if q, err := db.CreateQuestion(newQ); err == nil {
						t.Fatalf("unexpected success: %#v", q)
					} else if errJSON := util.ToJSON(t, err); errJSON != expErrJSON {
						t.Fatalf("wrong error: %s\nexpected %s", errJSON, expErrJSON)
					}
				})
			}
		})
	}
}

func responseCountsTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		question := db.NewTestQuestion(t, Attr{"Active": true, "Answers": 5})
		expCounts := make(map[int]int, len(question.AnswerIDs))
		for _, answerID := range question.AnswerIDs {
			expCounts[answerID] = 0
		}
		// No need to stress-test the UUID generator:
		baseUserID := uuid.NewV4().String()
		userID := func(i int) string { return fmt.Sprintf("%s:%d", baseUserID, i) }

		for i := 1; i <= 20; i++ {
			answerID := question.AnswerIDs[rand.Intn(len(question.AnswerIDs))]
			expCounts[answerID] += 1
			var q *Question
			if _, err := question.AddUserAnswer(userID(i), answerID, true); err != nil {
				t.Fatalf("unexpected error adding answer %d/100 to question %d: %s", i, question.ID, err)
			} else if q, err = db.FindQuestion(question.ID); err != nil {
				t.Fatalf("unexpected error re-fetching question %d after answer %d/100: %s", question.ID, i, err)
			} else if q == nil {
				t.Fatalf("question %d stopped existing after answer %d/100", question.ID, i)
			} else if len(q.Answers) != len(expCounts) {
				t.Fatalf("wrong answer count for question %d: %d; expected %d\nfound IDs: %v, expected IDs: %v",
					question.ID, len(q.Answers), len(expCounts), q.AnswerIDs, question.AnswerIDs)
			} else if q.Responses != i {
				t.Fatalf("wrong response count for question %d: %d; expected %d", question.ID, q.Responses, i)
			}
			for _, answer := range q.Answers {
				if expCount, found := expCounts[answer.ID]; !found {
					t.Fatalf("found unexpected answer for question %d: %d", question.ID, answer.ID)
				} else if answer.Responses != expCount {
					t.Fatalf("wrong response count for answer %d: %d; expected: %d", answer.ID, answer.Responses, expCount)
				}
			}
		}
	}
}

func questionDeleteTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		for _, testCase := range []struct {
			name     string
			inactive bool
			dayDelta int
			ok       bool
		}{
			{name: "Inactive", inactive: true, ok: true},
			{name: "Future", dayDelta: 1, ok: true},
			{name: "Active", dayDelta: 0, ok: false},
			{name: "Past", dayDelta: -1, ok: false},
		} {
			t.Run(testCase.name, func(t *testing.T) {
				var qOrig *Question
				if testCase.inactive {
					qOrig = db.NewTestQuestion(t, nil)
				} else {
					from, until := questionTimes(testCase.dayDelta)
					qOrig = db.NewTestQuestion(t, Attr{"ActiveFrom": &from, "ActiveUntil": &until})
				}

				err := qOrig.Delete()
				if testCase.ok && err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if !testCase.ok {
					expErr := fmt.Sprintf("cannot delete %s question", strings.ToLower(testCase.name))
					if err == nil || err.Error() != expErr {
						t.Fatalf("wrong error: %q; expected %q", err, expErr)
					}
				}

				qFind, err := db.FindQuestion(qOrig.ID)
				if err != nil {
					t.Fatalf("unexpected error trying to re-fetch deleted question: %s", err)
				} else if testCase.ok {
					if qFind != nil {
						t.Fatal("question was not deleted")
					}
				} else {
					if qFind == nil {
						t.Fatalf("%s question was deleted", strings.ToLower(testCase.name))
					}
					qFind.AssertEqual(t, qOrig)
				}
			})
		}
	}
}

func questionUpdateTests(db *DB) func(t *testing.T) {
	return func(t *testing.T) {
		t.Run("NoChange", func(t *testing.T) {
			// Create a question with trigger that'll raise an exception on attempts to UPDATE it.
			qPre := db.NewTestQuestion(t, nil)
			triggerName := fmt.Sprintf("lock_question_%d", qPre.ID)
			if _, err := db.Exec(fmt.Sprintf(
				`CREATE FUNCTION %s_fun() RETURNS trigger LANGUAGE plpgsql AS $BODY$ `+
					`BEGIN RAISE EXCEPTION 'Question %d is locked'; END; $BODY$;`,
				triggerName, qPre.ID,
			)); err != nil {
				t.Fatalf("error creating trigger function %q: %s", triggerName+"_fun()", err)
			} else if _, err = db.Exec(fmt.Sprintf(
				`CREATE TRIGGER %s BEFORE UPDATE ON "questions" FOR EACH ROW `+
					`WHEN (OLD.id = %d) EXECUTE PROCEDURE %s_fun ();`,
				triggerName, qPre.ID, triggerName,
			)); err != nil {
				t.Fatalf("error creating trigger %q: %s", triggerName, err)
			}
			defer func() {
				if _, err := db.Exec(fmt.Sprintf(`DROP TRIGGER %s ON "questions";`, triggerName)); err != nil {
					t.Logf("error deleting trigger %q: %s", triggerName, err)
				} else if _, err = db.Exec(fmt.Sprintf(`DROP FUNCTION %s_fun;`, triggerName)); err != nil {
					t.Logf("error deleting trigger function %q: %s", triggerName+"_fun", err)
				}
			}()

			t.Run("Empty", func(t *testing.T) {
				if qPost, err := qPre.Update(StringMap{}); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if qPost == nil {
					t.Fatal("expected non-nil result")
				} else {
					qPost.AssertEqual(t, qPre)
				}
			})

			t.Run("IgnoredOnly", func(t *testing.T) {
				updates := StringMap{
					"responses":   30,
					"answer_ids":  []int{13, 17, 19},
					"answered":    true,
					"user_answer": 17,
				}
				if qPost, err := qPre.Update(updates); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if qPost == nil {
					t.Fatal("expected non-nil result")
				} else {
					qPost.AssertEqual(t, qPre)
				}
			})

			t.Run("Identical", func(t *testing.T) {
				updates := util.MapFromJSON(t, util.ToJSON(t, qPre))
				if qPost, err := qPre.Update(updates); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if qPost == nil {
					t.Fatal("expected non-nil result")
				} else {
					qPost.AssertEqual(t, qPost)
				}
			})
		})

		t.Run("BadFields", func(t *testing.T) {
			t.Run("Question", func(t *testing.T) {
				qPre := db.NewTestQuestion(t, nil)
				exp1 := `unrecognized fields: "foo", "bar"`
				exp2 := `unrecognized fields: "bar", "foo"`
				if _, err := qPre.Update(StringMap{"foo": 10, "bar": 20}); err == nil {
					t.Fatal("unexpected success")
				} else if msg := err.Error(); msg != exp1 && msg != exp2 {
					t.Fatalf("wrong error message: %q", err.Error())
				}
			})

			t.Run("Answer", func(t *testing.T) {
				qPre := db.NewTestQuestion(t, nil)
				updates := util.MapFromJSON(t, fmt.Sprintf(`{"answers": [{"id": %d, "foo": 10, "bar": 20}]}`, qPre.AnswerIDs[1]))
				exp1 := `answer 1: unrecognized fields: "foo", "bar"`
				exp2 := `answer 1: unrecognized fields: "bar", "foo"`
				if _, err := qPre.Update(updates); err == nil {
					t.Fatal("unexpected success")
				} else if msg := err.Error(); msg != exp1 && msg != exp2 {
					t.Fatalf("wrong error message: %q", msg)
				}
			})
		})

		t.Run("ChangeStatic", func(t *testing.T) {
			qPre := db.NewTestQuestion(t, nil)
			// ID is in update map as float64 because that's what `json.Unmarshal` decodes numbers in maps to.
			if _, err := qPre.Update(StringMap{"id": float64(qPre.ID + 1)}); err == nil {
				t.Fatal("unexpected success")
			} else if msg := err.Error(); msg != "id: cannot be modified" {
				t.Fatalf("wrong error message: %q", msg)
			}
		})

		t.Run("Text", func(t *testing.T) {
			qPre := db.NewTestQuestion(t, nil)
			if qPost, err := qPre.Update(StringMap{"text": "new question text"}); err != nil {
				t.Fatalf("unexpected error: %s", err)
			} else {
				qPre.Text = "new question text"
				qPost.AssertEqual(t, qPre)
			}
		})

		t.Run("ActiveTimes", func(t *testing.T) {
			t.Run("XorNil", func(t *testing.T) {
				qPre := db.NewTestQuestion(t, nil)
				if qPost, err := qPre.Update(StringMap{"active_from": util.Timestamp(1)}); err == nil {
					t.Fatalf("unexpected success: %#v", qPost)
				} else if msg := err.Error(); msg != "active_from and active_until must either be both nil or both non-nil" {
					t.Fatalf("wrong error message: %q", msg)
				}
			})

			t.Run("ToNil", func(t *testing.T) {
				from, until := util.Timestamp(+1), util.Timestamp(+2)
				qPre := db.NewTestQuestion(t, Attr{"ActiveFrom": from, "ActiveUntil": until})
				if qPost, err := qPre.Update(StringMap{"active_from": nil, "active_until": nil}); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else {
					qPre.ActiveFrom, qPre.ActiveUntil = nil, nil
					qPost.AssertEqual(t, qPre)
				}
			})

			t.Run("FromNil", func(t *testing.T) {
				qPre := db.NewTestQuestion(t, nil)
				from, until := util.Timestamp(+1), util.Timestamp(+2)
				var tsErr error
				if qPost, err := qPre.Update(StringMap{"active_from": from, "active_until": until}); err != nil {
					t.Fatalf("unexpected error: %s", err)
				} else if qPre.ActiveFrom, tsErr = util.ParseTimep(from); tsErr != nil {
					t.Fatalf("error re-parsing active_from (%q): %s", from, tsErr)
				} else if qPre.ActiveUntil, tsErr = util.ParseTimep(until); tsErr != nil {
					t.Fatalf("error re-parsing active_until (%q): %s", until, tsErr)
				} else {
					qPost.AssertEqual(t, qPre)
				}
			})

			t.Run("Conflict", func(t *testing.T) {
				now := time.Now().UTC()
				from, until := now.Add(1*time.Hour), now.Add(2*time.Hour)
				qConflict := db.NewTestQuestion(t, Attr{
					"ActiveFrom":  from.Format(time.RFC3339),
					"ActiveUntil": until.Format(time.RFC3339),
				})

				qPre := db.NewTestQuestion(t, Attr{"ChannelID": qConflict.ChannelID})
				newFrom, newUntil := from.Add(10*time.Minute), until.Add(-10*time.Minute)
				if qPost, err := qPre.Update(StringMap{
					"active_from":  newFrom.Format(time.RFC3339),
					"active_until": newUntil.Format(time.RFC3339),
				}); err == nil {
					t.Fatalf("unexpected success: %#v", qPost)
				} else if msg := err.Error(); msg != "question overlaps existing questions" {
					t.Fatalf("wrong error: %q", msg)
				} else if conflicts := err.Meta()["conflicts"].([]int); len(conflicts) != 1 || conflicts[0] != qConflict.ID {
					t.Fatalf("wrong conflicting Question IDs: %v; expected [%d]", conflicts, qConflict.ID)
				}
			})
		})

		t.Run("Answers", func(t *testing.T) {
			qPre := db.NewTestQuestion(t, Attr{"Answers": 3})
			updates := StringMap{"answers": []interface{}{
				map[string]interface{}{"id": qPre.AnswerIDs[1], "text": "existing question"},
				map[string]interface{}{"text": "new question"},
			}}
			if qPost, err := qPre.Update(updates); err != nil {
				t.Fatalf("unexpected error: %s", err)
			} else if len(qPost.AnswerIDs) != 2 {
				t.Fatalf("wrong number of answers in result: %d; answers: %#v", len(qPost.AnswerIDs), qPost.Answers)
			} else if qPost.AnswerIDs[0] != qPre.AnswerIDs[1] {
				t.Fatalf("answer %d was not kept: %v", qPre.AnswerIDs[1], qPost.AnswerIDs)
			} else if qPost.AnswerIDs[1] == qPre.AnswerIDs[0] || qPost.AnswerIDs[1] == qPre.AnswerIDs[2] {
				t.Fatalf("answer %d was kept", qPost.AnswerIDs[1])
			} else if text := qPost.Answers[0].Text; text != "existing question" {
				t.Fatalf("kept answer was not modified: %q", text)
			} else if text = qPost.Answers[1].Text; text != "new question" {
				t.Fatalf("new answer has wrong text: %q", text)
			}
		})
	}
}
