package main

import (
	"bufio"
	"encoding/binary"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"net"

	"code.justin.tv/video/gortmp/pkg/rtmp"
)

var listen = flag.String("listen", "0.0.0.0:1935", "listen")
var remote = flag.String("remote", "live.justin.tv:1935", "remote")

func main() {
	flag.Parse()

	ln, err := net.Listen("tcp", *listen)
	if err != nil {
		fmt.Printf("Failed to open listen socket: %s\n", err)
		return
	}

	fmt.Printf("Listening - %s\n", *listen)
	for {
		client, err := ln.Accept()

		if err != nil {
			netErr, ok := err.(net.Error)

			if ok && netErr.Temporary() {
				continue
			}

			fmt.Printf("Listen socket error: %s\n", err)
			return
		}

		remote, err := net.Dial("tcp", *remote)
		if err != nil {
			fmt.Printf("Failed to open remote conenction: %s\n", err)
			remote.Close()
			continue
		}

		go proxy(client, remote)
	}
}

type RtmpLogger struct {
	tag string
}

func (rl RtmpLogger) OnMessage(msg *rtmp.RawMessage) {
	logMessage(rl.tag, msg)
}

func (rl RtmpLogger) OnWindowFull(inCount uint32) {
	fmt.Printf("OnWindowFull: %d\n", inCount)
}

type Sniffer interface {
	io.ReadCloser
}

type sniffer struct {
	tr     io.Reader
	pa, pb net.Conn
}

func (s *sniffer) Read(b []byte) (int, error) {
	return s.tr.Read(b)
}

func (s *sniffer) Close() error {
	s.pa.Close()
	s.pb.Close()
	return nil
}

func NewSniffer(r io.Reader, tag string) Sniffer {
	a, b := net.Pipe()
	tr := io.TeeReader(r, a)

	go io.Copy(ioutil.Discard, a)

	chunkStreamReader := rtmp.NewChunkStreamReader(bufio.NewReader(b))

	go func() {
		chunkStreamReader.Read()
	}()

	return &sniffer{
		tr: tr,
		pa: a,
		pb: b,
	}
}

func logMessage(tag string, msg *rtmp.RawMessage) {
	if msg.Type == rtmp.AUDIO_TYPE || msg.Type == rtmp.VIDEO_TYPE {
		return
	}
	fmt.Printf("%s - %#v\n", tag, *msg)

	switch msg.Type {
	case rtmp.SET_CHUNK_SIZE:
		fmt.Printf("%s - SetChunkSize: %d\n", tag, binary.BigEndian.Uint32(msg.Data.Bytes()))
	case rtmp.ABORT_MESSAGE:
		fmt.Printf("%s - AbortMessage: %d\n", tag, binary.BigEndian.Uint32(msg.Data.Bytes()))
	case rtmp.ACKNOWLEDGEMENT:
		fmt.Printf("%s - Acknowledgement: %d\n", tag, binary.BigEndian.Uint32(msg.Data.Bytes()))
	case rtmp.WINDOW_ACKNOWLEDGEMENT_SIZE:
		fmt.Printf("%s - WindowAcknowledgementSize: %d\n", tag, binary.BigEndian.Uint32(msg.Data.Bytes()))
	case rtmp.SET_PEER_BANDWIDTH:
		var bandwidth uint32
		var limit uint8
		binary.Read(msg.Data, binary.BigEndian, &bandwidth)
		binary.Read(msg.Data, binary.BigEndian, &limit)
		fmt.Printf("%s - SetPeerBandwidth: %d %d\n", tag, bandwidth, limit)
	case rtmp.COMMAND_AMF3:
		fallthrough
	case rtmp.COMMAND_AMF0:
		cmd, err := rtmp.ParseCommand(msg)
		if err != nil {
			fmt.Printf("%s - rtmp.ParseCommand error: %s\n", tag, err)
			return
		}
		fmt.Printf("%s - %#v\n", tag, cmd)
	case rtmp.USER_CONTROL_MESSAGE:
		logUserControl(tag, msg)
	}
}

func logUserControl(tag string, msg *rtmp.RawMessage) {
	err := func() error {
		var event uint16
		if err := binary.Read(msg.Data, binary.BigEndian, &event); err != nil {
			return err
		}

		switch event {
		case rtmp.EVENT_STREAM_BEGIN:
			var streamID uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &streamID); err != nil {
				return err
			}
			fmt.Printf("%s - StreamBegin: %d\n", tag, streamID)
		case rtmp.EVENT_STREAM_EOF:
			var streamID uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &streamID); err != nil {
				return err
			}
			fmt.Printf("%s - StreamEOF: %d\n", tag, streamID)
		case rtmp.EVENT_STREAM_DRY:
			var streamID uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &streamID); err != nil {
				return err
			}
			fmt.Printf("%s - StreamDry: %d\n", tag, streamID)
		case rtmp.EVENT_SET_BUFFER_LENGTH:
			var streamID uint32
			var bufferLen uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &streamID); err != nil {
				return err
			}
			if err := binary.Read(msg.Data, binary.BigEndian, &bufferLen); err != nil {
				return err
			}
			fmt.Printf("%s - SetBufferLength: %d %d", tag, streamID, bufferLen)
		case rtmp.EVENT_STREAM_IS_RECORDED:
			var streamID uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &streamID); err != nil {
				return err
			}
			fmt.Printf("%s - StreamIsRecorded: %d", tag, streamID)
		case rtmp.EVENT_PING_REQUEST:
			var timestamp uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &timestamp); err != nil {
				return err
			}
			fmt.Printf("%s - PingRequest: %d\n", tag, timestamp)
		case rtmp.EVENT_PING_RESPONSE:
			var timestamp uint32
			if err := binary.Read(msg.Data, binary.BigEndian, &timestamp); err != nil {
				return err
			}
			fmt.Printf("%s - PingResponse: %d\n", tag, timestamp)
		default:
			fmt.Printf("%s - unknown user control: %d\n", tag, event)
		}
		return nil
	}()

	if err != nil {
		fmt.Printf("%s - user control parse error: %s\n", tag, err)
	}
}

func handshake(client net.Conn, remote net.Conn) error {
	done := make(chan error)

	go func() {
		_, err := io.CopyN(client, remote, 3073)
		done <- err
	}()

	go func() {
		_, err := io.CopyN(remote, client, 3073)
		done <- err
	}()

	if err := <-done; err != nil {
		<-done
		return err
	}
	return <-done
}

func proxy(client net.Conn, remote net.Conn) {
	fmt.Printf("Beginning proxy...\n")

	err := handshake(client, remote)

	if err != nil {
		fmt.Printf("Handshake error: %s\n", err)
		client.Close()
		remote.Close()
		return
	}

	done := make(chan net.Conn)

	clientSniffer := NewSniffer(client, "CLIENT")
	remoteSniffer := NewSniffer(remote, "REMOTE")
	defer clientSniffer.Close()
	defer remoteSniffer.Close()

	go func() {
		io.Copy(client, remoteSniffer)
		done <- client
	}()

	go func() {
		io.Copy(remote, clientSniffer)
		done <- remote
	}()

	(<-done).Close()
	(<-done).Close()

	fmt.Printf("Finishing proxy...\n")
}
