// Copyright 2013, zhangpeihao All rights reserved.

package rtmp

import (
	"bytes"
	"code.justin.tv/video/gortmp/pkg/log"
	"context"
	"crypto/hmac"
	"crypto/rand"
	"crypto/sha256"
	"encoding/binary"
	"fmt"
	"io"
	"io/ioutil"
	"net"
)

const (
	HANDSHAKE_SIZE = 1536
	DIGEST_LENGTH  = 32
	KEY_LENGTH     = 128

	RTMP_SIG_SIZE          = 1536
	RTMP_LARGE_HEADER_SIZE = 12
	SHA256_DIGEST_LENGTH   = 32
)

var (
	GENUINE_FMS_KEY = []byte{
		0x47, 0x65, 0x6e, 0x75, 0x69, 0x6e, 0x65, 0x20,
		0x41, 0x64, 0x6f, 0x62, 0x65, 0x20, 0x46, 0x6c,
		0x61, 0x73, 0x68, 0x20, 0x4d, 0x65, 0x64, 0x69,
		0x61, 0x20, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
		0x20, 0x30, 0x30, 0x31, // Genuine Adobe Flash Media Server 001
		0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8,
		0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57,
		0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab,
		0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae,
	}
	GENUINE_FP_KEY = []byte{
		0x47, 0x65, 0x6E, 0x75, 0x69, 0x6E, 0x65, 0x20,
		0x41, 0x64, 0x6F, 0x62, 0x65, 0x20, 0x46, 0x6C,
		0x61, 0x73, 0x68, 0x20, 0x50, 0x6C, 0x61, 0x79,
		0x65, 0x72, 0x20, 0x30, 0x30, 0x31, /* Genuine Adobe Flash Player 001 */
		0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8,
		0x2E, 0x00, 0xD0, 0xD1, 0x02, 0x9E, 0x7E, 0x57,
		0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
		0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
	}
)

type readWriter struct {
	io.Reader
	io.Writer
}

func calculateHMACsha256(msgBytes []byte, key []byte) ([]byte, error) {
	h := hmac.New(sha256.New, key)
	_, err := h.Write(msgBytes)
	if err != nil {
		return nil, err
	}
	return h.Sum(nil), nil
}

func sumBytes(buf []byte) uint32 {
	var sum uint32
	for _, b := range buf {
		sum += uint32(b)
	}

	return sum
}

func getDigestOffset0(buf []byte) uint32 {
	return (sumBytes(buf[8:12]) % 728) + 12
}

func getDigestOffset1(buf []byte) uint32 {
	return (sumBytes(buf[772:776]) % 728) + 776
}

func getDigestOffset(buf []byte, scheme int) uint32 {
	switch scheme {
	case 0:
		return getDigestOffset0(buf)
	case 1:
		return getDigestOffset1(buf)
	default:
		return getDigestOffset0(buf)
	}
}

// TODO: make this pass errors along somehow
func validate(buf []byte) (int, error) {
	var errs []string

	for i := 0; i < 2; i++ {
		err := validateScheme(buf, i)
		if err != nil {
			errs = append(errs, err.Error())
		} else {
			return i, nil
		}
	}

	return 0, fmt.Errorf("all validation schemes failed: %v", errs)
}

func validateScheme(buf []byte, scheme int) error {
	var digestOffset uint32
	switch scheme {
	case 0:
		digestOffset = getDigestOffset0(buf)
	case 1:
		digestOffset = getDigestOffset1(buf)
	default:
		return fmt.Errorf("validateScheme(%d): invlaid scheme", scheme)
	}

	tempBuffer := make([]byte, HANDSHAKE_SIZE-DIGEST_LENGTH)
	copy(tempBuffer, buf[:digestOffset])
	copy(tempBuffer[digestOffset:], buf[digestOffset+DIGEST_LENGTH:])

	hash, err := calculateHMACsha256(tempBuffer, GENUINE_FP_KEY[:30])
	if err != nil {
		return fmt.Errorf("validateScheme(%d) calculateHMAC failed: %s", scheme, err)
	}

	if bytes.Compare(buf[digestOffset:digestOffset+DIGEST_LENGTH], hash) != 0 {
		return fmt.Errorf("validateScheme(%d) digest comparison failed", scheme)
	}

	return nil
}

func doSimpleHandshake(conn net.Conn, c1 []byte) error {
	return fmt.Errorf("simple handshake not implemented")
}

func Handshake(ctx context.Context, conn net.Conn) ([]byte, error) {
	log := log.FromContext(ctx, "handshake")

	buf := &bytes.Buffer{}
	tr := io.TeeReader(conn, buf)

	c0c1 := make([]byte, HANDSHAKE_SIZE+1)
	c0 := c0c1[0:1]
	c1 := c0c1[1:]

	// Set the rtmp connection type
	c0[0] = 0x03

	// first 8 bytes are timestamp + version
	if _, err := rand.Read(c1[8:]); err != nil {
		return buf.Bytes(), fmt.Errorf("failed to create C1 bytes: %s", err)
	}

	// get the digest position
	c1off := getDigestOffset(c1, 0)

	// crop out digest bytes
	c1bytes := make([]byte, HANDSHAKE_SIZE-DIGEST_LENGTH)
	copy(c1bytes, c1[:c1off])
	copy(c1bytes[c1off:], c1[c1off+DIGEST_LENGTH:])

	c1hash, err := calculateHMACsha256(c1bytes, GENUINE_FP_KEY[:30])
	if err != nil {
		return buf.Bytes(), fmt.Errorf("failed to calculate C1 digest: %s", err)
	}

	// write the hash into our output data
	copy(c1[c1off:], c1hash)

	log.Debugf("writing C0 C1")
	if _, err := conn.Write(c0c1); err != nil {
		return buf.Bytes(), err
	}

	log.Debugf("reading S0 S1 S2")
	s0s1s2 := make([]byte, 2*HANDSHAKE_SIZE+1)
	if _, err := io.ReadFull(tr, s0s1s2); err != nil {
		return buf.Bytes(), err
	}

	s0 := s0s1s2[0]
	s1 := s0s1s2[1 : HANDSHAKE_SIZE+1]
	// s2 := s0s1s2[1+HANDSHAKE_SIZE:]

	if s0 != 0x03 {
		log.Warnf("invalid connection type: %d", s0)
	}

	// calculate s1 digest offset
	s1off := getDigestOffset(s1, 0)

	s1bytes := make([]byte, HANDSHAKE_SIZE-DIGEST_LENGTH)
	copy(s1bytes, s1[:s1off])
	copy(s1bytes, s1[s1off+DIGEST_LENGTH:])

	s1hash, err := calculateHMACsha256(s1bytes, GENUINE_FMS_KEY[:36])
	if err != nil {
		return buf.Bytes(), fmt.Errorf("failed to create S1 hash: %s", err)
	}

	if bytes.Compare(s1[s1off:s1off+DIGEST_LENGTH], s1hash) != 0 {
		log.Warnf("S1 digest mismatch!")
	}

	c2 := make([]byte, HANDSHAKE_SIZE)
	c2bytes := c2[:HANDSHAKE_SIZE-DIGEST_LENGTH]
	if _, err := rand.Read(c2bytes); err != nil {
		return buf.Bytes(), fmt.Errorf("failed to create C2 bytes: %s", err)
	}

	c2key, err := calculateHMACsha256(s1hash, GENUINE_FP_KEY)
	if err != nil {
		return buf.Bytes(), fmt.Errorf("failed to create C2 key: %s", err)
	}

	c2hash, err := calculateHMACsha256(c2bytes, c2key)
	if err != nil {
		return buf.Bytes(), fmt.Errorf("failed to create C2 hash: %s", err)
	}

	copy(c2[HANDSHAKE_SIZE-DIGEST_LENGTH:], c2hash)

	log.Debugf("writing C2")
	if _, err := conn.Write(c2); err != nil {
		return buf.Bytes(), err
	}

	return buf.Bytes(), nil
}

func SHandshake(ctx context.Context, conn net.Conn) ([]byte, error) {
	log := log.FromContext(ctx, "shandshake")

	buf := &bytes.Buffer{}
	tr := io.TeeReader(conn, buf)

	c0 := [1]byte{}
	if _, err := io.ReadFull(tr, c0[:]); err != nil {
		return buf.Bytes(), err
	}

	if c0[0] != 0x03 {
		log.Warnf("unsupported handshake type: %x", c0)
	}

	c1 := make([]byte, HANDSHAKE_SIZE)
	if _, err := io.ReadFull(tr, c1); err != nil {
		return buf.Bytes(), err
	}
	log.Debugf("read c1")

	// Check the first byte of version
	v := c1[4:8]
	if v[0] == 0 {
		log.Warnf("unversioned flash client detected: %x %x %x %x", v[0], v[1], v[2], v[3])
	}

	scheme, err := validate(c1)
	if err != nil {
		log.Warnf("failed to get validation scheme: %s", err)

		log.Debugf("performing echo handshake")
		err = doEchoSHandshake(c1, &readWriter{tr, conn})
	} else {
		log.Debugf("performing signed handshake")
		err = doSignedSHandshake(c1, scheme, &readWriter{tr, conn})
	}

	log.Debugf("SHandshake Complete")
	return buf.Bytes(), err
}

func doSignedSHandshake(c1 []byte, scheme int, rw io.ReadWriter) error {
	response := make([]byte, 2*HANDSHAKE_SIZE+1)
	s0 := response[0:1]
	s1 := response[1 : HANDSHAKE_SIZE+1]
	s2 := response[HANDSHAKE_SIZE+1:]

	// set handshake type
	s0[0] = 0x03

	// prep output
	binary.BigEndian.PutUint32(s1[4:8], 0x01020304)
	if _, err := rand.Read(s1[8:]); err != nil {
		return fmt.Errorf("failed to create S1 bytes: %s", err)
	}

	s1off := getDigestOffset(s1, scheme)
	s1bytes := make([]byte, HANDSHAKE_SIZE-DIGEST_LENGTH)
	copy(s1bytes, s1[:s1off])
	copy(s1bytes[s1off:], s1[s1off+DIGEST_LENGTH:])

	s1hash, err := calculateHMACsha256(s1bytes, GENUINE_FMS_KEY[:36])
	if err != nil {
		return fmt.Errorf("failed to create S1 hash: %s", err)
	}

	copy(s1[s1off:], s1hash)
	c1off := getDigestOffset(c1, scheme)
	c1hash := c1[c1off : c1off+DIGEST_LENGTH]

	s2bytes := s2[:HANDSHAKE_SIZE-DIGEST_LENGTH]
	if _, err := rand.Read(s2bytes); err != nil {
		return fmt.Errorf("failed to create S2 bytes: %s", err)
	}

	s2key, err := calculateHMACsha256(c1hash, GENUINE_FMS_KEY[:68])
	if err != nil {
		return fmt.Errorf("failed to create S2 key: %s", err)
	}

	s2hash, err := calculateHMACsha256(s2bytes, s2key)
	if err != nil {
		return fmt.Errorf("failed to create S2 hash: %s", err)
	}
	copy(s2[HANDSHAKE_SIZE-DIGEST_LENGTH:], s2hash)

	done := make(chan error, 2)

	go func() {
		if _, err := rw.Write(response); err != nil {
			done <- err
			return
		}

		done <- nil
	}()

	go func() {
		// ignore client response
		_, err := io.CopyN(ioutil.Discard, rw, HANDSHAKE_SIZE)
		done <- err
	}()

	if err := <-done; err != nil {
		return fmt.Errorf("connection error: %s", err)
	}

	if err := <-done; err != nil {
		return fmt.Errorf("connection error: %s", err)
	}

	return nil
}

func doEchoSHandshake(c1 []byte, rw io.ReadWriter) error {
	s0 := make([]byte, 1)
	s1 := make([]byte, HANDSHAKE_SIZE)
	// s1[0:4] == timestamp offset == 0
	// s1[4:8] == 0 to meet spec
	_, err := rand.Read(s1[8:])
	if err != nil {
		return err
	}

	// set handshake type
	s0[0] = 0x03

	done := make(chan error, 2)

	// discard C2 (really doesn't matter)
	go func() {
		_, err := io.CopyN(ioutil.Discard, rw, HANDSHAKE_SIZE)
		done <- err
	}()

	go func() {
		if _, err := rw.Write(s0); err != nil {
			done <- err
			return
		}

		if _, err := rw.Write(s1); err != nil {
			done <- err
			return
		}

		if _, err := rw.Write(c1); err != nil {
			done <- err
			return
		}

		done <- nil
	}()

	if err := <-done; err != nil {
		return fmt.Errorf("connection error: %s", err)
	}

	if err := <-done; err != nil {
		return fmt.Errorf("connection error: %s", err)
	}

	return nil
}
