package api

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"

	v1 "code.justin.tv/cb/roster/api/v1"
	"code.justin.tv/cb/roster/internal/api/mocks"
	"code.justin.tv/cb/roster/internal/clients/telemetryhook"
	"code.justin.tv/cb/roster/internal/db"
	"code.justin.tv/web/users-service/client/channels"
	"code.justin.tv/web/users-service/models"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
	"github.com/stretchr/testify/mock"
)

var _ = Describe("PostV1TeamMemberships", func() {
	var (
		mockedCache *mocks.Cache
		dbReader    *mocks.DBReader
		dbWriter    *mocks.DBWriter
		users       *mocks.Users
		server      *Server
		recorder    *httptest.ResponseRecorder

		teamID, channelID              string
		revenueRevealed, statsRevealed bool
	)

	BeforeEach(func() {
		recorder = httptest.NewRecorder()
		mockedCache = &mocks.Cache{}
		dbReader = &mocks.DBReader{}
		dbWriter = &mocks.DBWriter{}
		users = &mocks.Users{}

		server = NewServer(&ServerParams{
			Cache:            mockedCache,
			DBReader:         dbReader,
			DBWriter:         dbWriter,
			Users:            users,
			TelemetryHandler: &telemetryhook.NoopClient{},
		})

		teamID = "123"
		channelID = "999"
		revenueRevealed = true
		statsRevealed = true
	})

	It("fails with Bad Request when request body is invalid", func() {
		path := fmt.Sprintf("/v1/teams/%s/memberships", teamID)
		req, err := http.NewRequest(http.MethodPost, path, bytes.NewReader([]byte{}))
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("invalid request body"))
	})

	It("fails with Bad Request when channel ID is invalid", func() {
		path := fmt.Sprintf("/v1/teams/%s/memberships", teamID)
		reqBody := v1.PostTeamMembershipsRequestBody{
			ChannelID: "",
		}

		buffer := new(bytes.Buffer)
		err := json.NewEncoder(buffer).Encode(&reqBody)
		Expect(err).NotTo(HaveOccurred())

		req, err := http.NewRequest(http.MethodPost, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("channel ID cannot be empty"))
	})

	It("fails with Bad Request when 'revenue_revealed' is invalid", func() {
		path := fmt.Sprintf("/v1/teams/%s/memberships", "123")
		reqBody := v1.PostTeamMembershipsRequestBody{
			ChannelID: channelID,
		}

		buffer := new(bytes.Buffer)
		err := json.NewEncoder(buffer).Encode(&reqBody)
		Expect(err).NotTo(HaveOccurred())

		req, err := http.NewRequest(http.MethodPost, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("revenue revealed must be true or false"))
	})

	It("fails with Bad Request when 'revenue_revealed' is invalid", func() {
		path := fmt.Sprintf("/v1/teams/%s/memberships", "123")
		reqBody := v1.PostTeamMembershipsRequestBody{
			ChannelID:       channelID,
			RevenueRevealed: &revenueRevealed,
		}

		buffer := new(bytes.Buffer)
		err := json.NewEncoder(buffer).Encode(&reqBody)
		Expect(err).NotTo(HaveOccurred())

		req, err := http.NewRequest(http.MethodPost, path, buffer)
		Expect(err).NotTo(HaveOccurred())

		server.ServeHTTP(recorder, req)

		Expect(recorder.Result().StatusCode).To(Equal(http.StatusBadRequest))
		Expect(recorder.Body.String()).To(ContainSubstring("stats revealed must be true or false"))
	})

	Context("when all request parameters are valid", func() {
		JustBeforeEach(func() {
			path := fmt.Sprintf("/v1/teams/%s/memberships", teamID)
			reqBody := v1.PostTeamMembershipsRequestBody{
				ChannelID:       channelID,
				RevenueRevealed: &revenueRevealed,
				StatsRevealed:   &statsRevealed,
			}

			buffer := new(bytes.Buffer)
			err := json.NewEncoder(buffer).Encode(&reqBody)
			Expect(err).NotTo(HaveOccurred())

			req, err := http.NewRequest(http.MethodPost, path, buffer)
			Expect(err).NotTo(HaveOccurred())

			server.ServeHTTP(recorder, req)
		})

		AfterEach(func() {
			dbReader.AssertExpectations(GinkgoT())
		})

		Context("when the team does not exist", func() {
			BeforeEach(func() {
				dbReader.On("GetTeamByID", mock.Anything, teamID).Return(db.Team{}, db.ErrNoTeam)
			})

			It("returns Not Found", func() {
				Expect(recorder.Result().StatusCode).To(Equal(http.StatusNotFound))
				Expect(recorder.Body.String()).To(ContainSubstring("team with ID 123 not found"))
			})
		})

		Context("when the DB errors while querying for the team", func() {
			BeforeEach(func() {
				dbReader.On("GetTeamByID", mock.Anything, teamID).Return(db.Team{}, errors.New("🔥"))
			})

			It("returns Internal Server Error", func() {
				Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
				Expect(recorder.Body.String()).To(ContainSubstring("db: failed to query team"))
			})
		})

		Context("when the team is found", func() {
			BeforeEach(func() {
				dbReader.On("GetTeamByID", mock.Anything, teamID).Return(db.Team{
					ID: teamID,
				}, nil)
			})

			AfterEach(func() {
				users.AssertExpectations(GinkgoT())
			})

			Context("when the channel is not found in Users Service", func() {
				BeforeEach(func() {
					users.On("GetByIDAndParams", mock.Anything, channelID, mock.Anything, mock.Anything).
						Return(nil, &channels.ErrChannelNotFound{})
				})

				It("returns Unprocessable Entity", func() {
					Expect(recorder.Result().StatusCode).To(Equal(http.StatusUnprocessableEntity))
					Expect(recorder.Body.String()).To(ContainSubstring("Channel not found"))
				})
			})

			Context("when request to the Users Service fails", func() {
				BeforeEach(func() {
					users.On("GetByIDAndParams", mock.Anything, channelID, mock.Anything, mock.Anything).
						Return(nil, errors.New("❌"))
				})

				It("fails with Internal Server Error", func() {
					Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
					Expect(recorder.Body.String()).To(ContainSubstring("users service: failed to look up channel"))
				})
			})

			Context("when the Users Service returns the channel", func() {
				BeforeEach(func() {
					params := &models.ChannelFilterParams{
						NotDeleted:      true,
						NoTOSViolation:  true,
						NoDMCAViolation: true,
					}

					users.On("GetByIDAndParams", mock.Anything, channelID, params, mock.Anything).Return(&channels.Channel{
						ID: 999,
					}, nil)
				})

				Context("when the channel is already a member in DB", func() {
					BeforeEach(func() {
						dbReader.On("GetMembership", mock.Anything, teamID, channelID).Return(db.Membership{}, nil)
					})

					It("returns Unprocessable Entity", func() {
						Expect(recorder.Result().StatusCode).To(Equal(http.StatusUnprocessableEntity))
						Expect(recorder.Body.String()).To(ContainSubstring("channel already has an existing team membership"))
					})
				})

				Context("when DB errors while querying for a membership", func() {
					BeforeEach(func() {
						dbReader.On("GetMembership", mock.Anything, teamID, channelID).
							Return(db.Membership{}, errors.New("🙈"))
					})

					It("returns Internal Server Error", func() {
						Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
						Expect(recorder.Body.String()).To(ContainSubstring("db: failed to query membership"))
					})
				})

				Context("when DB returns no matching membership", func() {
					BeforeEach(func() {
						dbReader.On("GetMembership", mock.Anything, teamID, channelID).
							Return(db.Membership{}, db.ErrNoMembership)
					})

					AfterEach(func() {
						dbWriter.AssertExpectations(GinkgoT())
					})

					Context("when DB creates no membership", func() {
						BeforeEach(func() {
							dbWriter.On("CreateMembership", mock.Anything, mock.Anything).
								Return(db.ErrNoMembershipCreated)
						})

						It("returns Conflict", func() {
							Expect(recorder.Result().StatusCode).To(Equal(http.StatusConflict))
							Expect(recorder.Body.String()).To(ContainSubstring("failed to create membership"))
						})
					})

					Context("when DB fails to create membership", func() {
						BeforeEach(func() {
							dbWriter.On("CreateMembership", mock.Anything, mock.Anything).Return(errors.New("🐅"))
						})

						It("returns Conflict", func() {
							Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
							Expect(recorder.Body.String()).To(ContainSubstring("db: failed to insert a membership record"))
						})
					})

					Context("when DB successfully creates the membership", func() {
						var membership db.Membership

						BeforeEach(func() {
							membership = db.Membership{
								ChannelID:       channelID,
								TeamID:          teamID,
								RevenueRevealed: revenueRevealed,
								StatsRevealed:   statsRevealed,
							}

							dbWriter.On("CreateMembership", mock.Anything, membership).Return(nil)
						})

						Context("when Users service fails to update the channel's primary team ID", func() {
							BeforeEach(func() {
								users.On("Set", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("🐢"))
							})

							It("returns Internal Server Error", func() {
								Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError))
								Expect(recorder.Body.String()).To(ContainSubstring("users service: failed to update channel's primary team ID"))
							})
						})

						Context("when Users service succeeds in updating the channel's primary team ID", func() {
							BeforeEach(func() {
								primaryTeamID := uint64(123)

								params := models.UpdateChannelProperties{
									ID:            999,
									PrimaryTeamID: &primaryTeamID,
								}

								users.On("Set", mock.Anything, params, mock.Anything).Return(nil)
								mockedCache.On("ClearChannelMemberships", mock.Anything, channelID).Return(nil)
								mockedCache.On("ClearAllTeamMembershipsForTeam", mock.Anything, teamID).Return(nil)
							})

							It("returns Created", func() {
								Expect(recorder.Result().StatusCode).To(Equal(http.StatusCreated))
							})
						})
					})
				})
			})
		})
	})
})
