package hitcounter

import (
	"errors"
	"golang.org/x/net/context"
	"sync/atomic"
	"time"
)

type hitCounts map[string]int

type QuotaReader interface {
	GetHitCounts(context.Context, string, int) (map[string]int, error)
}

type HitCountWriter interface {
	UpdateOrInsertHitCount(ctx context.Context, client string, newHits int, bucketDuration time.Duration) error
}

type HitCounter interface {
	Hit(client string) error
	HasQuota(client string) bool
	Stop()
}

type hitCountersImpl struct {
	hitCountsByClient hitCounts
	hits              chan string
	errorSink         chan error
	limits            map[string]int
	over              atomic.Value
}

func StartHitCounters(dbWriter HitCountWriter, flushPeriod time.Duration, errorSink chan error) *hitCountersImpl {
	h := hitCountersImpl{
		hitCountsByClient: make(map[string]int),
		hits:              make(chan string, 100),
		errorSink:         errorSink,
	}

	go h.run(flushPeriod, dbWriter)

	return &h
}

func (h *hitCountersImpl) Hit(client string) error {

	select {
	case h.hits <- client:
		return nil
	default:
		return errors.New("Hot not counted. Blocked Channel.")
	}
}

func (h *hitCountersImpl) Stop() {
	close(h.hits)
}

// Internal routines
func (h *hitCountersImpl) run(flushPeriod time.Duration, dbWriter HitCountWriter) {

	// Channel for flushes
	flush := make(chan hitCounts, 10)
	defer close(flush)

	// Ticker to flush the map every period
	ticker := time.NewTicker(flushPeriod)
	defer ticker.Stop()

	// Routine to apply flushes to the database
	// Receive the map from the flush channel and apply each element to the DB.
	go func() {
		for counts := range flush {
			for k := range counts {
				err := dbWriter.UpdateOrInsertHitCount(context.Background(), k, counts[k], flushPeriod)
				if err != nil {
					h.errorSink <- err
				}
			}
		}
	}()

	for {
		select {
		case _, isOpen := <-ticker.C:
			if !isOpen {
				return
			}

			flush <- h.hitCountsByClient

			// create a new map as the old one was passed by reference
			h.hitCountsByClient = make(map[string]int)

		case client, isOpen := <-h.hits:
			if !isOpen {
				return
			}
			// Add this client to the list
			currentCount, prs := h.hitCountsByClient[client]
			if !prs {
				currentCount = 0
			}
			h.hitCountsByClient[client] = currentCount + 1
		}
	}
}

func (h *hitCountersImpl) HasQuota(client string) bool {
	m1, ok := h.over.Load().(map[string]bool)
	if !ok {
		// not initialized
		return true
	}

	q, ok := m1[client]
	if !ok {
		// no quota for that client
		return true
	}

	return !q
}

func (h *hitCountersImpl) updateQuotas(quotaReader QuotaReader, interval_minutes int) (map[string]bool, error) {
	current, err := quotaReader.GetHitCounts(context.Background(), "", interval_minutes)
	if err != nil {
		return nil, err
	}

	overs := make(map[string]bool)
	for k, v := range current {
		if setting, ok := h.limits[k]; ok {
			overs[k] = v > setting
		} else if setting, ok = h.limits["default"]; ok {
			overs[k] = v > setting
		}
	}

	return overs, nil
}

func (h *hitCountersImpl) StartUpdateQuotas(limits map[string]int, quotaReader QuotaReader, frequency_seconds int, interval_minutes int) {

	h.limits = limits

	ticker := time.NewTicker(time.Duration(frequency_seconds) * time.Second)
	go func() {
		for _ = range ticker.C {
			newOvers, err := h.updateQuotas(quotaReader, interval_minutes)
			if err != nil {
				h.errorSink <- err
				return
			}
			h.over.Store(newOvers)
		}
	}()
}
