package caller

import (
	"io"
	"net/http"
	"time"

	"github.com/cactus/go-statsd-client/statsd"

	"code.justin.tv/sse/malachai/pkg/config"
	"code.justin.tv/sse/malachai/pkg/internal/inventory"
	"code.justin.tv/sse/malachai/pkg/log"
	"code.justin.tv/sse/malachai/pkg/registration"
	"code.justin.tv/sse/malachai/pkg/s2s"
	"code.justin.tv/sse/malachai/pkg/s2s/internal"
	"code.justin.tv/sse/malachai/pkg/signature"
	"code.justin.tv/systems/sandstorm/manager"
	"code.justin.tv/systems/sandstorm/queue"
)

const (
	metricNameRequestSigned             = "request.signed"
	metricNameRequestSignatureAttempted = "request.signature.attempted"
	metricNameRequestSignatureSuccess   = "request.signature.success"
	metricNameRequestSignatureFailure   = "request.signature.failure"
)

// RoundTripper implements http.RoundTripper. It signs requests using the
// configured caller service's private key, retrieved from sandstorm.
type RoundTripper struct {
	inner     http.RoundTripper
	signer    *signature.Signer
	sandstorm manager.API
	statter   statsd.Statter
	logger    log.S2SLogger
}

// NewRoundTripper returns an http.RoundTripper, which wraps the default http.RoundTripper
// with a request signer
func NewRoundTripper(callerName string, cfg *Config, logger log.S2SLogger) (rt *RoundTripper, err error) {
	rt, err = NewWithCustomRoundTripper(callerName, cfg, http.DefaultTransport, logger)
	return
}

// NewWithCustomRoundTripper wraps the signing RoundTripper with the provided
// http.RoundTripper
func NewWithCustomRoundTripper(callerName string, cfg *Config, inner http.RoundTripper, logger log.S2SLogger) (rt *RoundTripper, err error) {
	if logger == nil {
		logger = &log.NoOpLogger{}
	}

	if cfg == nil {
		cfg = &Config{}
	}
	cfg.callerName = callerName
	err = cfg.FillDefaults()
	if err != nil {
		return
	}

	logger.Debugf("config options: %#v", cfg)
	logger.Debugf("configuring registrar with: %#v", cfg.RegistrationConfig)
	reg, err := registration.New(cfg.RegistrationConfig, logger)
	if err != nil {
		return
	}

	logger.Debugf("retrieving service registration for caller name '%s'", cfg.callerName)
	caller, err := reg.Get(cfg.callerName)
	if err != nil {
		return
	}
	logger.Debugf("service registered is: %#v", caller)

	inventoryClient := inventory.New(&inventory.Config{
		Environment: cfg.Environment,
		RoleArn:     caller.RoleArn,
	})
	if err != nil {
		return
	}

	instanceID := internal.NewInstanceID()
	err = inventoryClient.Put(&inventory.Instance{
		ServiceID:   caller.ID,
		ServiceName: caller.Name,
		InstanceID:  instanceID,
		Version:     s2s.Version,
	})
	if err != nil {
		return
	}

	logger.Debugf("configuring sandstorm manager using caller role '%s', sandstorm role '%s'", caller.RoleArn, caller.SandstormRoleArn)
	sandstorm := manager.New(manager.Config{
		AWSConfig: config.AWSConfig(
			cfg.Region,
			cfg.roleArn,
			caller.RoleArn,
			caller.SandstormRoleArn),
		KeyID:     cfg.SandstormKMSKeyID,
		TableName: cfg.SandstormSecretsTableName,
		Queue: queue.Config{
			TopicArn: cfg.SandstormTopicArn,
		},
		InstanceID: instanceID,
	})

	if !cfg.DisableSecretRotationListener {
		logger.Debug("listening for sandstorm updates")
		err = sandstorm.ListenForUpdates()
		if err != nil {
			logger.Error("error starting sandstorm queue listener: " + err.Error())
			return
		}
	}

	statter, err := cfg.statsd(cfg.callerName)
	if err != nil {
		logger.Errorf("failed to initialize statsd client")
		return
	}

	privateKey := &privateKeyStorer{
		Sandstorm:           sandstorm,
		SandstormSecretName: caller.SandstormSecretName,
		Logger:              logger,
	}
	// load key into cache
	if _, _, err = privateKey.Key(); err != nil {
		return
	}

	logger.Debug("configuring roundtripper")
	signingRoundTripper := &RoundTripper{
		inner: inner,
		signer: &signature.Signer{
			CallerID:   caller.ID,
			Method:     defaultSigningMethod,
			PrivateKey: privateKey.Key,
		},
		sandstorm: sandstorm,
		logger:    logger,
		statter:   statter,
	}

	rt = signingRoundTripper
	return
}

// Close cleans up the sandstorm notification pipeline
func (rt *RoundTripper) Close() (err error) {
	rt.logger.Infof("stopped listening for sandstorm updates")
	return rt.sandstorm.StopListeningForUpdates()
}

// SetLogger sets the logger
func (rt *RoundTripper) SetLogger(logger log.S2SLogger) {
	rt.logger = logger
}

// SetInnerRoundTripper sets the inner round tripper
func (rt *RoundTripper) SetInnerRoundTripper(inner http.RoundTripper) {
	rt.inner = inner
}

// RoundTrip implements http.RoundTripper
func (rt *RoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
	startTime := time.Now()
	if req.Body == nil {
		rt.logger.Debug("signing an empty body request")

		sErr := rt.statter.Inc(metricNameRequestSignatureAttempted, 1, 1.0)
		err = rt.signer.SignRequest(req)
		if err != nil {
			rt.logger.Error("failed to sign request: " + err.Error())
			sErr := rt.statter.Inc(metricNameRequestSignatureFailure, 1, 1.0)
			if sErr != nil {
				rt.logger.Error("failed to send stat:" + sErr.Error())
			}
			return
		}
		sErr = rt.statter.Inc(metricNameRequestSignatureSuccess, 1, 1.0)
		if sErr != nil {
			rt.logger.Error("failed to send stat: " + sErr.Error())
		}
		sErr = rt.statter.TimingDuration(metricNameRequestSigned, time.Now().Sub(startTime), 1.0)
		if sErr != nil {
			rt.logger.Error("failed to send stat: " + sErr.Error())
		}
		resp, err = rt.inner.RoundTrip(req)
		if err != nil {
			rt.logger.Error("failed to send request: " + err.Error())
		}
		return
	}

	seeker, err := internal.WrapRequestBodyWithSeeker(req, 0)
	if err != nil {
		rt.logger.Error("failure wrapping request body: " + err.Error())
		return
	}

	rt.logger.Debugf("signing a request with body, content-length: %d", req.ContentLength)
	err = rt.signer.SignRequest(req)
	if err != nil {
		rt.logger.Error("failed to sign a request, err: " + err.Error())
		sErr := rt.statter.Inc(metricNameRequestSignatureFailure, 1, 1.0)
		if sErr != nil {
			rt.logger.Error("failed to send stat: " + sErr.Error())
		}
		return
	}

	sErr := rt.statter.Inc(metricNameRequestSignatureSuccess, 1, 1.0)
	if sErr != nil {
		rt.logger.Error("failed to send stat: " + sErr.Error())
	}

	sErr = rt.statter.TimingDuration(metricNameRequestSigned, time.Now().Sub(startTime), 1.0)
	if sErr != nil {
		rt.logger.Error("failed to send stat: " + sErr.Error())
	}
	_, err = seeker.Seek(0, io.SeekStart)
	if err != nil {
		rt.logger.Error("failed to build signed request: " + err.Error())
		return
	}
	resp, err = rt.inner.RoundTrip(req)
	if err != nil {
		rt.logger.Error("failed to send request: " + err.Error())
	}
	return
}

// SignRequest signs a request and adds header values to the request
func (rt *RoundTripper) SignRequest(req *http.Request) (err error) {
	return rt.signer.SignRequest(req)
}

// SignRequestWithHashedBody signs a request and adds header values to the request
func (rt *RoundTripper) SignRequestWithHashedBody(req *http.Request, hashedBody []byte) (err error) {
	return rt.signer.SignRequestWithHashedBody(req, hashedBody)
}
