package util

import (
	"errors"
	"log"
	"net"
	"net/http"
	"sync"
	"time"

	"github.com/labstack/echo/v4"

	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
)

var (
	connectionPeers   map[string]int
	connectionPeersMu sync.Mutex

	requestPeers   map[string]int
	requestPeersMu sync.Mutex
)

// NewLimitListener returns limitListener, a copy of LimitListener which immediately aborts if no t in Accept().
func NewLimitListener(l net.Listener, limit int) net.Listener {
	return &limitListener{
		Listener: l,
		sem:      make(chan struct{}, limit),
		done:     make(chan struct{}),
	}
}

type limitListener struct {
	net.Listener
	sem  chan struct{}
	done chan struct{}
}

func (l *limitListener) acquire() bool {
	select {
	case <-l.done:
		return false
	case l.sem <- struct{}{}:
		return true
	}
}

func (l *limitListener) release() {
	<-l.sem
}

func (l *limitListener) count() int {
	return len(l.sem)
}

func (l *limitListener) Accept() (net.Conn, error) {
	if !l.acquire() {
		log.Printf("server is closing\n")
		_ = l.Listener.Close()
		return nil, errors.New("do not Accept(): server is closing")
	}
	c, err := l.Listener.Accept()
	if err != nil {
		log.Printf("ERROR: failed to accept: %v\n", err)
		l.release()
		return nil, err
	}
	metrics.MaxConcurrentConnections.Report(uint64(l.count()))
	addr := c.RemoteAddr().String()
	addConnectionPeers(addr, 1)
	log.Printf("INFO: new connection from: %v\n", addr)
	return &limitListenerConn{
		Conn:    c,
		release: l.release,
	}, nil
}

func (l *limitListener) Close() error {
	err := l.Listener.Close()
	close(l.done)
	return err
}

type limitListenerConn struct {
	net.Conn
	releaseOnce sync.Once
	release     func()
}

func (l *limitListenerConn) Close() error {
	err := l.Conn.Close()
	l.releaseOnce.Do(l.release)
	addr := l.Conn.RemoteAddr().String()
	addConnectionPeers(addr, -1)
	log.Printf("INFO: closed connection from: %v\n", addr)
	return err
}

func addConnectionPeers(addr string, delta int) {
	connectionPeersMu.Lock()
	defer connectionPeersMu.Unlock()
	if connectionPeers == nil {
		connectionPeers = make(map[string]int)
	}
	connectionPeers[addr] += delta
	if connectionPeers[addr] == 0 {
		delete(connectionPeers, addr)
	}
}

func GetConnectionPeers() map[string]int {
	connectionPeersMu.Lock()
	defer connectionPeersMu.Unlock()

	connectionPeersCopy := make(map[string]int)
	for k, v := range connectionPeers {
		connectionPeersCopy[k] = v
	}
	return connectionPeersCopy
}

type requestLimiter struct {
	sem     chan struct{}
	timeout time.Duration
}

func (l *requestLimiter) acquire() bool {
	timer := time.NewTimer(l.timeout)
	defer timer.Stop()
	select {
	case l.sem <- struct{}{}:
		return true
	case <-timer.C:
		return false
	}
}

func (l *requestLimiter) release() {
	<-l.sem
}

func (l *requestLimiter) count() int {
	return len(l.sem)
}

// ConcurrentRequestLimiter returns echo middleware which limits the number of concurrently processed requests.
func ConcurrentRequestLimiter(limit int, timeout time.Duration) echo.MiddlewareFunc {
	l := &requestLimiter{
		sem:     make(chan struct{}, limit),
		timeout: timeout,
	}
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) (err error) {
			if !l.acquire() {
				metrics.RejectedConcurrentRequests.Inc()
				log.Printf("ERROR: server is overloaded, concurrent requests hit\n")
				return c.String(http.StatusServiceUnavailable, "Server is overloaded, request limit hit")
			}
			defer l.release()
			metrics.MaxConcurrentRequests.Report(uint64(l.count()))
			addRequestPeers(c.RealIP(), 1)
			defer addRequestPeers(c.RealIP(), -1)
			return next(c)
		}
	}
}

func addRequestPeers(addr string, delta int) {
	requestPeersMu.Lock()
	defer requestPeersMu.Unlock()
	if requestPeers == nil {
		requestPeers = make(map[string]int)
	}
	requestPeers[addr] += delta
	if requestPeers[addr] == 0 {
		delete(requestPeers, addr)
	}
}

func GetRequestPeers() map[string]int {
	requestPeersMu.Lock()
	defer requestPeersMu.Unlock()

	requestPeersCopy := make(map[string]int)
	for k, v := range requestPeers {
		requestPeersCopy[k] = v
	}
	return requestPeers
}
