package main

import (
	"bytes"
	"code.justin.tv/video/gortmp/pkg/amf"
	"code.justin.tv/video/gortmp/pkg/flv"
	rtmplog "code.justin.tv/video/gortmp/pkg/log"
	"code.justin.tv/video/gortmp/pkg/rtmp"
	goctx "context"
	"flag"
	"fmt"
	"log"
	"net"
	"net/url"
	"reflect"
	"strings"
	"sync"
	"time"
)

var rtmpURL = flag.String("rtmp", "", "rtmp connect url")
var fileName = flag.String("filename", "", "file to publish")

type TransactionHandler func(rtmp.Message) error
type TransactionMap struct {
	mu    sync.Mutex
	m     map[uint32]TransactionHandler
	txnid uint32
}

func NewTransactionMap() *TransactionMap {
	return &TransactionMap{
		m: make(map[uint32]TransactionHandler),
	}
}

func (tm *TransactionMap) New(h TransactionHandler) uint32 {
	tm.mu.Lock()
	defer tm.mu.Unlock()

	id := tm.txnid
	tm.txnid++
	tm.m[id] = h

	return id
}

func (tm *TransactionMap) Handle(id uint32, msg rtmp.Message) error {
	if h, ok := tm.m[id]; ok {
		delete(tm.m, id)
		return h(msg)
	}
	return nil
}

type RtmpPublisher struct {
	URL  *url.URL
	Conn rtmp.BasicConn
	File *flv.FlvFileReader
	tm   *TransactionMap
}

func (p *RtmpPublisher) ConnectResult(msg rtmp.Message) error {
	log.Printf("ConnectResult")

	cmd := rtmp.CreateStreamCommand{TransactionID: p.tm.New(p.CreateStreamResult)}
	return p.Send(cmd)
}

func (p *RtmpPublisher) CreateStreamResult(msg rtmp.Message) error {
	log.Printf("CreateStreamResult")

	result, ok := msg.(rtmp.ResultCommand)
	if !ok {
		return fmt.Errorf("Got invalid result command for CreateStream: %#v", msg)
	}

	log.Printf("CreateStreamResult: %#v", result)

	log.Printf("TypeOf(result.Info)=%s", reflect.TypeOf(result.Info).Name())

	streamID, ok := result.Info.(float64)
	if !ok {
		return fmt.Errorf("Invalid result commend for CreateStream: %#v", result)
	}

	path := strings.Split(p.URL.Path, "/")
	if len(path) == 0 {
		return fmt.Errorf("Unable to parse stream name: %s", p.URL.Path)
	}

	cmd := rtmp.PublishCommand{
		StreamID:      uint32(streamID),
		TransactionID: p.tm.New(p.PublishResult),
		Name:          path[len(path)-1],
	}

	return p.Send(cmd)
}

func (p *RtmpPublisher) PublishResult(msg rtmp.Message) error {
	log.Printf("PublishResult")

	result, ok := msg.(rtmp.OnStatusCommand)
	if !ok {
		return fmt.Errorf("Got invalid status command for Publish: %#v", msg)
	}

	go p.DoPublish(result.StreamID)

	return nil
}

func (p *RtmpPublisher) DoPublish(streamID uint32) {
	defer p.Conn.Close()

	start := time.Now()
	firstTs := uint32(0)
	firstTsSet := false

	for {
		header, data, err := p.File.ReadTag()
		if err != nil {
			log.Printf("Error reading flv data: %s", err)
			return
		}

		log.Printf("ts=%d,type=%d", header.Timestamp, header.TagType)

		if !firstTsSet {
			firstTs = header.Timestamp
			firstTsSet = true
		}

		clockDelta := time.Since(start)
		tsDelta := time.Duration(header.Timestamp-firstTs) * time.Millisecond

		rtDrift := tsDelta - clockDelta
		if rtDrift < 500*time.Millisecond {
			time.Sleep(rtDrift)
		}

		msg := &rtmp.RawMessage{
			ChunkStreamID: 5, // arbitrarily chosen
			Timestamp:     header.Timestamp,
			Type:          header.TagType,
			StreamID:      streamID,
			Data:          bytes.NewBuffer(data),
		}

		if err := p.Conn.Write(msg); err != nil {
			log.Printf("Error during publish: %s", err)
			return
		}

		if err := p.Conn.Flush(); err != nil {
			log.Printf("Connection error during publish: %s", err)
			return
		}
	}
}

func (p *RtmpPublisher) Handle() error {
	// first just send the connect command
	p.tm = NewTransactionMap()

	tcURL := *p.URL
	tcURL.Path = "app"
	connectCmd := rtmp.ConnectCommand{
		TransactionID: p.tm.New(p.ConnectResult),
		Properties: amf.Object{
			"app":           "app",
			"flashVer":      "LNX 9,0,124,2",
			"tcUrl":         tcURL.String(),
			"fpad":          false,
			"capabilities":  15,
			"audioCodecs":   4071,
			"videoCodecs":   252,
			"videoFunction": 1,
		},
	}

	log.Printf("Writing connect msg: %+v", connectCmd)
	if err := p.Send(connectCmd); err != nil {
		return err
	}

	if err := p.Send(rtmp.SetChunkSizeMessage{Size: 4096}); err != nil {
		return err
	}

	for {
		raw, err := p.Conn.Read()
		if err != nil {
			return err
		}

		if err := p.HandleMessage(raw); err != nil {
			return err
		}
	}
}

func (p *RtmpPublisher) HandleMessage(raw *rtmp.RawMessage) error {
	msg, err := p.parseMessage(raw)
	if err != nil {
		return err
	}

	switch msg := msg.(type) {
	case rtmp.ResultCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	case rtmp.ErrorCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	case rtmp.OnStatusCommand:
		return p.tm.Handle(msg.TransactionID, msg)
	case *rtmp.RawMessage:
		raw := *msg
		raw.Data = nil
		log.Printf("Raw: %#v", raw)
	default:
		log.Printf("Received msg: %#v", msg)
	}
	return nil
}

func (p *RtmpPublisher) parseMessage(raw *rtmp.RawMessage) (rtmp.Message, error) {
	if raw.ChunkStreamID == rtmp.CS_ID_PROTOCOL_CONTROL {
		var msg rtmp.Message
		var err error

		if raw.Type == rtmp.USER_CONTROL_MESSAGE {
			msg, err = rtmp.ParseEvent(raw)
		} else {
			msg, err = rtmp.ParseControlMessage(raw)
		}

		return msg, err
	}

	switch raw.Type {
	case rtmp.COMMAND_AMF0:
		fallthrough
	case rtmp.COMMAND_AMF3:
		return rtmp.ParseCommand(raw)
	case rtmp.USER_CONTROL_MESSAGE:
		return rtmp.ParseEvent(raw)
	default:
		return raw, nil
	}
}

func (p *RtmpPublisher) Send(msg rtmp.Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	if err := p.Conn.Write(raw); err != nil {
		return err
	}

	return p.Conn.Flush()
}

func dial(host string) (net.Conn, error) {
	if _, err := net.ResolveTCPAddr("tcp", host); err != nil {
		if _, ok := err.(net.InvalidAddrError); ok {
			host = net.JoinHostPort(host, "1935")
		}
	}

	return net.Dial("tcp", host)
}

func main() {
	rtmplog.SetLogLevel(rtmplog.LogDebug)
	log.SetFlags(log.Lmicroseconds)
	flag.Parse()

	parsed, err := url.Parse(*rtmpURL)
	if err != nil {
		log.Fatalf("Invalid rtmp url: %s", err)
	}

	flvFile, err := flv.OpenFile(*fileName)
	if err != nil {
		log.Fatalf("Unable to open flv file %s: %s", *fileName, err)
	}

	conn, err := dial(parsed.Host)
	if err != nil {
		log.Fatal(err)
	}

	log.Printf("Connected to %s, performing handshake", parsed.Host)
	if _, err := rtmp.Handshake(goctx.Background(), conn); err != nil {
		log.Fatalf("Handshake failed: %s", err)
	}

	player := &RtmpPublisher{
		URL:  parsed,
		Conn: rtmp.NewBasicConn(conn),
		File: flvFile,
	}

	if err := player.Handle(); err != nil {
		log.Fatalf("Error playing back stream: %s", err)
	}
}
