package contextlogger

import (
	"context"
	"sync"

	log "github.com/Sirupsen/logrus"
	"github.com/pkg/errors"
)

type logFieldsKey struct{}
type logFields struct {
	log.Fields
	sync.RWMutex
}

func ContextWithLogFields(ctx context.Context) context.Context {
	return context.WithValue(
		ctx, logFieldsKey{}, &logFields{Fields: make(log.Fields)})
}

func getLogFields(ctx context.Context) (*logFields, error) {
	logFields, ok := ctx.Value(logFieldsKey{}).(*logFields)
	if !ok {
		return nil, errors.New("Context is missing log fields")
	}
	return logFields, nil
}

func IncContextLogField(ctx context.Context, field string) error {
	logFields, err := getLogFields(ctx)
	if err != nil {
		return err
	}

	logFields.Lock()
	defer logFields.Unlock()

	fieldValue, found := logFields.Fields[field]
	if !found {
		logFields.Fields[field] = 1
	} else {
		fieldInt, ok := fieldValue.(int)
		if !ok {
			return errors.Errorf("Field %q isn't an int", field)
		}
		logFields.Fields[field] = fieldInt + 1
	}

	return nil
}

func AddContextLogField(ctx context.Context, field string, value float64) error {
	logFields, err := getLogFields(ctx)
	if err != nil {
		return err
	}

	logFields.Lock()
	defer logFields.Unlock()

	fieldValue, found := logFields.Fields[field]
	if !found {
		logFields.Fields[field] = value
	} else {
		fieldInt, ok := fieldValue.(float64)
		if !ok {
			return errors.Errorf("Field %q isn't a float64", field)
		}
		logFields.Fields[field] = fieldInt + value
	}

	return nil
}

func GetContextLogField(ctx context.Context, field string) (interface{}, error) {
	logFields, err := getLogFields(ctx)
	if err != nil {
		return nil, err
	}

	logFields.RLock()
	defer logFields.Unlock()

	return logFields.Fields[field], nil
}

func SetContextLogField(ctx context.Context, field string, value interface{}) error {
	logFields, err := getLogFields(ctx)
	if err != nil {
		return err
	}

	logFields.Lock()
	defer logFields.Unlock()

	logFields.Fields[field] = value
	return nil
}
