package api

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"goji.io/pattern"

	"code.justin.tv/cb/dashy/internal/auth"
	"code.justin.tv/cb/dashy/internal/clients"
	"code.justin.tv/cb/dashy/internal/clients/zephyr"
	"code.justin.tv/common/goauthorization"
	"code.justin.tv/common/jwt/claim"
	user "code.justin.tv/web/users-service/models"
	"github.com/stretchr/testify/suite"
)

type validateTimeRangeSuite struct {
	suite.Suite
}

func TestValidateTimeRangeSuite(t *testing.T) {
	suite.Run(t, &validateTimeRangeSuite{})
}

func (s *validateTimeRangeSuite) TestBadRequest_NoStartTime() {
	path := "/no_start_time"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := func(w http.ResponseWriter, req *http.Request) {
		panic("not supposed to happen")
	}

	validateTimeRange(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusBadRequest, recorder.Code)
}

func (s *validateTimeRangeSuite) TestBadRequest_NoEndTime() {
	path := "/no_end_time?start_time=2017-07-01T07:00:00Z"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := func(w http.ResponseWriter, req *http.Request) {
		panic("not supposed to happen")
	}

	validateTimeRange(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusBadRequest, recorder.Code)
}

func (s *validateTimeRangeSuite) TestBadRequest_InvalidTimeRange() {
	path := "/end_before_start?start_time=3000-07-01T07:00:00Z&end_time=2000-07-01T07:00:00Z"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := func(w http.ResponseWriter, req *http.Request) {
		panic("not supposed to happen")
	}

	validateTimeRange(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusBadRequest, recorder.Code)
	s.Contains(recorder.Body.String(), "must be before")
}

func (s *validateTimeRangeSuite) TestSuccess_BeforeZephyrStartTime() {
	startTime := zephyr.Start.Add(-24 * time.Hour)
	endTime := zephyr.Start.Add(-12 * time.Hour)
	path := "/with_query_parameters"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	query := req.URL.Query()
	query.Set("start_time", startTime.Format(time.RFC3339))
	query.Set("end_time", endTime.Format(time.RFC3339))
	req.URL.RawQuery = query.Encode()

	innerHandler := func(w http.ResponseWriter, req *http.Request) {
		reqTimeRange, ok := req.Context().Value(contextKeyTimeRange).(timeRange)
		s.True(ok)
		s.Equal(zephyr.Start, reqTimeRange.startTime)
		s.Equal(zephyr.Start, reqTimeRange.endTime)

		w.WriteHeader(http.StatusOK)
	}

	validateTimeRange(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

func (s *validateTimeRangeSuite) TestSuccess_WithContext() {
	startTime := time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC)
	endTime := time.Date(3000, 1, 1, 0, 0, 0, 0, time.UTC)
	path := "/with_query_parameters"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	query := req.URL.Query()
	query.Set("start_time", startTime.Format(time.RFC3339))
	query.Set("end_time", endTime.Format(time.RFC3339))
	req.URL.RawQuery = query.Encode()

	innerHandler := func(w http.ResponseWriter, req *http.Request) {
		reqTimeRange, ok := req.Context().Value(contextKeyTimeRange).(timeRange)
		s.True(ok)
		s.Equal(startTime, reqTimeRange.startTime)
		s.Equal(endTime, reqTimeRange.endTime)

		w.WriteHeader(http.StatusOK)
	}

	validateTimeRange(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

type validateQueryParamChannelIDsSuite struct {
	suite.Suite
}

func TestValidateQueryParamChannelIDsSuite(t *testing.T) {
	suite.Run(t, &validateQueryParamChannelIDsSuite{})
}

func (s *validateQueryParamChannelIDsSuite) TestBadRequest() {
	path := "/no_query_parameter"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		panic("not supposed to happen")
	})

	validateQueryParamChannelIDs(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusBadRequest, recorder.Code)
}

func (s *validateQueryParamChannelIDsSuite) TestSuccessWithContext() {
	path := "/with_query_parameter?channel_ids=1,1,2,2,3,3"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		channelIDs, ok := req.Context().Value(contextKeyChannelIDs).([]int64)
		s.True(ok)

		expectedLength := 3

		if s.Len(channelIDs, expectedLength) {
			for idx := 0; idx < expectedLength; idx++ {
				s.Equal(int64(idx+1), channelIDs[idx])
			}
		}

		w.WriteHeader(http.StatusOK)
	})

	validateQueryParamChannelIDs(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

// mockDecoder is a mock of goauthorization.Decoder
type mockDecoder struct {
	parseToken  *goauthorization.AuthorizationToken
	parseErr    error
	validateErr error
}

func (d *mockDecoder) Decode(string) (*goauthorization.AuthorizationToken, error) { return nil, nil }
func (d *mockDecoder) ParseToken(*http.Request) (*goauthorization.AuthorizationToken, error) {
	return d.parseToken, d.parseErr
}
func (d *mockDecoder) Validate(*goauthorization.AuthorizationToken, goauthorization.CapabilityClaims) error {
	return d.validateErr
}

type authorizeChannelViewStatsSuite struct {
	suite.Suite
}

func TestAuthorizeChannelViewStatsSuite(t *testing.T) {
	suite.Run(t, &authorizeChannelViewStatsSuite{})
}

// mockUsers is a mock implementation of users.Service
type mockUsers struct {
	properties *user.Properties
	err        error
}

func (u *mockUsers) GetUserByID(context.Context, string) (*user.Properties, error) {
	return u.properties, u.err
}

func (s *authorizeChannelViewStatsSuite) TestFailureInvalidToken() {
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseErr: fmt.Errorf("bad parse"),
			},
		},
	}

	path := "/invalid_token"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusForbidden, recorder.Code)
}

func (s *authorizeChannelViewStatsSuite) TestSuccessWithClaims() {
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseToken:  &goauthorization.AuthorizationToken{},
				parseErr:    nil,
				validateErr: nil,
			},
		},
	}

	path := "/valid_claims"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)
	// N.B. - the middleware uses pat.Param so the channel id has to be passed in the context
	req = req.WithContext(context.WithValue(req.Context(), pattern.Variable("channel_id"), "123"))

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

func (s *authorizeChannelViewStatsSuite) TestSuccessWithoutClaimsUserIsChannel() {
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseToken: &goauthorization.AuthorizationToken{
					Claims: goauthorization.TokenClaims{
						Sub: claim.Sub{Sub: "11222333"},
					},
				},
				parseErr:    nil,
				validateErr: fmt.Errorf("bad validate"),
			},
		},
	}

	path := "/user_is_channel"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)
	req = req.WithContext(context.WithValue(req.Context(), pattern.Variable("channel_id"), "11222333"))

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

func (s *authorizeChannelViewStatsSuite) TestSuccessWithoutClaimsUserIsStaff() {
	t := true
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseToken: &goauthorization.AuthorizationToken{
					Claims: goauthorization.TokenClaims{
						Sub: claim.Sub{Sub: "332211"},
					},
				},
				parseErr:    nil,
				validateErr: fmt.Errorf("bad validate"),
			},
		},
		Clients: &clients.Clients{
			Users: &mockUsers{
				properties: &user.Properties{
					Admin: &t,
				},
			},
		},
	}

	path := "/user_is_staff"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)
	req = req.WithContext(context.WithValue(req.Context(), pattern.Variable("channel_id"), "11222333"))

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusOK, recorder.Code)
}

func (s *authorizeChannelViewStatsSuite) TestFailureUsersService() {
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseToken: &goauthorization.AuthorizationToken{
					Claims: goauthorization.TokenClaims{
						Sub: claim.Sub{Sub: "332211"},
					},
				},
				parseErr:    nil,
				validateErr: fmt.Errorf("bad validate"),
			},
		},
		Clients: &clients.Clients{
			Users: &mockUsers{
				err: fmt.Errorf("the internet is on fire"),
			},
		},
	}

	path := "/users_service_broken"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)
	req = req.WithContext(context.WithValue(req.Context(), pattern.Variable("channel_id"), "11222333"))

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusForbidden, recorder.Code)
}

func (s *authorizeChannelViewStatsSuite) TestFailureUnauthorized() {
	server := &Server{
		authDecoder: &auth.Decoder{
			Decoder: &mockDecoder{
				parseToken: &goauthorization.AuthorizationToken{
					Claims: goauthorization.TokenClaims{
						Sub: claim.Sub{Sub: "456789"},
					},
				},
				parseErr:    nil,
				validateErr: fmt.Errorf("bad validate"),
			},
		},
		Clients: &clients.Clients{
			Users: &mockUsers{
				properties: &user.Properties{},
			},
		},
	}

	path := "/user_is_not_staff"
	recorder := httptest.NewRecorder()

	req, err := http.NewRequest(http.MethodGet, path, nil)
	s.Require().NoError(err)
	req = req.WithContext(context.WithValue(req.Context(), pattern.Variable("channel_id"), "11222333"))

	innerHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		w.WriteHeader(http.StatusOK)
	})

	server.authorizeChannelViewStats(innerHandler).ServeHTTP(recorder, req)

	s.Equal(http.StatusForbidden, recorder.Code)
}
