package odin

import (
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"encoding/xml"
	"errors"
	"fmt"
	"time"
)

const (
	DefaultOdinEverywherePath = "/var/odin/"
	DefaultOdinEndpoint       = "http://localhost:2009/query"
)

var (
	odinDaemonImplementation     = &odinDaemon{}
	odinEverywhereImplementation = &odinEverywhere{}

	// Default to the regular daemon implementation
	// OdinDaemon runs in Prod and OdinEverywhere runs in EC2, so having this global is okay since these modes are mutually exclusive
	impl OdinImpl = odinDaemonImplementation
)

// SwitchToOdinEverywhere
// Optional argument is to override the default path on the filesystem
func SwitchToOdinEverywhere(params ...string) {
	if len(params) > 0 {
		// First parameter is a path override
		odinEverywherePath = params[0]
	} else {
		odinEverywherePath = DefaultOdinEverywherePath
	}

	impl = odinEverywhereImplementation
}

// Switch to odin daemon implementation
// Optional argument is to override the default Odin endpoint
func SwitchToDaemon(params ...string) {
	if len(params) > 0 {
		// First parameter is an endpoint override
		Endpoint = params[0]
	} else {
		Endpoint = DefaultOdinEndpoint
	}

	impl = odinDaemonImplementation
}

func SetDaemonTimeout(timeout time.Duration) {
	client.Timeout = timeout
}

func SwitchToCustom(customImpl OdinImpl) {
	impl = customImpl
}

type OdinImpl interface {
	Retrieve(materialSet, materialType string, serial *int64) (*RawMaterial, error)
}

type Material struct {
	Serial    int64
	NotBefore time.Time
	NotAfter  time.Time
}

type PrivateKeyMaterial struct {
	PrivateKey *rsa.PrivateKey
	Material
}

type PublicKeyMaterial struct {
	PublicKey *rsa.PublicKey
	Material
}

type SymmetricKeyMaterial struct {
	SymmetricKey []byte
	Material
}

func decodeSymmetricKey(c *RawMaterial) (*SymmetricKeyMaterial, error) {
	key := c.MaterialData
	skm := SymmetricKeyMaterial{
		SymmetricKey: key,
		Material: Material{
			Serial:    c.MaterialSerial,
			NotBefore: c.NotBefore.Time,
			NotAfter:  c.NotAfter.Time,
		},
	}
	return &skm, nil
}

func SymmetricKeyBySerial(materialSet string, serial int64) (*SymmetricKeyMaterial, error) {
	c, err := impl.Retrieve(materialSet, "SymmetricKey", &serial)
	if err != nil {
		return nil, err
	}
	return decodeSymmetricKey(c)
}

func SymmetricKey(materialSet string) (*SymmetricKeyMaterial, error) {
	c, err := impl.Retrieve(materialSet, "SymmetricKey", nil)
	if err != nil {
		return nil, err
	}
	return decodeSymmetricKey(c)
}

func CredentialPair(materialSet string) (principal, credential string, err error) {
	principal, credential, _, _, err = credentialPairBySerial(materialSet, nil, nil)
	return
}

func CredentialPairAndSerial(materialSet string) (principal, credential string, principalSerial, credentialSerial int64, err error) {
	return credentialPairBySerial(materialSet, nil, nil)
}

func CredentialPairBySerial(materialSet string, principalSerial int64, credentialSerial int64) (principal, credential string, err error) {
	principal, credential, _, _, err = credentialPairBySerial(materialSet, &principalSerial, &credentialSerial)
	return
}

func credentialPairBySerial(materialSet string, principalSerial *int64, credentialSerial *int64) (principal, credential string, pSerial, cSerial int64, err error) {
	var c, p *RawMaterial
	c, err = impl.Retrieve(materialSet, "Credential", principalSerial)
	if err != nil {
		return
	}
	p, err = impl.Retrieve(materialSet, "Principal", credentialSerial)
	if err != nil {
		return
	}
	credential = string(c.MaterialData)
	cSerial = c.MaterialSerial
	principal = string(p.MaterialData)
	pSerial = p.MaterialSerial
	return
}

func Certificate(materialSet string) (certificate *tls.Certificate, err error) {
	var c, p *RawMaterial
	c, err = impl.Retrieve(materialSet, "Certificate", nil)
	if err != nil {
		return
	}
	p, err = impl.Retrieve(materialSet, "PrivateKey", nil)
	if err != nil {
		return
	}
	certDER := c.MaterialData
	keyDER := p.MaterialData
	var key interface{}
	key, err = x509.ParsePKCS8PrivateKey(keyDER)
	if err != nil {
		return
	}
	var leaf *x509.Certificate
	leaf, err = x509.ParseCertificate(certDER)
	if err != nil {
		return
	}
	certificate = &tls.Certificate{
		Certificate: [][]byte{certDER},
		PrivateKey:  key,
		Leaf:        leaf,
	}
	return
}

func decodePrivateKey(p *RawMaterial) (*PrivateKeyMaterial, error) {
	keyDER := p.MaterialData
	key, err := x509.ParsePKCS8PrivateKey(keyDER)
	if err != nil {
		return nil, err
	}
	if rsakey, ok := key.(*rsa.PrivateKey); ok {
		akm := PrivateKeyMaterial{
			PrivateKey: rsakey,
			Material: Material{
				Serial:    p.MaterialSerial,
				NotBefore: p.NotBefore.Time,
				NotAfter:  p.NotAfter.Time,
			},
		}
		return &akm, nil
	}
	return nil, errors.New("found non-RSA private key in PKCS#8 wrapping")
}

func decodePublicKey(p *RawMaterial) (*PublicKeyMaterial, error) {
	keyDER := p.MaterialData
	i, err := x509.ParsePKIXPublicKey(keyDER)
	if err != nil {
		return nil, err
	}
	if key, ok := i.(*rsa.PublicKey); ok {
		akm := PublicKeyMaterial{
			PublicKey: key,
			Material: Material{
				Serial:    p.MaterialSerial,
				NotBefore: p.NotBefore.Time,
				NotAfter:  p.NotAfter.Time,
			},
		}
		return &akm, nil
	}
	return nil, errors.New("Got unexpected key type")
}

func PrivateKeyBySerial(materialSet string, serial int64) (*PrivateKeyMaterial, error) {
	p, err := impl.Retrieve(materialSet, "PrivateKey", &serial)
	if err != nil {
		return nil, err
	}
	return decodePrivateKey(p)
}

func PrivateKey(materialSet string) (*PrivateKeyMaterial, error) {
	p, err := impl.Retrieve(materialSet, "PrivateKey", nil)
	if err != nil {
		return nil, err
	}
	return decodePrivateKey(p)
}

func PublicKeyBySerial(materialSet string, serial int64) (*PublicKeyMaterial, error) {
	p, err := impl.Retrieve(materialSet, "PublicKey", &serial)
	if err != nil {
		return nil, err
	}
	return decodePublicKey(p)
}

func PublicKey(materialSet string) (*PublicKeyMaterial, error) {
	p, err := impl.Retrieve(materialSet, "PublicKey", nil)
	if err != nil {
		return nil, err
	}
	return decodePublicKey(p)
}

type RawMaterial struct {
	MaterialType   string          `json:"materialType" xml:"type,attr"`
	NotBefore      RawMaterialTime `json:"notBefore" xml:"NotBefore"`
	NotAfter       RawMaterialTime `json:"notAfter" xml:"NotAfter"`
	MaterialData   []byte
	MaterialSerial int64  `json:"materialSerial" xml:"serial,attr"`
	MaterialName   string `json:"materialName"`

	Active bool `xml:"Active"`

	Base64MaterialData string `json:"materialData"`
}

func (x *RawMaterial) IsActive() bool {
	now := time.Now()
	return x.Active && now.After(x.NotBefore.Time) && now.Before(x.NotAfter.Time)
}

// Time helper to decode int64 directly into time.Time
type RawMaterialTime struct {
	time.Time
}

func (m *RawMaterialTime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
	var epochTime int64
	err := d.DecodeElement(&epochTime, &start)
	if err != nil {
		return err
	}

	m.Time = time.Unix(epochTime, 0)

	return nil
}

func (m *RawMaterialTime) UnmarshalJSON(b []byte) error {
	var epochTime int64
	err := json.Unmarshal(b, &epochTime)
	if err != nil {
		return err
	}

	m.Time = time.Unix(epochTime, 0)

	return nil
}

// Serial represents a 64-bit integer that can be optionally
// supplied as the material serial number during material retrieval.
type Serial int64

// Retrieve obtains the raw Odin material directly,
// useful in situations not addressed by other exported functions
// in this package.
//
// The only option that is currently supported is passing an Odin
// material serial, a value of type Serial. Retrieve obtains an Odin
// material with the latest serial, otherwise.
//
// For example:
//
// 	// retrieve the latest credential
// 	odin.Retrieve("cool-service-RSA-Chain", "Credential")
//
// 	// retrieve the credential for given serial
// 	odin.Retrieve("cool-service-RSA-Chain", "Credential", odin.Serial(5))
func Retrieve(materialSet, materialType string, opts ...interface{}) (*RawMaterial, error) {
	var serial *Serial
	for _, opt := range opts {
		switch x := opt.(type) {
		case Serial:
			serial = &x
		default:
			panic(fmt.Errorf("unsupported material retrieval option type: %T", opt))
		}
	}
	return impl.Retrieve(materialSet, materialType, (*int64)(serial))
}
