package sessionviewer

import (
	"fmt"
	"net/http"
	"strconv"
	"time"

	sec "code.justin.tv/amzn/StarfruitSECTwirp"
	cc "code.justin.tv/event-engineering/carrot-control/pkg/rpc"
	csa "code.justin.tv/event-engineering/carrot-stream-analysis/pkg/rpc"
	"code.justin.tv/event-engineering/starfruit-support-portal/app/util"
	"github.com/aws/aws-sdk-go/aws/arn"
	"github.com/golang/protobuf/ptypes"
	"github.com/sirupsen/logrus"
	goji "goji.io"
	"goji.io/pat"
)

type handler struct {
	mux    *goji.Mux
	cc     cc.CarrotControl
	logger logrus.FieldLogger
}

// NewSessionViewerHandler will forward sessionviewer API requests to the carrot control service
func NewSessionViewerHandler(carrotControl cc.CarrotControl, logger logrus.FieldLogger) http.Handler {
	handler := &handler{
		mux:    goji.NewMux(),
		logger: logger,
		cc:     carrotControl,
	}

	handler.mux.HandleFunc(pat.New("/:customer_id/:region/:content_id"), handler.ListSessions)
	handler.mux.HandleFunc(pat.New("/:customer_id/:region/:content_id/:session_id"), handler.GetSessionDetail)
	handler.mux.HandleFunc(pat.New("/:customer_id/:region/:content_id/:session_id/events"), handler.GetSessionEvents)
	handler.mux.HandleFunc(pat.New("/:customer_id/:region/:content_id/:session_id/limit_breaches"), handler.GetLimitBreaches)

	return handler.mux
}

func (s *handler) ListSessions(writer http.ResponseWriter, request *http.Request) {
	customerID := pat.Param(request, "customer_id")
	region := pat.Param(request, "region")
	contentID := pat.Param(request, "content_id")
	before := request.URL.Query().Get("before")

	// We're just going to paginate this list, showing 14 days of sessions per page
	end := time.Now()

	if before != "" {
		beforeInt, err := strconv.ParseInt(before, 10, 64)

		if err != nil {
			util.HandleError(writer, 400, fmt.Errorf("Expected `before` to be a Unix timestamp, got %v", before))
			return
		}

		end = time.Unix(beforeInt-1, 0)
	}

	start := end.Add(-time.Hour * 24 * 14)

	pStart, err := ptypes.TimestampProto(start)
	if err != nil {
		util.HandleError(writer, 400, err)
		return
	}

	pEnd, err := ptypes.TimestampProto(end)
	if err != nil {
		util.HandleError(writer, 400, err)
		return
	}

	s.logger.Debugf("Calling GetChannelSessions with customerID %v and contentID %v", customerID, contentID)
	resp, err := s.cc.GetChannelSessions(request.Context(), &csa.GetChannelSessionsRequest{
		CustomerId: customerID,
		Region:     region,
		ContentId:  contentID,
		Start:      pStart,
		End:        pEnd,
	})

	util.HandleResponse(writer, resp, err)
}

func (s *handler) GetSessionDetail(writer http.ResponseWriter, request *http.Request) {
	customerID := pat.Param(request, "customer_id")
	region := pat.Param(request, "region")
	contentID := pat.Param(request, "content_id")
	sessionID := pat.Param(request, "session_id")

	s.logger.Debugf("Calling GetSessionData with customerID %v contentID %v and sessionID %v", customerID, contentID, sessionID)
	resp, err := s.cc.GetSessionData(request.Context(), &csa.GetSessionDataRequest{
		CustomerId: customerID,
		Region:     region,
		ContentId:  contentID,
		SessionId:  sessionID,
	})

	util.HandleResponse(writer, resp, err)
}

func (s *handler) GetSessionEvents(writer http.ResponseWriter, request *http.Request) {
	customerID := pat.Param(request, "customer_id")
	region := pat.Param(request, "region")
	contentID := pat.Param(request, "content_id")
	sessionID := pat.Param(request, "session_id")

	s.logger.Debugf("Calling ListStreamEvents with customerID %v contentID %v and sessionID %v", customerID, contentID, sessionID)

	channelArn := arn.ARN{
		Partition: "aws",
		Service:   "ivs",
		Region:    region,
		AccountID: customerID,
		Resource:  fmt.Sprintf("channel/%v", contentID),
	}

	resp, err := s.cc.ListStreamEvents(request.Context(), &cc.ListStreamEventsRequest{
		Region: region,
		Request: &sec.ListStreamEventsRequest{
			ChannelArn:         channelArn.String(),
			BroadcastSessionId: sessionID,
		},
	})

	util.HandleResponse(writer, resp, err)
}

func (s *handler) GetLimitBreaches(writer http.ResponseWriter, request *http.Request) {
	customerID := pat.Param(request, "customer_id")
	region := pat.Param(request, "region")
	contentID := pat.Param(request, "content_id")
	sessionID := pat.Param(request, "session_id")
	pageToken := request.URL.Query().Get("page_token")

	s.logger.Debugf("Calling ListLimitBreachEvents with customerID %v contentID %v and sessionID %v", customerID, contentID, sessionID)

	resp, err := s.cc.ListLimitBreachEvents(request.Context(), &cc.ListLimitBreachEventsRequest{
		Region: region,
		Request: &sec.ListLimitBreachEventsRequest{
			CustomerId: customerID,
			SessionId:  sessionID,
			PageSize:   10,
			PageToken:  pageToken,
		},
	})

	util.HandleResponse(writer, resp, err)
}
