package crypto

import (
	"crypto"
	"crypto/rand"
	"crypto/sha256"
	"crypto/subtle"
	"encoding/asn1"
	"encoding/base64"
	"encoding/binary"
	"errors"
	"math/big"
)

const (
	saltSize        = 32 // we assume SHA-256 is used
	hashSize        = 32
	signatureEnding = 0xBC
)

var (
	bigOne      *big.Int
	bigMinusOne *big.Int
	bigTwo      *big.Int
	bigThree    *big.Int
	bigFour     *big.Int
	bigZero     *big.Int
)

type RWPublicKey struct {
	N *big.Int
}

type RWPrivateKey struct {
	PubKey RWPublicKey
	P      *big.Int
	Q      *big.Int
	IQMP   *big.Int
	DQ     *big.Int
	DP     *big.Int
	TwoMP  *big.Int
	TwoMQ  *big.Int
}

type RWVerifyInterface interface {
	Verify(digest []byte, signature []byte) (bool, error)
	VerifyMsg(message []string, signature []byte) (bool, error)
}

type RWSignInterface interface {
	Sign(data []byte) ([]byte, error)
	SignDigest(data []byte) ([]byte, error)
}

func init() {
	bigZero = &big.Int{}
	bigZero.SetInt64(0)

	bigOne = &big.Int{}
	bigOne.SetInt64(1)

	bigMinusOne = &big.Int{}
	bigMinusOne.SetInt64(-1)

	bigTwo = &big.Int{}
	bigTwo.SetInt64(2)

	bigThree = &big.Int{}
	bigThree.SetInt64(3)

	bigFour = &big.Int{}
	bigFour.SetInt64(4)
}

func RWPublicKeyFromBin(data []byte) (*RWPublicKey, error) {
	var rw RWPublicKey
	if _, err := asn1.Unmarshal(data, &rw); err != nil {
		return nil, err
	}

	return &rw, nil
}

func RWPrivateKeyFromBin(data []byte) (*RWPrivateKey, error) {

	rwTmp := struct {
		N     *big.Int
		P     *big.Int
		Q     *big.Int
		IQMP  *big.Int
		DQ    *big.Int
		DP    *big.Int
		TwoMP *big.Int
		TwoMQ *big.Int
	}{}

	if _, err := asn1.Unmarshal(data, &rwTmp); err != nil {
		return nil, err
	}

	rw := RWPrivateKey{
		RWPublicKey{rwTmp.N},
		rwTmp.P,
		rwTmp.Q,
		rwTmp.IQMP,
		rwTmp.DQ,
		rwTmp.DP,
		rwTmp.TwoMP,
		rwTmp.TwoMQ,
	}

	return &rw, nil
}

func RWPublicKeyFromString(value string) (*RWPublicKey, error) {
	data, err := base64.RawURLEncoding.DecodeString(value)
	if err != nil {
		return nil, err
	}

	return RWPublicKeyFromBin(data)
}

func RWPrivateKeyFromString(value string) (*RWPrivateKey, error) {
	data, err := base64.RawURLEncoding.DecodeString(value)
	if err != nil {
		return nil, err
	}

	return RWPrivateKeyFromBin(data)
}

func (rw *RWPublicKey) internalVerifyPSSR(digest []byte, signbytes []byte) (bool, error) {
	msbits := (rw.N.BitLen() - 1) & 0x07
	emlen := len(rw.N.Bytes())

	if len(signbytes) < emlen {
		signbytes = append(make([]byte, emlen-len(signbytes)), signbytes...)
	}

	if (signbytes[0] & (0xff << uint(msbits))) != 0 {
		return false, errors.New("internalVerifyPSSR(). invalid signature format - 1")
	}

	if msbits == 0 {
		signbytes = signbytes[1:]
	}

	if len(signbytes) < saltSize+hashSize+2 {
		return false, errors.New("internalVerifyPSSR(). invalid signature format - 2")
	}

	if signbytes[len(signbytes)-1] != signatureEnding {
		return false, errors.New("internalVerifyPSSR(). invalid signature format - 3")
	}

	maskedDBLen := uint32(emlen - hashSize - 1)
	hashbytes := signbytes[maskedDBLen : len(signbytes)-1]
	if len(hashbytes) != hashSize {
		return false, errors.New("internalVerifyPSSR(). invalid signature format - 4")
	}
	db := mgf1(hashbytes, maskedDBLen)
	for i := 0; i < int(maskedDBLen); i++ {
		db[i] ^= signbytes[i]
	}

	if msbits != 0 {
		db[0] &= 0xFF >> (8 - uint(msbits))
	}

	var i int
	for i = 0; i < int(maskedDBLen-1); i++ {
		if db[i] != 0 {
			break
		}
	}

	if db[i] != 0x01 {
		return false, errors.New("internalVerifyPSSR(). invalid signature format - 5")
	}
	i++

	if (maskedDBLen - uint32(i)) != saltSize {
		return false, nil
	}

	hash := sha256.New()
	if _, err := hash.Write([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}); err != nil {
		panic(err)
	}
	if _, err := hash.Write(digest); err != nil {
		panic(err)
	}

	if int(maskedDBLen)-i > 0 {
		if _, err := hash.Write(db[i:maskedDBLen]); err != nil {
			panic(err)
		}
	}

	sum := hash.Sum(nil)
	if subtle.ConstantTimeCompare(sum, signbytes[maskedDBLen:]) == 0 {
		return true, nil
	}

	return false, nil
}

func bytecounter(cnt uint32) []byte {
	result := make([]byte, 4)
	binary.BigEndian.PutUint32(result, cnt)
	return result
}

// https://en.wikipedia.org/wiki/Mask_generation_function
func mgf1(seed []byte, needed uint32) []byte {
	output := make([]byte, 0)
	for i := uint32(0); uint32(len(output)) < needed; i++ {
		cnt := bytecounter(i)
		hash := sha256.New()
		if _, err := hash.Write(seed); err != nil {
			panic(err)
		}
		if _, err := hash.Write(cnt); err != nil {
			panic(err)
		}
		output = append(output, hash.Sum(nil)...)
	}

	return output[:needed]
}

func (rw *RWPublicKey) internalApply(signature []byte) (*big.Int, error) {
	var x big.Int
	x.SetBytes(signature)

	if x.Cmp(rw.N) >= 0 {
		return nil, errors.New("internalApply(). invalid signature format - 1")
	}

	var t1 big.Int
	var t2 big.Int
	var t1d big.Int
	var t2d big.Int
	var one big.Int

	one.SetInt64(1)
	t1.Sub(rw.N, &one)
	t1.Rsh(&t1, 1)

	if t1.Cmp(&x) < 0 {
		return nil, errors.New("internalApply(). invalid signature format - 2")
	}

	if x.Sign() < 0 {
		return nil, errors.New("internalApply(). invalid signature format - 3")
	}

	t1.Mul(&x, &x)
	t1.Mod(&t1, rw.N)

	t2.Sub(rw.N, &t1)
	t1d.Lsh(&t1, 1)
	t2d.Lsh(&t2, 1)

	rest1 := t1.Uint64() & 0x0f
	rest2 := t2.Uint64() & 0x0f

	if rest1 == 12 {
		return &t1, nil
	} else if rest1&0x07 == 6 {
		return &t1d, nil
	} else if rest2 == 12 {
		return &t2, nil
	} else if rest2&0x07 == 6 {
		return &t2d, nil
	}

	return nil, errors.New("internalApply(). invalid signature format - 4")
}

func (rw *RWPublicKey) Verify(digest []byte, signature []byte) (bool, error) {
	if len(digest) == 0 || len(signature) == 0 {
		return false, errors.New("Verify(). invalid signature format")
	}

	x, err := rw.internalApply(signature)
	if err != nil {
		return false, err
	}

	xbts := x.Bytes()
	if len(xbts) < len(signature) {
		padding := make([]byte, len(signature)-len(xbts))
		xbts = append(padding, xbts...)
	}

	res, err := rw.internalVerifyPSSR(digest, xbts)
	if err != nil {
		return false, err
	}

	return res, nil
}

func (rw *RWPublicKey) VerifyMsg(message string, signature []byte) (bool, error) {
	digest := sha256.Sum256([]byte(message))
	return rw.Verify(digest[:], signature)
}

func (rw *RWPrivateKey) Sign(data []byte) ([]byte, error) {
	digest := sha256.Sum256(data)
	return rw.SignDigest(digest[:])
}

/*
Disclaimer: this code neither optimized nor secure/constant-time, so use it only for testing purposes
*/
func (rw *RWPrivateKey) SignDigest(data []byte) ([]byte, error) {
	if len(data) != hashSize {
		return nil, errors.New("data must be 32 bytes long")
	}

	padded, err := rw.PubKey.addPSSRPadding(data)
	if err != nil {
		return nil, err
	}

	return rw.signPaddedDigest(padded)
}

func (rw *RWPrivateKey) signPaddedDigest(dgst []byte) ([]byte, error) {
	if rw.P == nil || rw.Q == nil || rw.PubKey.N == nil {
		return nil, errors.New("invalid private key")
	}

	var m, u, v, tmp, mQ, mP, tmp2 *big.Int
	m = &big.Int{}
	u = &big.Int{}
	v = &big.Int{}
	tmp = &big.Int{}
	mQ = &big.Int{}
	mP = &big.Int{}
	tmp2 = &big.Int{}

	m.SetBytes(dgst)

	if m.Cmp(rw.PubKey.N) >= 0 {
		return nil, errors.New("bad digest value")
	}

	if m.Uint64()&0x0f != 12 {
		return nil, errors.New("bad padding")
	}

	mQ.Mod(m, rw.Q)
	mP.Mod(m, rw.P)

	u.Exp(mQ, rw.DQ, rw.Q)
	tmp.Exp(u, bigFour, rw.Q)

	if tmp.Cmp(mQ) != 0 {
		mP.Sub(rw.P, mP)
	}

	v.Exp(mP, rw.DP, rw.P)

	tmp2.Exp(mP, bigTwo, rw.P)
	tmp.Exp(v, bigFour, rw.P)

	tmp.Mul(tmp, tmp2)
	tmp.Mod(tmp, rw.P)

	v.Exp(v, bigThree, rw.P)
	v.Mul(v, mP)
	v.Mod(v, rw.P)

	if tmp.Cmp(mP) != 0 {
		u.Mul(u, rw.TwoMQ)
		u.Mod(u, rw.Q)

		v.Mul(v, rw.TwoMP)
		v.Mod(v, rw.P)
	}

	v.Sub(v, u)
	v.Mod(v, rw.P)

	v.Mul(v, rw.IQMP)
	v.Mod(v, rw.P)
	v.Mul(v, rw.Q)

	v.Add(v, u)
	v.Mod(v, rw.PubKey.N)

	v.Exp(v, bigTwo, rw.PubKey.N)

	tmp.Sub(rw.PubKey.N, v)
	if tmp.Cmp(v) >= 0 {
		tmp = v
	}

	result := tmp.Bytes()
	if len(result) < len(rw.PubKey.N.Bytes()) {
		result = append(make([]byte, len(rw.PubKey.N.Bytes())-len(result)), result...)
	}

	return result, nil
}

func (rw *RWPublicKey) addPSSRPadding(digest []byte) ([]byte, error) {
	if rw.N == nil {
		return nil, errors.New("value of N can't be nil")
	}

	if len(rw.N.Bytes()) < crypto.SHA256.Size() {
		return nil, errors.New("public key is too short")
	}

	salt := make([]byte, crypto.SHA256.Size())
	if _, err := rand.Read(salt); err != nil {
		return nil, err
	}

	maskedLen := len(rw.N.Bytes()) - crypto.SHA256.Size() - 1
	msbits := (rw.N.BitLen() - 1) & 0x07

	hashInstance := crypto.SHA256.New()
	if _, err := hashInstance.Write([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}); err != nil {
		panic(err)
	}
	if _, err := hashInstance.Write(digest); err != nil {
		panic(err)
	}
	if _, err := hashInstance.Write(salt); err != nil {
		panic(err)
	}
	sum := hashInstance.Sum(nil)

	result := mgf1(sum, uint32(maskedLen))
	result = append(result, sum...)
	result = append(result, []byte{signatureEnding}...)

	start := len(result) - 2*crypto.SHA256.Size() - 2
	result[start] ^= 0x01
	start++

	for i := 0; i < len(salt); i++ {
		result[start+i] ^= salt[i]
	}

	if msbits != 0 {
		result[0] &= 0xff >> (8 - uint(msbits))
	}

	return result, nil
}
