package api

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"

	v1 "code.justin.tv/cb/roster/api/v1"
	"code.justin.tv/cb/roster/internal/api/mocks"
	"code.justin.tv/cb/roster/internal/clients/telemetryhook"
	"code.justin.tv/cb/roster/internal/db"
	"code.justin.tv/web/users-service/models"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
	"github.com/stretchr/testify/mock"
)

var _ = Describe("PutV1ChannelMembership", func() {
	var (
		cache    *mocks.Cache
		dbWriter *mocks.DBWriter
		pushy    *mocks.Pushy
		users    *mocks.Users
		server   *Server
		recorder *httptest.ResponseRecorder

		channelID string
		teamID    string
	)

	BeforeEach(func() {
		recorder = httptest.NewRecorder()
		cache = &mocks.Cache{}
		dbWriter = &mocks.DBWriter{}
		pushy = &mocks.Pushy{}
		users = &mocks.Users{}

		server = NewServer(&ServerParams{
			Cache:            cache,
			DBWriter:         dbWriter,
			Pushy:            pushy,
			Users:            users,
			TelemetryHandler: &telemetryhook.NoopClient{},
		})

		teamID = "123"
		channelID = "456"
	})

	Context("when the request body is malformed", func() {
		It("fails with Bad Request", func() {
			path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, teamID)

			buffer := bytes.NewBufferString("not valid json")

			req, err := http.NewRequest(http.MethodPut, path, buffer)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)

			Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
			Expect(recorder.Body.String()).To(ContainSubstring("invalid request body"))
		})
	})

	It("fails with Bad Request when channel ID is invalid", func() {
		path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", "=D", teamID)

		reqBody := `{
			"revenue_revealed": true,
			"stats_revealed": false
		}`

		buffer := bytes.NewBufferString(reqBody)

		req, err := http.NewRequest(http.MethodPut, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("invalid channel id (must be numeric)"))
	})

	It("fails with Bad Request when team ID is invalid", func() {
		path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, "=D")

		reqBody := `{
			"revenue_revealed": true,
			"stats_revealed": false
		}`

		buffer := bytes.NewBufferString(reqBody)

		req, err := http.NewRequest(http.MethodPut, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("invalid team id (must be numeric)"))
	})

	It("fails with Bad Request when 'revenue_revealed' is missing'", func() {
		path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, teamID)

		reqBody := `{
			"stats_revealed": false
		}`

		buffer := bytes.NewBufferString(reqBody)

		req, err := http.NewRequest(http.MethodPut, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("'revenue_revealed' is required"))
	})

	It("fails with Bad Request when 'stats_revealed' is missing'", func() {
		path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, teamID)

		reqBody := `{
			"revenue_revealed": true
		}`

		buffer := bytes.NewBufferString(reqBody)

		req, err := http.NewRequest(http.MethodPut, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("'stats_revealed' is required"))
	})

	Context("when the request parameters are valid, without changing primary team status", func() {
		var revenueRevealed = true
		var statsRevealed = false

		JustBeforeEach(func() {
			path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, teamID)
			reqBody := v1.PutChannelMembershipRequestBody{
				RevenueRevealed: &revenueRevealed,
				StatsRevealed:   &statsRevealed,
			}

			buffer := new(bytes.Buffer)
			err := json.NewEncoder(buffer).Encode(&reqBody)
			Expect(err).NotTo(HaveOccurred())

			req, err := http.NewRequest(http.MethodPut, path, buffer)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)
		})

		Context("when db does not find membership record", func() {
			BeforeEach(func() {
				dbWriter.On("UpdateMembership", mock.Anything, teamID, channelID, revenueRevealed, statsRevealed).
					Return(db.ErrNoMembershipForUpdate)
			})

			It("returns Not Found", func() {
				dbWriter.AssertExpectations(GinkgoT())

				Expect(recorder.Result().StatusCode).To(Equal(http.StatusNotFound))
				Expect(recorder.Body.String()).To(ContainSubstring("membership not found"))
			})
		})

		Context("when db fails to update the membership record", func() {
			BeforeEach(func() {
				dbWriter.On("UpdateMembership", mock.Anything, teamID, channelID, revenueRevealed, statsRevealed).
					Return(errors.New("some db error"))
			})

			It("returns Internal Server Error", func() {
				dbWriter.AssertExpectations(GinkgoT())

				Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
				Expect(recorder.Body.String()).To(ContainSubstring("db: failed to update membership"))
			})
		})

		Context("when db updates the membership record", func() {
			BeforeEach(func() {
				dbWriter.On("UpdateMembership", mock.Anything, teamID, channelID, revenueRevealed, statsRevealed).
					Return(nil)

				cache.On("ClearChannelMemberships", mock.Anything, channelID).Return(nil)
				cache.On("ClearAllTeamMembershipsForTeam", mock.Anything, teamID).Return(nil)
			})

			It("returns No Content", func() {
				dbWriter.AssertExpectations(GinkgoT())
				pushy.AssertExpectations(GinkgoT())

				Expect(recorder.Result().StatusCode).To(Equal(http.StatusNoContent))
			})
		})
	})

	Context("when the request parameters are valid, with changing primary team status", func() {
		var primary bool
		var revenueRevealed = true
		var statsRevealed = false

		BeforeEach(func() {
			dbWriter.On("UpdateMembership", mock.Anything, teamID, channelID, revenueRevealed, statsRevealed).
				Return(nil)
		})

		JustBeforeEach(func() {
			path := fmt.Sprintf("/v1/channels/%s/teams/%s/membership", channelID, teamID)

			reqBody := v1.PutChannelMembershipRequestBody{
				Primary:         &primary,
				RevenueRevealed: &revenueRevealed,
				StatsRevealed:   &statsRevealed,
			}

			buffer := new(bytes.Buffer)
			err := json.NewEncoder(buffer).Encode(&reqBody)
			Expect(err).NotTo(HaveOccurred())

			req, err := http.NewRequest(http.MethodPut, path, buffer)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)
		})

		Context("when setting the team membership to be primary", func() {
			BeforeEach(func() {
				primary = true
			})

			Context("when users service fails to update", func() {
				BeforeEach(func() {
					users.On("Set", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("users service error"))
				})

				It("returns Internal Server Error", func() {
					dbWriter.AssertExpectations(GinkgoT())
					users.AssertExpectations(GinkgoT())
					pushy.AssertExpectations(GinkgoT())

					Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
					Expect(recorder.Body.String()).To(ContainSubstring("users service: failed to update channel's primary team id"))
				})
			})

			Context("when users service updates successfully", func() {
				primaryTeamID := uint64(123)

				BeforeEach(func() {
					users.On("Set", mock.Anything, models.UpdateChannelProperties{
						ID:            uint64(456),
						PrimaryTeamID: &primaryTeamID,
					}, mock.Anything).Return(nil)

					cache.On("ClearChannelMemberships", mock.Anything, channelID).Return(nil)
					cache.On("ClearAllTeamMembershipsForTeam", mock.Anything, teamID).Return(nil)
				})

				It("returns No Content", func() {
					dbWriter.AssertExpectations(GinkgoT())
					users.AssertExpectations(GinkgoT())
					pushy.AssertExpectations(GinkgoT())

					Expect(recorder.Result().StatusCode).To(Equal(http.StatusNoContent))
				})
			})
		})

		Context("when setting the team membership to not be primary", func() {
			BeforeEach(func() {
				primary = false

				cache.On("ClearChannelMemberships", mock.Anything, channelID).Return(nil)
				cache.On("ClearAllTeamMembershipsForTeam", mock.Anything, teamID).Return(nil)
			})

			It("returns No Content", func() {
				dbWriter.AssertExpectations(GinkgoT())

				Expect(recorder.Result().StatusCode).To(Equal(http.StatusNoContent))
			})
		})
	})
})
