package proxy

/* Contains a library of functions for validating statsD lines */

import (
	"bytes"
	"crypto/sha1"
	"errors"
	"fmt"
	"hash/fnv"
	"regexp"
	"strconv"
	"sync"
	"time"

	cache "github.com/patrickmn/go-cache"
)

const (
	SHA1  = "sha1"
	FNV1A = "fnv1a"
)

// Hasher describes a hash function for metric keys
type Hasher func([]byte) uint32

// HashStrategies is a collection of hash functions implementing `Hasher`. `fnv1a` should be used in all cases except legacy ones—it is fast and consistent.
var HashStrategies = map[string]Hasher{
	SHA1: func(metric []byte) uint32 {
		b := sha1.Sum([]byte(metric))
		// turn the sha1 back into a 32-bit int
		return uint32(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
	},

	FNV1A: func(metric []byte) uint32 {
		h := fnv.New32a()
		_, _ = h.Write(metric)
		return h.Sum32()
	},
}

// statsdRegex matches valid statsd lines
var statsdRegex = regexp.MustCompile(`^[0-9a-zA-Z.\-_/]+$`)

// Create a shared cache for stat names to avoid duplicate regex matches
var regexCache = cache.New(5*time.Minute, 30*time.Second)

// BufferLine is a line from a buffer with syncing, handed to the forwarder
// for shipping to the statsd endpoint
type BufferLine struct {
	Line []byte
	wg   *sync.WaitGroup
}

// Done tells the validator that the receiving forwarder is done with this line from the buffer
func (bf BufferLine) Done() {
	bf.wg.Done()
}

// Validator reads lines from a Collector and validates them, hashing to downstream
type Validator interface {
	Run()                            // Infinite loop that polls a collector and validates and ships packets
	Validate([]byte) ([]byte, error) // Returns an error if the metric is invalid, or the name if the line is valid
	Hash([]byte) Forwarder           // Returns a channel to send the metric line into
}

// StatsdValidator validates lines are proper StatsD and hashes them downstream
type StatsdValidator struct {
	collector    Collector // source to get packets from
	forwarders   []Forwarder
	hasher       Hasher
	linesInCount int64
	observer     Observer
	shutdown     <-chan struct{}
}

// NewStatsdValidator creates a new StatsdValidator
func NewStatsdValidator(c Collector, f []Forwarder, h string, o Observer, shutdown <-chan struct{}) (*StatsdValidator, error) {
	hasher := HashStrategies[h]
	if hasher == nil {
		return nil, fmt.Errorf("could not find hash strategy %q", h)
	}

	return &StatsdValidator{
		collector:  c,
		forwarders: f,
		hasher:     hasher,
		observer:   o,
		shutdown:   shutdown,
	}, nil
}

// Run starts up the StatsdValidator
func (sv *StatsdValidator) Run() {
	statTicker := time.NewTicker(500 * time.Millisecond)
	for {
		// Submit a stat every 500ms about number of incoming lines
		select {
		case <-statTicker.C:
			sv.observer.Increment("LinesIn", sv.linesInCount)
			sv.linesInCount = 0
			select {
			case <-sv.shutdown:
				return
			default:
			}
		default:
			sv.processBuffer()
		}
	}
}

// ProcessBuffer takes in a packet buffer and processes lines
func (sv *StatsdValidator) processBuffer() {
	buf := sv.collector.GetBuffer() // grab a packet
	wg := &sync.WaitGroup{}
	lines := bytes.Split(buf, []byte{'\n'})
	lineCount := len(lines)
	// Internal statistics -- increment line-in counter
	sv.linesInCount += int64(lineCount)
	// Initialize the waitgroup for this buffer assuming all lines are valid
	wg.Add(lineCount)
	for _, line := range lines {
		if len(line) == 0 {
			// Drop empty lines
			sv.observer.Increment("EmptyLines", 1)
			// Decrement the waitgroup for this invalid line
			wg.Done()
			continue
		}

		metricName, err := sv.Validate(line) // returns non-nil error if invalid
		if err != nil {
			sv.observer.Error("invalid line", err, "line", string(line))
			sv.observer.Increment("InvalidLines", 1)
			// Decrement the waitgroup for this invalid line
			wg.Done()
			continue
		}
		forwarder := sv.Hash(metricName) // Get a downstream to send to
		// Create a buffer line for this individual line from the buffer/packet
		bufferLine := &BufferLine{Line: line, wg: wg}
		// Send the line on its way, to a forwarder
		forwarder.SendLine(bufferLine)
	}

	//wait until the forwarders are done w/ the buffer before returning to the collector
	go func() {
		wg.Wait()
		sv.collector.ReturnBuffer(buf)
	}()
}

var (
	// TypeCounter is the string representation of a counter
	TypeCounter = []byte("c")
	//TypeGauge is the string representation of a gauge
	TypeGauge = []byte("g")
	// TypeSet is the string representation of a set
	TypeSet = []byte("s")
	// TypeHistogram is the string representation of a histogram
	TypeHistogram = []byte("h")
	// TypeTimer is the string representation of a timer
	TypeTimer = []byte("ms")
)

var (
	// ErrBadMetric - invalid metric formatting
	ErrBadMetric = errors.New("bad_metric")
	// ErrBadValue - invalid metric value
	ErrBadValue = errors.New("bad_value")
	// ErrBadType - invalid metric type
	ErrBadType = errors.New("bad_type")
	// ErrUnexpectedEndOfLine unexpected EOL
	ErrUnexpectedEndOfLine = errors.New("unexpected_EOL")
	// ErrTooManyFields triggers when a stat has too many '|' fields
	ErrTooManyFields = errors.New("too_many_fields")
	// ErrTooFewFields triggers when a stat doesn't have enough '|' fields
	ErrTooFewFields = errors.New("too_few_fields")
	// ErrIncorrectSampling incorrect sampling parameter
	ErrIncorrectSampling = errors.New("incorrect_sampling")
)

// Validate returns the name of the metric if it is valid, otherwise an error
// TODO: break this out into many smaller functions
func (sv *StatsdValidator) Validate(line []byte) ([]byte, error) {
	fields := bytes.Split(line, []byte{'|'})
	fieldCount := len(fields)
	var metricNameAndValue, metricType, metricSampling []byte
	// Not enough fields?
	if fieldCount < 2 {
		return nil, ErrTooFewFields
	}
	// Too many fields?
	if fieldCount > 3 {
		return nil, ErrTooManyFields
	}
	metricNameAndValue = fields[0]
	metricType = fields[1]
	// If 3 fields, is sampling formatted correctly?
	if fieldCount == 3 {
		metricSampling = fields[2]
		// Must start with '@'
		if len(metricSampling) == 0 || metricSampling[0] != '@' {
			return nil, ErrIncorrectSampling
		}
		// Rest must parse as a double-precision float
		value, err := strconv.ParseFloat(string(metricSampling[1:]), 64)
		if err != nil {
			return nil, ErrIncorrectSampling
		}
		if value < 0 || value > 1 {
			return nil, ErrIncorrectSampling
		}
	}
	// Sampling is fine. Moving on...
	// Ensure that the metricType is valid
	metricTypeLen := len(metricType)
	if metricTypeLen > 2 || metricTypeLen == 0 {
		return nil, ErrBadType
	}
	switch {
	case bytes.Equal(metricType, TypeCounter):
	case bytes.Equal(metricType, TypeTimer):
	case bytes.Equal(metricType, TypeGauge):
	case bytes.Equal(metricType, TypeSet):
	case bytes.Equal(metricType, TypeHistogram):
	default:
		return nil, ErrBadType
	}
	// Ok, the metricType is valid, now look at the name and value
	metricKV := bytes.Split(metricNameAndValue, []byte(":"))
	// There should be exactly 2 fields: a key and a value
	if len(metricKV) != 2 {
		return nil, ErrBadMetric
	}
	metricKey := metricKV[0]
	metricCacheKey := string(metricKey)
	metricValue := metricKV[1]
	// Validate the key
	// See if it is already in cache
	valid, cacheHit := regexCache.Get(metricCacheKey)
	// if not, set it in cache
	if !cacheHit {
		valid = statsdRegex.Match(metricKey)
		regexCache.Set(metricCacheKey, valid, 0)
	}
	if !valid.(bool) {
		return nil, ErrBadMetric
	}
	// Validate the value -> should be able to parse as a float
	_, err := strconv.ParseFloat(string(metricValue), 64)
	if err != nil {
		return nil, ErrBadValue
	}

	// It's valid! Return the metric name
	return metricKey, nil
}

// Hash returns a channel from a consistent hashring to send the line into
func (sv *StatsdValidator) Hash(metric []byte) Forwarder {
	idx := int(sv.hasher(metric))
	if idx < 0 {
		idx = -idx // this should never happen on 64-bit systems
	}
	return sv.forwarders[idx%len(sv.forwarders)]
}

// Spec is that you have [a-zA-Z0-9_\-\./]+:<float>|{c,ms,g,h,m}[|@<float> {only allowed for c}]
// See https://github.com/b/statsd_spec for more information
