package api

import (
	"fmt"
	"io/ioutil"
	"math/big"
	"net/http"
	"net/http/httptest"

	"code.justin.tv/cb/roster/internal/api/mocks"
	"code.justin.tv/cb/roster/internal/authorization"
	"code.justin.tv/cb/roster/internal/clients/telemetryhook"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
)

var _ = Describe("requestValidationMiddleware", func() {
	var (
		server   *Server
		recorder *httptest.ResponseRecorder
	)

	BeforeEach(func() {
		recorder = httptest.NewRecorder()
		server = NewServer(&ServerParams{
			AuthDecoder:      &authorization.Decoder{},
			DBWriter:         &mocks.DBWriter{},
			Users:            &mocks.Users{},
			TelemetryHandler: &telemetryhook.NoopClient{},
		})
	})

	Describe("validateNumericTeamID", func() {
		It("fails with Bad Request when the team id is invalid", func() {
			path := fmt.Sprintf("/v1/teams/%s/channels/%s/membership", "invalid-team-id", "456")
			req, err := http.NewRequest(http.MethodPatch, path, nil)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)

			Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))

			b, err := ioutil.ReadAll(recorder.Result().Body)
			Expect(err).NotTo(HaveOccurred())
			Expect(string(b)).To(ContainSubstring("invalid team id (must be numeric)"))
		})
	})

	Describe("validateNumericChannelID", func() {
		It("fails with Bad Request when the channel id is invalid", func() {
			path := fmt.Sprintf("/v1/teams/%s/channels/%s/membership", "123", "invalid-channel-id")
			req, err := http.NewRequest(http.MethodPatch, path, nil)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)

			Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))

			b, err := ioutil.ReadAll(recorder.Result().Body)
			Expect(err).NotTo(HaveOccurred())
			Expect(string(b)).To(ContainSubstring("invalid channel id (must be numeric)"))
		})
	})

	Describe("validateNumericUserID", func() {
		It("fails with Bad Request when the user id is invalid", func() {
			bigInt := big.NewInt(9223372036854775807)
			overflow := bigInt.Add(bigInt, big.NewInt(1))

			path := fmt.Sprintf("/v1/users/%s/teams", overflow.String())
			req, err := http.NewRequest(http.MethodPost, path, nil)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)

			Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))

			b, err := ioutil.ReadAll(recorder.Result().Body)
			Expect(err).NotTo(HaveOccurred())
			Expect(string(b)).To(ContainSubstring("invalid user id (must be numeric)"))
		})
	})
})
