package server

import (
	"context"
	"fmt"
	"net"
	"net/http"
	"sync"
	"time"

	rtmpctx "code.justin.tv/event-engineering/gortmp/pkg/context"
	gortmp "code.justin.tv/event-engineering/gortmp/pkg/rtmp"

	"code.justin.tv/event-engineering/carrot-rtmp-recorder/pkg/ms2s3"
	rpc "code.justin.tv/event-engineering/carrot-rtmp-recorder/pkg/rpc"
	"code.justin.tv/event-engineering/carrot-rtmp-recorder/pkg/svc"
	"code.justin.tv/event-engineering/rtmp/pkg/server"
	"github.com/aws/aws-sdk-go/service/s3/s3iface"
	"github.com/google/uuid"
	"github.com/sirupsen/logrus"
)

type Server interface {
	Start() error
	Stop()
	Stopped() bool
}

type endpointConfig struct {
	EndpointID  string
	KeyPrefix   string
	MaxDuration time.Duration
}

type rtmpServer struct {
	logger     logrus.FieldLogger
	region     string
	stopped    bool
	ln         net.Listener
	sv         *server.Server
	recorder   ms2s3.MediaStreamtoS3
	operations svc.WorkerOperations
	wg         sync.WaitGroup
	config     map[string]*endpointConfig
	configLock sync.Mutex
}

func New(destBucketName, region string, s3 s3iface.S3API, operations svc.WorkerOperations, logger logrus.FieldLogger) Server {
	return &rtmpServer{
		logger:     logger,
		region:     region,
		recorder:   ms2s3.New(s3, destBucketName, logger),
		operations: operations,
	}
}

func (s *rtmpServer) configUpdater() {
	// Update the config from the DB every 10 seconds
	for {
		if s.stopped {
			return
		}

		config, err := s.operations.GetRegionEndpointConfig(s.region)

		if err != nil {
			s.logger.WithError(err).Error("Failed to retrieve config")
		} else {
			newConfig := make(map[string]*endpointConfig)
			for _, item := range config {
				if time.Now().After(time.Unix(item.EndpointExpires, 0)) {
					continue
				}

				newConfig[item.StreamKey] = &endpointConfig{
					EndpointID:  item.ID,
					KeyPrefix:   item.KeyPrefix,
					MaxDuration: time.Second * time.Duration(item.MaxDuration),
				}
			}

			s.configLock.Lock()

			// Mark removed endpoints as stopped
			for streamKey, item := range s.config {
				if _, ok := newConfig[streamKey]; !ok {
					s.operations.UpdateEndpointStatus(item.EndpointID, rpc.EndpointStatus_EndpointStopped)
				}
			}

			// Mark new endpoints as ready
			for streamKey, item := range newConfig {
				if _, ok := s.config[streamKey]; !ok {
					s.operations.UpdateEndpointStatus(item.EndpointID, rpc.EndpointStatus_EndpointReady)
				}
			}

			s.config = newConfig

			s.configLock.Unlock()
		}

		time.Sleep(10 * time.Second)
	}
}

func (s *rtmpServer) Start() error {
	go s.configUpdater()

	ms := server.NewMediaServer(&handler{
		logger:          s.logger,
		server:          s,
		cancelKillTimer: make(map[string]chan interface{}),
	})

	sv := server.NewServer(server.ServerConfig{
		Handler: ms,
	})

	s.sv = sv

	// Listen for incoming connections
	ln, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", "1935"))
	if err != nil {
		return err
	}

	s.ln = ln

	mln := NewMultiListener(ln, 10*time.Second)

	go func(listener net.Listener) {
		if err := sv.Serve(listener); err != nil {
			s.logger.Warnf("Server error: %s", err)
		} else {
			s.logger.Infof("Server exited cleanly: %s", err)
		}
	}(mln.RtmpListener())

	go func(listener net.Listener) {
		err := http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			fmt.Fprintf(w, "Would you like a carrot?")
		}))

		if err != nil {
			s.logger.WithError(err).Warnf("Healthcheck listener errored")
		}
	}(mln.HttpListener())

	return nil
}

func (s *rtmpServer) Stop() {
	s.sv.Close()
	s.ln.Close()
	s.wg.Wait()
	s.stopped = true
}

func (s *rtmpServer) Stopped() bool {
	return s.stopped
}

type handler struct {
	logger          logrus.FieldLogger
	server          *rtmpServer
	killTimerMu     sync.Mutex
	cancelKillTimer map[string]chan (interface{})
}

func (h *handler) CloseConnection(ctx context.Context) {
	if remoteAddr, ok := rtmpctx.GetRemoteAddr(ctx); ok {
		h.server.sv.Disconnect(remoteAddr)
	}
}

func (h *handler) OnMediaStreamCreated(ctx context.Context, ms gortmp.MediaStream) {
	h.logger.Info("MediaStream Created")

	streamName, ok := rtmpctx.GetStreamName(ctx)
	if !ok {
		h.logger.Error("Failed to get stream name")

		h.CloseConnection(ctx)
		ms.Close()
	} else {
		h.server.configLock.Lock()
		endpoint, ok := h.server.config[streamName]
		h.server.configLock.Unlock()

		if !ok {
			h.logger.Infof("Got unrecognised stream name %v", streamName)
			h.CloseConnection(ctx)
			ms.Close()
			return
		}

		h.server.operations.UpdateEndpointStatus(endpoint.EndpointID, rpc.EndpointStatus_EndpointRecording)

		go h.server.recorder.RecordToS3(ctx, ms, fmt.Sprintf("%v%v.flv", endpoint.KeyPrefix, uuid.New().String()))

		// Kill the stream if we get an incoming stream of sufficient length
		go func() {
			killStream := time.NewTimer(endpoint.MaxDuration)
			cancelKillTimer := make(chan interface{}, 1)

			h.killTimerMu.Lock()
			h.cancelKillTimer[streamName] = cancelKillTimer
			h.killTimerMu.Unlock()

			select {
			case <-killStream.C:
				h.logger.Infof("Killing stream after %v", endpoint.MaxDuration)
				h.CloseConnection(ctx)
				ms.Close()

				// If the endpoint is still current, set it back to "ready" state
				h.server.configLock.Lock()
				if endpoint, ok := h.server.config[streamName]; ok {
					h.server.operations.UpdateEndpointStatus(endpoint.EndpointID, rpc.EndpointStatus_EndpointReady)
				}
				h.server.configLock.Unlock()
			case <-cancelKillTimer:
				killStream.Stop()
			}
		}()
	}
}

func (h *handler) OnMediaStreamDestroyed(ctx context.Context, ms gortmp.MediaStream) {
	h.logger.Info("MediaStream Destroyed")
	streamName, ok := rtmpctx.GetStreamName(ctx)
	if !ok {
		h.logger.Error("Failed to get stream name")
	} else {
		h.killTimerMu.Lock()

		// If the endpoint is still current, set it back to "ready" state
		h.server.configLock.Lock()
		if endpoint, ok := h.server.config[streamName]; ok {
			h.server.operations.UpdateEndpointStatus(endpoint.EndpointID, rpc.EndpointStatus_EndpointReady)
		}
		h.server.configLock.Unlock()

		// Kill the stream cancel timer
		h.cancelKillTimer[streamName] <- "would you like a carrot?"

		delete(h.cancelKillTimer, streamName)
		h.killTimerMu.Unlock()
	}
}

func (h *handler) Handle(ctx context.Context, r gortmp.Receiver, msg gortmp.Message) error {
	return r.Handle(msg)
}
