package twitchclient

import (
	"context"
	"crypto/tls"
	"fmt"
	"net"
	"net/http/httptrace"
	"regexp"
	"time"

	"github.com/cactus/go-statsd-client/statsd"
)

const (
	successSuffix = "success"
	failureSuffix = "failure"
)

type getConnectionData struct {
	getConnectionStart time.Time
}

type dialData struct {
	dialStart time.Time
}

type dnsData struct {
	host     string
	dnsStart time.Time
}

type tlsData struct {
	tlsHandshakeStart time.Time
}

func withStats(ctx context.Context, stats statsd.Statter, dnsPrefix string, logger Logger, sampleRate float32) context.Context {
	dnsHookData := dnsData{}
	tlsHookData := tlsData{}
	dialHookData := dialData{}
	getConnectionHookData := getConnectionData{}

	clientTrace := &httptrace.ClientTrace{
		GetConn: hookGetConn(logger, &dnsHookData, &getConnectionHookData),
		GotConn: hookGotConn(stats, &dnsHookData, &getConnectionHookData, sampleRate),

		DNSStart: hookDNSStart(&dnsHookData),
		DNSDone:  hookDNSDone(stats, dnsPrefix, &dnsHookData, sampleRate),

		ConnectStart: hookConnectStart(&dialHookData),
		ConnectDone:  hookConnectDone(stats, &dnsHookData, &dialHookData, sampleRate),

		WroteRequest: hookWroteRequest(stats, &dnsHookData),

		TLSHandshakeStart: hookTLSHandshakeStart(&tlsHookData),
		TLSHandshakeDone:  hookTLSHandshakeDone(stats, &dnsHookData, &tlsHookData, sampleRate),
	}

	return httptrace.WithClientTrace(ctx, clientTrace)
}

func hookGetConn(logger Logger, dnsData *dnsData, getConnectionData *getConnectionData) func(hostPort string) {
	return func(hostPort string) {
		getConnectionData.getConnectionStart = time.Now()

		host, _, err := net.SplitHostPort(hostPort)
		if err != nil {
			logger.Log(err)
			return
		}
		dnsData.host = host
	}
}

func hookGotConn(stats statsd.Statter, dnsData *dnsData, getConnectionData *getConnectionData, sampleRate float32) func(httptrace.GotConnInfo) {
	return func(trace httptrace.GotConnInfo) {
		if getConnectionData == nil || dnsData == nil || getConnectionData.getConnectionStart.IsZero() || dnsData.host == "" {
			return // do not attempt to send inaccurate stats if hookGetConn or hookDNSStart was not run.
		}

		duration := time.Since(getConnectionData.getConnectionStart)
		stat := statf("get_connection.%s.success", dnsData.host)
		_ = stats.TimingDuration(stat, duration, sampleRate)

		if trace.Reused {
			stat = statf("get_connection.%s.reused", dnsData.host)
			_ = stats.Inc(stat, 1, sampleRate)
		}

		if trace.WasIdle {
			stat = statf("get_connection.%s.idle_time", dnsData.host)
			_ = stats.TimingDuration(stat, trace.IdleTime, sampleRate)

			stat = statf("get_connection.%s.was_idle", dnsData.host)
			_ = stats.Inc(stat, 1, sampleRate)
		}
	}
}

func hookTLSHandshakeStart(tlsData *tlsData) func() {
	return func() {
		tlsData.tlsHandshakeStart = time.Now()
	}
}

func hookTLSHandshakeDone(stats statsd.Statter, dnsData *dnsData, tlsData *tlsData, sampleRate float32) func(tls.ConnectionState, error) {
	return func(state tls.ConnectionState, err error) {
		if dnsData == nil || tlsData == nil || tlsData.tlsHandshakeStart.IsZero() || dnsData.host == "" {
			return // do not attempt to send inaccurate stats if hookTLSHandshakeStart or hookDNSStart was not run.
		}

		duration := time.Since(tlsData.tlsHandshakeStart)
		if err != nil {
			stat := statf("tls_handshake.%s.%s", dnsData.host, failureSuffix)
			_ = stats.Inc(stat, 1, 1)
			return
		}

		stat := statf("tls_handshake.%s.%s", dnsData.host, successSuffix)
		_ = stats.TimingDuration(stat, duration, sampleRate)
	}
}

func hookConnectStart(dialData *dialData) func(network, addr string) {
	return func(network, addr string) {
		dialData.dialStart = time.Now()
	}
}

func hookConnectDone(stats statsd.Statter, dnsData *dnsData, dialData *dialData, sampleRate float32) func(network, addr string, err error) {
	return func(network, addr string, err error) {
		if dialData == nil || dnsData == nil || dialData.dialStart.IsZero() || dnsData.host == "" {
			return // do not attempt to send inaccurate stats if hookConnectStart or hookDNSStart was not run.
		}

		if err != nil {
			stat := statf("dial.%s.%s", dnsData.host, failureSuffix)
			_ = stats.Inc(stat, 1, 1)
			return
		}

		duration := time.Since(dialData.dialStart)
		stat := statf("dial.%s.%s", dnsData.host, successSuffix)
		_ = stats.TimingDuration(stat, duration, sampleRate)
	}
}

func hookWroteRequest(stats statsd.Statter, dnsData *dnsData) func(httptrace.WroteRequestInfo) {
	return func(trace httptrace.WroteRequestInfo) {
		if trace.Err != nil {
			if dnsData == nil || dnsData.host == "" {
				return // do not attempt to send inaccurate stats if hookDNSStart was not run.
			}

			stat := statf("req_write.%s.failure", dnsData.host)
			_ = stats.Inc(stat, 1, 1)
		}
	}
}

func hookDNSStart(dnsData *dnsData) func(httptrace.DNSStartInfo) {
	return func(trace httptrace.DNSStartInfo) {
		dnsData.dnsStart = time.Now()
		dnsData.host = trace.Host
	}
}

func hookDNSDone(stats statsd.Statter, dnsPrefix string, dnsData *dnsData, sampleRate float32) func(httptrace.DNSDoneInfo) {
	return func(trace httptrace.DNSDoneInfo) {
		if dnsData == nil || dnsData.dnsStart.IsZero() || dnsData.host == "" {
			return // do not attempt to send inaccurate stats if hookDNSStart was not run.
		}
		duration := time.Since(dnsData.dnsStart)

		if trace.Err != nil {
			stat := statf("%s.%s.%s", dnsPrefix, dnsData.host, failureSuffix)
			_ = stats.Inc(stat, 1, 1)
		} else {
			stat := statf("%s.%s.%s", dnsPrefix, dnsData.host, successSuffix)
			_ = stats.TimingDuration(stat, duration, sampleRate)
		}

		if trace.Coalesced {
			stat := statf("%s.%s.coalesced", dnsPrefix, dnsData.host)
			_ = stats.Inc(stat, 1, sampleRate)
		}
	}
}

func statf(format string, a ...interface{}) string {
	for i := range a {
		str, ok := a[i].(string)
		if !ok {
			continue
		}
		a[i] = sanitizeStat(str)
	}

	return fmt.Sprintf(format, a...)
}

var invalidStatChars = regexp.MustCompile(`[^A-Za-z0-9-]+`)

// statf formats a stat, including sanitization for invalid characters. For example:
// statf("mystat.%s.%d", "fo&8*o", 200) returns "mystat.fo_8_o.200"
func sanitizeStat(stat string) string {
	return invalidStatChars.ReplaceAllString(stat, "_")
}
