package rtmp

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"net/url"
	"strings"
	"time"

	logging "code.justin.tv/event-engineering/golibs/pkg/logging"
	"github.com/pkg/errors"

	gortmp "code.justin.tv/event-engineering/gortmp/pkg/rtmp"
	rtmp "code.justin.tv/event-engineering/rtmp/pkg/common"
)

const (
	writeIdleTimeoutSeconds           = 2
	bufferNumTags                     = 10
	pingTickerIntervalSeconds         = 1
	statusUpdateTickerIntervalSeconds = 15
)

// Pusher manages an RTMP connection to an RTMP server and publishes a stream to it
type Pusher struct {
	URL    *url.URL
	Conn   gortmp.BasicConn
	ms     gortmp.MediaStream
	tm     *rtmp.TransactionMap
	errCh  chan (error)
	logger logging.Logger
	closed bool
	// We're going to use this object to signal that something that went wrong, and what.
	stopErr error
	lastTs  uint32
}

// NewPusher creates a new rtmp Pusher
func NewPusher(errCh chan error, logger logging.Logger) *Pusher {
	p := &Pusher{
		errCh:  errCh,
		logger: logger,
	}

	return p
}

// Push initiates connection to the supplied RTMP URL, sending data from the supplied mediastream to the server
func (p *Pusher) Push(rtmpURL string, ms gortmp.MediaStream) error {
	parsedURL, conn, appName, err := rtmp.Connect(rtmpURL, 0, writeIdleTimeoutSeconds)
	if err != nil {
		return err
	}

	p.URL = parsedURL
	p.Conn = conn
	p.tm = rtmp.NewTransactionMap()
	p.ms = ms

	// first just send the connect command
	connectCmd := rtmp.DefaultConnectCommand(appName, p.URL.String(), p.tm.New(p.connectResult))

	p.logger.Debugf("Writing connect msg: %+v", connectCmd)
	if err := rtmp.Send(p.Conn, connectCmd); err != nil {
		return errors.Wrap(err, "Connect message failed")
	}

	p.logger.Debug("Writing chunk size")
	if err := rtmp.Send(p.Conn, gortmp.SetChunkSizeMessage{Size: 4096}); err != nil {
		return errors.Wrap(err, "Set Chunk Size message failed")
	}

	go p.listen()

	return nil
}

// Publish manually publishes an FLV tag to the mediastream, useful for injecting metadata
func (p *Pusher) Publish(tag gortmp.FlvTag) error {
	if tag.Timestamp == 0 {
		tag.Timestamp = p.lastTs
	}

	return p.ms.Publish(&tag)
}

// listen for responses from the server
func (p *Pusher) listen() {
	for {
		if p.closed {
			return
		}
		raw, err := p.Conn.Read()
		if err != nil {
			p.errCh <- errors.Wrap(err, "Failed to read from connection")
			return
		}

		if err := p.handleMessage(raw); err != nil {
			p.errCh <- errors.Wrap(err, "Failed to handle message")
			return
		}
	}
}

func (p *Pusher) handleMessage(raw *gortmp.RawMessage) error {
	msg, err := rtmp.ParseMessage(raw)
	if err != nil {
		return err
	}

	switch msg := msg.(type) {
	case gortmp.ResultCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	case gortmp.ErrorCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	case gortmp.OnStatusCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	default:
		//p.logger.Debugf("What's this: %v", msg)
	}
	return nil
}

// Close closes the pusher and terminates the connection
func (p *Pusher) Close() {
	if p.closed {
		return
	}

	p.closed = true

	if p.Conn != nil {
		err := p.Conn.Close()
		if err != nil {
			p.logger.Warn("Error closing output connection", err)
		}
	}
}

// IsClosed returns true if the pusher has been closed
func (p *Pusher) IsClosed() bool {
	return p.closed
}

func (p *Pusher) connectResult(msg gortmp.Message) error {
	cmd := gortmp.CreateStreamCommand{TransactionID: p.tm.New(p.createStreamResult)}
	return rtmp.Send(p.Conn, cmd)
}

func (p *Pusher) createStreamResult(msg gortmp.Message) error {
	result, ok := msg.(gortmp.ResultCommand)
	if !ok {
		return fmt.Errorf("Got invalid result command for CreateStream: %#v", msg)
	}

	p.logger.Debugf("CreateStreamResult: %#v", result)

	streamID, ok := result.Info.(float64)
	if !ok {
		return fmt.Errorf("Invalid result command for CreateStream: %#v", result)
	}

	path := strings.Split(p.URL.Path, "/")
	if len(path) == 0 {
		return fmt.Errorf("Unable to parse stream name: %s", p.URL.Path)
	}

	streamName := path[len(path)-1]
	if p.URL.RawQuery != "" {
		streamName = streamName + "?" + p.URL.RawQuery
	}

	cmd := gortmp.PublishCommand{
		StreamID:      uint32(streamID),
		TransactionID: p.tm.New(p.publishResult),
		Name:          streamName,
	}

	return rtmp.Send(p.Conn, cmd)
}

func (p *Pusher) publishResult(msg gortmp.Message) error {
	result, ok := msg.(gortmp.OnStatusCommand)
	if !ok {
		return fmt.Errorf("Got invalid status command for Publish: %#v", msg)
	}

	go p.doPublish(result.StreamID)

	return nil
}

func (p *Pusher) bufferTags(inputTags, outputTags chan *gortmp.FlvTag) {
	tagsDropped := 0
	for {
		if p.closed {
			return
		}

		curTag, gotMsg := <-inputTags
		if !gotMsg {
			p.stopErr = errors.New("Failed to read from input channel")
			return
		}

		select {
		case outputTags <- curTag:
			if tagsDropped > 0 {
				p.logger.Debugf("send buffer no longer full, resuming after %v dropped tags", tagsDropped)
				tagsDropped = 0
			}
		default:
			if tagsDropped == 0 {
				p.logger.Debug("send buffer full, dropping tags")
			}
			tagsDropped++
		}
		if p.stopErr != nil {
			return
		}
	}
}

func (p *Pusher) doPublish(streamID uint32) {
	inputTags, err := p.ms.Subscribe()

	if err != nil {
		p.errCh <- errors.Wrap(err, "Error subscribing to media stream")
		return
	}

	// We need to keep track of some timestamps for loop detection
	var lastTs uint32
	var lastLoopStartTs uint32

	// We don't want to block here, since a single connection could go down and blocking on the tags channel
	// would result in anything else subscribed to the medistream getting blocked, so we're going to read off the tags channel and
	// send to the tagsOutput non-blocking so we don't cause any backpressure issues further up, we're just
	// going to drop any tags we can't send on the floor after a buffer of bufferNumTags
	outputTags := make(chan *gortmp.FlvTag, bufferNumTags)
	go p.bufferTags(inputTags, outputTags)

	// Set up tickers for some operations detailed below
	pingTicker := time.NewTicker(time.Second * pingTickerIntervalSeconds)
	statusTicker := time.NewTicker(time.Second * statusUpdateTickerIntervalSeconds)

	for {
		if p.closed {
			break
		}

		select {
		case curTag, gotMsg := <-outputTags:
			if !gotMsg {
				p.stopErr = errors.New("Failed to read from buffer channel")
				break
			}

			// Handle timestamp resets in the source, this way the source can be reconnected and should work without killing the dest
			if curTag.Timestamp+lastLoopStartTs < lastTs {
				p.logger.Debugf("Input loop detected, adjusting timestamps raw:%v loopStart:%v last:%v", curTag.Timestamp, lastLoopStartTs, lastTs)
				lastLoopStartTs = lastTs
			}

			raw := &gortmp.RawMessage{
				ChunkStreamID: gortmp.CS_ID_DATA,
				Timestamp:     curTag.Timestamp + lastLoopStartTs,
				Type:          curTag.Type,
				StreamID:      streamID,
				Data:          bytes.NewBuffer(curTag.Bytes),
			}
			lastTs = raw.Timestamp
			p.lastTs = lastTs

			if err := p.Conn.Write(raw); err != nil {
				p.stopErr = errors.Wrap(err, "Error during publish")
				break
			}

			if err := p.Conn.Flush(); err != nil {
				p.stopErr = errors.Wrap(err, "Connection error during publish")
			}
		// Because we have a short idle timeout in order to detect broken output connections as fast as possible, we need to send something over the socket
		// if the input is stalled (for instance when the input stream is restarted), otherwise it will disconnect after the idle time, our idle time is 2 seconds
		// so we're sending a ping message every 1 second, I've no idea if this is the correct way to do this or not, but it doesn't seem to break anything
		case <-pingTicker.C:
			p.sendPing(streamID)
		}

		if p.stopErr != nil {
			break
		}
	}

	// Clean up
	p.ms.Unsubscribe(inputTags)
	pingTicker.Stop()
	statusTicker.Stop()
	close(outputTags)

	// This should always be true in theory
	if p.stopErr != nil {
		p.errCh <- p.stopErr
	}
}

func (p *Pusher) sendPing(streamID uint32) {
	data := new(bytes.Buffer)

	err := binary.Write(data, binary.BigEndian, uint16(gortmp.PING_REQUEST_EVENT))
	if err != nil {
		p.logger.Debug("Error creating ping message - event ", err)
		return
	}

	err = binary.Write(data, binary.BigEndian, time.Now().Unix())
	if err != nil {
		p.logger.Debug("Error creating ping message - timestamp ", err)
		return
	}

	ping := &gortmp.RawMessage{
		ChunkStreamID: gortmp.CS_ID_USER_CONTROL,
		Type:          gortmp.USER_CONTROL_MESSAGE,
		StreamID:      streamID,
		Data:          data,
	}

	if err := p.Conn.Write(ping); err != nil {
		p.logger.Debug("Error sending ping ", err)
		return
	}
}
