package server

import (
	"bytes"
	"fmt"
	"io"
	"net"
	"os"
	"strconv"
	"strings"
	"time"
)

var (
	// prefix is the string we look for at the start of a connection
	// to check if this connection is using the proxy protocol
	prefix    = []byte("PROXY ")
	prefixLen = len(prefix)
)

type FDConn interface {
	net.Conn
	File() (*os.File, error)
}

type ProxyConn struct {
	FDConn
	SrcAddr *net.TCPAddr
	DstAddr *net.TCPAddr
}

var _ FDConn = (*ProxyConn)(nil)
var _ net.Conn = (*ProxyConn)(nil)

func (pc *ProxyConn) RemoteAddr() net.Addr {
	if pc.SrcAddr != nil {
		return &proxyAddr{pc.SrcAddr, pc.DstAddr}
	}
	return pc.RemoteAddr()
}

type ProxyAddr interface {
	RemoteAddr() net.Addr
	ProxyAddr() net.Addr
}

type proxyAddr struct {
	net.Addr
	proxy net.Addr
}

func (pa *proxyAddr) RemoteAddr() net.Addr {
	return pa.Addr
}

func (pa *proxyAddr) ProxyAddr() net.Addr {
	return pa.proxy
}

func (pa *proxyAddr) Network() string {
	return pa.Addr.Network()
}

func (pa *proxyAddr) String() string {
	return pa.Addr.String()
}

type TimeoutConn struct {
	FDConn
	Timeout time.Duration
}

func (t *TimeoutConn) Read(b []byte) (int, error) {
	n, err := t.FDConn.Read(b)
	t.SetDeadline(time.Now().Add(t.Timeout))
	return n, err
}

func (t *TimeoutConn) Write(b []byte) (int, error) {
	n, err := t.FDConn.Write(b)
	t.SetDeadline(time.Now().Add(t.Timeout))
	return n, err
}

type multiListener struct {
	rtmpListener net.Listener
	httpListener net.Listener
}

type outListener struct {
	net.Listener
	ch chan net.Conn
}

type peekedConn struct {
	FDConn
	reader io.Reader
}

var _ FDConn = (*peekedConn)(nil)
var _ net.Conn = (*peekedConn)(nil)

func sniffConnection(conn FDConn, rtmpCh, httpCh chan net.Conn) {
	destCh := httpCh
	sniffSlice, conn := sniffProxyProto(conn)

	if sniffSlice[0] == 3 { //RTMP S0 is 0x03
		destCh = rtmpCh
	}

	peeked := bytes.NewReader(sniffSlice)
	mr := io.MultiReader(peeked, conn)
	destCh <- &peekedConn{conn, mr}
}

func sniffProxyProto(conn FDConn) ([]byte, FDConn) {
	initialSniff := make([]byte, prefixLen)
	n, err := io.ReadFull(conn, initialSniff)
	if err != nil || n != prefixLen {
		conn.Close()
		return initialSniff, conn
	}
	if !bytes.Equal(initialSniff[:prefixLen], prefix[:prefixLen]) {
		return initialSniff, conn
	}

	sniffProxySlice := make([]byte, 108-prefixLen)

	n, err = io.ReadFull(conn, sniffProxySlice)
	if err != nil || n != 108-prefixLen {
		conn.Close()
		return initialSniff, conn
	}
	//find the first newline char
	lenHeader := bytes.Index(sniffProxySlice, []byte{0x0A})
	if lenHeader == -1 {
		conn.Close()
		return initialSniff, conn
	}

	// Strip the carriage return and new line
	header := string(append(initialSniff, sniffProxySlice[:lenHeader-1]...))
	clientIP, ingestProxy, err := getProxyInfoFromHeader(header)
	if err != nil {
		conn.Close()
		fmt.Println(err)
		return initialSniff, conn
	}
	proxyConn := &ProxyConn{conn, clientIP, ingestProxy}
	peeked := bytes.NewReader(sniffProxySlice[lenHeader+1:])
	mr := io.MultiReader(peeked, proxyConn)
	return sniffProxyProto(&peekedConn{proxyConn, mr})
}

func getProxyInfoFromHeader(header string) (clientIP, ingestProxy *net.TCPAddr, err error) {
	// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
	parts := strings.Split(header, " ")
	if len(parts) != 6 {
		err = fmt.Errorf("Invalid header line: %s", header)
		return
	}
	// Verify the type is known
	switch parts[1] {
	case "TCP4":
	case "TCP6":
	default:
		err = fmt.Errorf("Unhandled address type: %s", parts[1])
		return
	}

	// Parse out the source address
	ip := net.ParseIP(parts[2])
	if ip == nil {
		err = fmt.Errorf("Invalid source ip: %s", parts[2])
		return
	}

	port, err := strconv.Atoi(parts[4])
	if err != nil {
		err = fmt.Errorf("Invalid source port: %s", parts[4])
		return
	}
	clientIP = &net.TCPAddr{IP: ip, Port: port}
	// Parse out the destination address
	ip = net.ParseIP(parts[3])
	if ip == nil {
		err = fmt.Errorf("Invalid destination ip: %s", parts[3])
		return
	}

	port, err = strconv.Atoi(parts[5])
	if err != nil {
		err = fmt.Errorf("Invalid destination port: %s", parts[5])
		return
	}

	ingestProxy = &net.TCPAddr{IP: ip, Port: port}
	return
}

func NewMultiListener(l net.Listener, timeout time.Duration) *multiListener {
	rtmpCh := make(chan net.Conn, 5)
	rtmpL := outListener{l, rtmpCh}
	httpCh := make(chan net.Conn, 5)
	httpL := outListener{l, httpCh}
	go func() {
		for {
			conn, err := l.Accept()
			if err != nil {
				return
			}
			tcpConn, ok := conn.(*net.TCPConn)
			if !ok {
				return
			}
			go sniffConnection(&TimeoutConn{tcpConn, timeout}, rtmpCh, httpCh)
		}
	}()
	return &multiListener{&rtmpL, &httpL}
}

func (ml *multiListener) RtmpListener() net.Listener {
	return ml.rtmpListener
}

func (ml *multiListener) HttpListener() net.Listener {
	return ml.httpListener
}

func (l *outListener) Accept() (net.Conn, error) {
	c := <-l.ch
	return c, nil
}

func (conn *peekedConn) Read(p []byte) (int, error) {
	return conn.reader.Read(p)
}
