package service

import (
	"fmt"
	"regexp"
	"time"

	"github.com/golang/protobuf/ptypes"
	"github.com/golang/protobuf/ptypes/wrappers"
	"github.com/pkg/errors"
	"github.com/twitchtv/twirp"

	"context"

	"encoding/json"

	"strings"

	"net/http"

	"net/url"

	"code.justin.tv/video/lvsapi/internal/auth"
	"code.justin.tv/video/lvsapi/internal/awsutils"
	"code.justin.tv/video/lvsapi/internal/caching"
	"code.justin.tv/video/lvsapi/internal/constants"
	"code.justin.tv/video/lvsapi/internal/digestion"
	"code.justin.tv/video/lvsapi/internal/logging"
	"code.justin.tv/video/lvsapi/internal/metrics"
	"code.justin.tv/video/lvsapi/internal/usher"
	"code.justin.tv/video/lvsapi/internal/viewcounts"
	"code.justin.tv/video/lvsapi/rpc/lvs"
	"code.justin.tv/video/lvsapi/streamkey"
)

// maxContentIDLength defines the maximum length of a content id
const maxContentIDLength = 40

const defaultStreamkeyExpiry = 300

// Return this value if a specific streamkey metadata key is not specified
const defaultStreamKeyParameterValue = "Not Specified"
const rtmpServerUrl = "rtmp://rtmplive.twitch.tv/app"
const playBackUrlTemplate = "https://usher.ttvnw.net/api/lvs/hls/%s.m3u8?allow_source=true&player_backend=mediaplayer"

// contentIDRegex defines the format of a content Identifier
var contentIDRegex = regexp.MustCompile(`([A-Za-z0-9\-\_]+)`)

type service struct {
	secretSource     streamkey.SecretSource
	s3APIs           awsutils.S3APIs
	cache            caching.CacheInterface
	metrics          *metrics.Client
	dgnAPIs          digestion.DigestionBackend
	viewcountBackend viewcounts.ViewcountsCache
}

// New returns an implementation of code.justin.tv.video.lvsapi.rpc.lvs.LiveVideoService
func New(ss streamkey.SecretSource, cache caching.CacheInterface, s3APIBackend awsutils.S3APIs, m *metrics.Client, dgn digestion.DigestionBackend, viewcountsBackend viewcounts.ViewcountsCache) lvs.LiveVideoService {
	return &service{
		secretSource:     ss,
		s3APIs:           s3APIBackend,
		cache:            cache,
		metrics:          m,
		dgnAPIs:          dgn,
		viewcountBackend: viewcountsBackend,
	}
}

// LVSMeta has details about the full s3 vod url and SNS Notification endpoint
type LVSMeta struct {
	S3VodUrl       string `json:"s3_vod_url"`
	SnsEndpoint    string `json:"sns_endpoint"`
	VodManifestUrl string `json:"vod_manifest_url"`
}

// IsValidContentID validates a contentID only has letter, numbers, underscores or dashes
// and is smaller than 40 characters
func IsValidContentID(contentID string) bool {
	parsed := contentIDRegex.FindAllString(contentID, -1)

	if len(parsed) != 1 {
		return false
	}

	// ContentID Regex should only find a single match.
	if parsed[0] == contentID && (len(parsed[0]) < maxContentIDLength) {
		return true
	}

	return false
}

// CreateStreamKey will return a valid streamkey along with a URL to playback video.
func (s *service) CreateStreamKey(ctx context.Context, req *lvs.CreateStreamKeyRequest) (*lvs.CreateStreamKeyResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)
	var customerId string
	startTime := time.Now()
	defer func() {
		s.recordMetric(ctx, customerId, "CreateStreamKey", "Time", metrics.UnitMillisecond, float64(time.Since(startTime))/float64(time.Millisecond))
	}()

	if !IsValidContentID(req.GetContentId()) {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("must match regex '%s'", contentIDRegex))
	}

	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	seconds := req.GetTtlSeconds()
	if seconds == 0 {
		seconds = defaultStreamkeyExpiry
	}
	expiration, err := ptypes.TimestampProto(time.Now().Add(time.Duration(seconds) * time.Second))
	if err != nil {
		return nil, twirp.InvalidArgumentError("ttl_seconds", "invalid ttl")
	}
	// Validate S3 bucket and region
	// Only validate if s3 bucket is not empty
	s3Bucket := req.GetS3Bucket()
	if s3Bucket != "" {
		// Check if there is an ARN prefix
		if strings.HasPrefix(s3Bucket, constants.S3_ARN_PREFIX) {
			return nil, twirp.InvalidArgumentError("s3_bucket", "please provide just the bucket name not an ARN")
		}

		//Check if there is an http or https path style url
		if strings.HasPrefix(s3Bucket, constants.HTTP_PREFIX) ||
			strings.HasPrefix(s3Bucket, constants.HTTPS_PREFIX) {
			return nil, twirp.InvalidArgumentError("s3_bucket", "please provide just the bucket name not URL")
		}

		//To simplify things we just query S3 to get the region
		//If there are any error codes , we use them to communicate back in the response
		bucket, err := s.s3APIs.GetBucketRegion(ctx, s3Bucket)
		if err != nil {
			if strings.Contains(err.Error(), "BadRequest") {
				return nil, twirp.InvalidArgumentError("s3_bucket", "s3 bucket name is invalid")
			}

			if strings.Contains(err.Error(), "NotFound") {
				return nil, twirp.InvalidArgumentError("s3_bucket", "bucket not found")
			}

			logging.Info(ctx, "GetBucketRegion call for bucket :%s threw error :%+v", bucket, err)
			return nil, twirp.InvalidArgumentError("s3_bucket", "unable to verify the region for the given bucket")
		}

		if bucket != constants.US_WEST2_REGION {
			return nil, twirp.InvalidArgumentError("s3_bucket", "please provide a bucket in us-west-2 region")
		}
	}

	if req.GetSnsNotificationEndpoint() != "" {
		err := awsutils.ValidateSnsArn(req.GetSnsNotificationEndpoint())
		if err != nil {
			return nil, twirp.InvalidArgumentError("sns_notification_endpoint", err.Error())
		}
	}

	if req.GetCdnUrl() != "" {
		//Validate the url and throw an error if it is not a valid url
		_, err := url.ParseRequestURI(req.GetCdnUrl())
		if err != nil {
			return nil, twirp.InvalidArgumentError("cloudfront_url", err.Error())
		}
	}

	//Check for a valid latency mode, if not return invalid arguemtn
	if !isValidLatencyMode(req.GetLatencyMode()) {
		return nil, twirp.InvalidArgumentError("latency_mode", "accepatable values are normal or low, leave it empty for normal latency mode")
	}

	logging.Debug(ctx, "CustomerID: %s, ContentID: %s", customerId, req.GetContentId())
	channelName := fmt.Sprintf("lvs.%s.%s", customerId, req.GetContentId())
	privData := &streamkey.PrivateData{
		ExpirationTime:          expiration,
		ContentId:               req.GetContentId(),
		S3Bucket:                req.GetS3Bucket(),
		SnsNotificationEndpoint: req.GetSnsNotificationEndpoint(),
		S3Prefix:                req.GetS3Prefix(),
		CustomerId:              customerId,
		ChannelName:             channelName,
		LowLatencyMode:          req.GetEnableLowLatency(),
		CdnUrl:                  req.GetCdnUrl(),
		LatencyMode:             req.GetLatencyMode(),
	}
	streamKey := streamkey.NewV1(customerId, privData)
	encryptedStreamKey, err := streamkey.Encrypt(ctx, s.secretSource, streamKey)
	if err != nil {
		logging.HandlerError(ctx, err)
		return nil, twirp.InternalError("unable to encrypt stream key")
	}

	response := lvs.CreateStreamKeyResponse{
		RtmpIngestUrl: fmt.Sprintf("%s/%s", rtmpServerUrl, encryptedStreamKey),
		PlaybackUrl:   fmt.Sprintf(playBackUrlTemplate, channelName),
		RtmpServer:    rtmpServerUrl,
		Streamkey:     encryptedStreamKey,
		CustomerId:    customerId,
		ContentId:     req.GetContentId(),
		StreamkeyMetadata: &lvs.StreamKeyMetadata{
			S3Bucket:                getStreamKeyParameter(req.GetS3Bucket()),
			S3Prefix:                getStreamKeyParameter(req.GetS3Prefix()),
			SnsNotificationEndpoint: getStreamKeyParameter(req.GetSnsNotificationEndpoint()),
			ExpirationTime:          expiration,
			EnableLowLatency:        req.GetEnableLowLatency(),
			CdnUrl:                  getStreamKeyParameter(req.GetCdnUrl()),
			LatencyMode:             req.GetLatencyMode(),
		},
	}

	return &response, nil

}

// CheckAuth will perform a test to verify authentication is functional
func (s *service) CheckAuth(ctx context.Context, req *lvs.CheckAuthRequest) (*lvs.CheckAuthResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)

	customerID, ok := auth.LvsCustomerID(ctx)
	if !ok {
		return nil, twirp.NewError(twirp.Unauthenticated, "missing client auth")
	}

	return &lvs.CheckAuthResponse{LvsCustomerId: customerID}, nil
}

// GetStream will fetch medatada for a given ContentID stream
func (s *service) GetStream(ctx context.Context, req *lvs.GetStreamRequest) (*lvs.GetStreamResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)

	var customerId string
	startTime := time.Now()
	defer func() {
		s.recordMetric(ctx, customerId, "GetStream", "Time", metrics.UnitMillisecond, float64(time.Since(startTime))/float64(time.Millisecond))
	}()
	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	if !IsValidContentID(req.GetContentId()) {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("must match regex '%s'", contentIDRegex))
	}

	contentId := req.GetContentId()
	if contentId == "" {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("Length is 0"))
	}

	us, err := s.cache.GetStream(ctx, customerId, req.GetContentId())
	if err != nil {
		logging.HandlerError(ctx, err)
		return nil, err
	}
	logging.Debug(ctx, "response: %v", us)

	//Invoke viewcounts API and get viewcount based on streamid
	count, err := s.viewcountBackend.GetCount(ctx, us.StreamId)

	stream, err := streamFromUsherResponse(customerId, req.GetContentId(), us, count)
	if err != nil {
		return nil, twirp.InternalErrorWith(err)
	}

	return &lvs.GetStreamResponse{
		Stream: stream,
	}, nil
}

// ListStreams will fetch and display a list of all live streams which come from the given LVS account
func (s *service) ListStreams(ctx context.Context, req *lvs.ListStreamsRequest) (*lvs.ListStreamsResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)
	var customerId string
	startTime := time.Now()
	defer func() {
		s.recordMetric(ctx, customerId, "ListStreams", "Time", metrics.UnitMillisecond, float64(time.Since(startTime))/float64(time.Millisecond))
	}()

	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	requestedStreamType := req.Type
	if requestedStreamType == "" {
		return nil, twirp.InvalidArgumentError("type", "required argument missing")
	}

	if requestedStreamType != constants.STREAM_TYPE_ALL && requestedStreamType != constants.STREAM_TYPE_PREPARING &&
		requestedStreamType != constants.STREAM_TYPE_LIVE {
		return nil, twirp.InvalidArgumentError("type", "accepted values for argument should be one of (all , live or preparing) ")
	}

	stl, err := s.cache.ListStreams(ctx, customerId)

	if err != nil {
		logging.HandlerError(ctx, err)
		return nil, err
	}

	logging.Debug(ctx, "response: %v", stl)

	response := &lvs.ListStreamsResponse{}

	// Filter streams based on the request type
	for _, stream := range stl {
		if isRequestedStreamType(requestedStreamType, stream.Status) {
			//Invoke viewcount api and get information
			//Invoke viewcounts API and get viewcount based on streamid
			count, err := s.viewcountBackend.GetCount(ctx, stream.StreamId)

			s, err := streamFromUsherResponse(customerId, stream.ContentId, &stream, count)
			if err != nil {
				return nil, twirp.InternalErrorWith(err)
			}
			response.Streams = append(response.Streams, s)
		}
	}

	if len(response.Streams) == 0 {
		return nil, twirp.NewError(twirp.NotFound, "No Streams found")
	}

	return response, nil
}

func (s *service) DecodeStreamKey(ctx context.Context, req *lvs.DecodeStreamKeyRequest) (*lvs.DecodeStreamKeyResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)

	var customerId string
	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	customerIdFromKey, err := streamkey.GetCustomerIdFromKey(req.GetStreamkey())
	if err != nil {
		return nil, twirp.NewError(twirp.InvalidArgument, "streamkey is invalid")
	}

	if customerId != customerIdFromKey {
		return nil, twirp.NewError(twirp.PermissionDenied, "Not Authorized to decrypt streamkey")
	}

	decryptedStreamKey, err := streamkey.Decrypt(ctx, s.secretSource, req.GetStreamkey())
	if err != nil {
		return nil, twirp.NewError(twirp.Unauthenticated, "failed to decode streamkey")
	}

	// For decoding a streamkey if  meta was not specified during streamkey creation , we just dont return it
	resp := &lvs.DecodeStreamKeyResponse{
		CustomerId: decryptedStreamKey.Priv.GetCustomerId(),
		ContentId:  decryptedStreamKey.Priv.GetContentId(),
		StreamkeyMetadata: &lvs.StreamKeyMetadata{
			S3Bucket:                decryptedStreamKey.Priv.GetS3Bucket(),
			S3Prefix:                decryptedStreamKey.Priv.GetS3Prefix(),
			SnsNotificationEndpoint: decryptedStreamKey.Priv.GetSnsNotificationEndpoint(),
			ExpirationTime:          decryptedStreamKey.Priv.GetExpirationTime(),
			EnableLowLatency:        decryptedStreamKey.Priv.GetLowLatencyMode(),
			CdnUrl:                  decryptedStreamKey.Priv.GetCdnUrl(),
			LatencyMode:             decryptedStreamKey.Priv.GetLatencyMode(),
		},
	}

	return resp, nil

}

// Record the desired metric and log an error if we fail to enqueue the metric for reporting
func (s *service) recordMetric(ctx context.Context, clientId, method, name, unit string, value float64) {
	err := s.metrics.Record(clientId, method, name, unit, value)
	if err != nil {
		logging.Info(ctx, "Failed to queue up the metric: %v", err)
	}
}

// prepareGetStreamResponse transforms an usher stream response into an lvs.LiveVideoService
// GetStreamResponse type. Currently much of this is hardcoded for testing purposes.
func streamFromUsherResponse(CustomerID string, ContentID string, us *usher.UsherStreamResponse, channelViewCount uint64) (*lvs.Stream, error) {
	var err error
	startTime := time.Unix(us.StartedOn, 0)

	streamStartTime, err := ptypes.TimestampProto(startTime)
	if err != nil {
		return nil, errors.Wrapf(err, "failed to convert timestamp %s to proto", startTime)
	}

	var streamStatus string
	if us.Status == constants.TRANSCODE_ACTIVE {
		streamStatus = constants.STATUS_LIVE
	} else {
		streamStatus = constants.STATUS_PREPARING
	}

	//TODO - We need to plug in starvation data into usher and use that for making decisions if stream is healthy
	response := lvs.Stream{
		CustomerId: CustomerID,
		ContentId:  ContentID,
		CdnPlaybackUrl: fmt.Sprintf("https://usher.ttvnw.net/api/lvs/hls/%s.m3u8?allow_source=true&player_backend=mediaplayer",
			getChannelNameLvsStreams(CustomerID, ContentID)),
		ViewerCount:     &wrappers.Int64Value{Value: int64(channelViewCount)},
		StartTime:       streamStartTime,
		DurationSeconds: int64(time.Since(startTime).Seconds()),
		S3VodUrl:        extractS3VodUrlFromMeta(us.LVSMetadata),
		VodManifestUrl:  extractVodManifestUrlFromMeta(us.LVSMetadata),
		Status:          streamStatus,
		HealthStatus:    constants.HEALTH_STATUS_STABLE,
		HealthReason:    fmt.Sprintf("Healthy stream"),
	}

	return &response, nil
}

func extractS3VodUrlFromMeta(lvsMeta string) string {
	var lvsData LVSMeta
	err := json.Unmarshal([]byte(lvsMeta), &lvsData)
	if err != nil {
		return ""
	}

	return lvsData.S3VodUrl
}

func extractVodManifestUrlFromMeta(lvsMeta string) string {
	var lvsData LVSMeta
	err := json.Unmarshal([]byte(lvsMeta), &lvsData)
	if err != nil {
		return ""
	}

	return lvsData.VodManifestUrl
}

func getChannelNameLvsStreams(customerId string, contentId string) string {
	return fmt.Sprintf("%s.%s.%s", "lvs", customerId, contentId)
}

func isRequestedStreamType(reqType string, streamStatus string) bool {

	if reqType == constants.STREAM_TYPE_ALL {
		return true
	}

	if reqType == constants.STREAM_TYPE_LIVE && streamStatus == constants.TRANSCODE_ACTIVE {
		return true
	}

	if reqType == constants.STREAM_TYPE_PREPARING && streamStatus == constants.TRANSCODE_PENDING {
		return true
	}

	return false
}

const maxMetadataSize = 1 << 10 // 1kB

// AddLiveMetadata inserts metadata into a live stream
func (s *service) AddLiveMetadata(ctx context.Context, req *lvs.AddLiveMetadataRequest) (*lvs.AddLiveMetadataResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)

	var customerId string
	startTime := time.Now()
	defer func() {
		s.recordMetric(ctx, customerId, "AddLiveMetadata", "Time", metrics.UnitMillisecond,
			float64(time.Since(startTime))/float64(time.Millisecond))
	}()

	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	if !IsValidContentID(req.GetContentId()) {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("must match regex '%s'",
			contentIDRegex))
	}

	contentId := req.GetContentId()
	if contentId == "" {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("Length is 0"))
	}

	md := req.GetMetadata()
	if md == "" {
		return nil, twirp.InvalidArgumentError("metadata", "is required")
	}

	if len(md) > maxMetadataSize {
		return nil, twirp.InvalidArgumentError("metadata", fmt.Sprintf("size is too large (%d of %d bytes)", len(md), maxMetadataSize))
	}

	channelName := getChannelNameLvsStreams(customerId, contentId)
	err := s.dgnAPIs.AddMetadata(ctx, channelName, md)
	if err != nil {
		twerr, ok := err.(twirp.Error)
		if ok && twerr.Code() == twirp.NotFound {
			return nil, twirp.NotFoundError("channel not found")
		}
		logging.Info(ctx, "Digestion AddMetadata call for channel :%s threw error :%+v", channelName, err)
		return nil, twirp.InternalErrorWith(err)
	}

	return &lvs.AddLiveMetadataResponse{}, nil
}

// StopStream will call digestion to remove the stream entry and disconnect rtmp stream
func (s *service) StopStream(ctx context.Context, req *lvs.StopStreamRequest) (*lvs.StopStreamResponse, error) {
	logging.Debug(ctx, "Received Request: %+v", req)

	var customerId string
	startTime := time.Now()
	defer func() {
		s.recordMetric(ctx, customerId, "StopStream", "Time", metrics.UnitMillisecond,
			float64(time.Since(startTime))/float64(time.Millisecond))
	}()

	var ok bool
	if customerId, ok = auth.LvsCustomerID(ctx); !ok {
		return nil, twirp.InvalidArgumentError("customer_id", "invalid request cert")
	}

	if !IsValidContentID(req.GetContentId()) {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("must match regex '%s'",
			contentIDRegex))
	}

	contentId := req.GetContentId()
	if contentId == "" {
		return nil, twirp.InvalidArgumentError("content_id", fmt.Sprintf("Length is 0"))
	}

	channelName := getChannelNameLvsStreams(customerId, contentId)
	err := s.dgnAPIs.DeleteStream(ctx, channelName)
	if err != nil {
		logging.Info(ctx, "Digestion DeleteStream call for channel :%s threw error :%+v", channelName, err)
		if strings.Contains(err.Error(), fmt.Sprintf("%d", http.StatusNotFound)) {
			return nil, twirp.NotFoundError("No streams with given content_id")
		}
		return nil, twirp.InternalErrorWith(err)
	}

	return &lvs.StopStreamResponse{}, nil
}

func getStreamKeyParameter(inputParam string) string {
	if inputParam == "" {
		return defaultStreamKeyParameterValue
	}

	return inputParam
}

func isValidLatencyMode(latencyMode string) bool {
	//If latency_mode is not specifed, default is normal latency mode
	if latencyMode == "" {
		return true
	}

	if latencyMode == "normal" || latencyMode == "low" {
		return true
	}

	return false
}
