package netpool

import (
	"errors"
	"net"
	"time"
)

// Inspired by: http://dustin.sallings.org/2014/04/25/chan-pool.html

var (
	// Error raised when a connection can't be retrieved from a pool.
	ErrTimeout    = errors.New("timeout waiting to build connection")
	ErrClosedPool = errors.New("the connection pool is closed")
	ErrNoPool     = errors.New("no connection pool")
)

const (
	// Default timeout for retrieving a connection from the pool.
	connPoolTimeout = time.Second * 2

	// connPoolAvailWaitTime is the amount of time to wait for an existing
	// connection from the pool before considering the creation of a new
	// one.
	connPoolAvailWaitTime = time.Millisecond
)

type (
	PoolDialFunc func() (net.Conn, error)

	NetPool struct {
		connections chan net.Conn
		createSem   chan bool
		dial        PoolDialFunc
		poolSize    int
	}
)

func New(dialFunc PoolDialFunc, poolSize int) (*NetPool, error) {
	pool := &NetPool{
		connections: make(chan net.Conn, poolSize),
		createSem:   make(chan bool, poolSize),
		dial:        dialFunc,
		poolSize:    poolSize,
	}

	// Try to get one connection
	conn, err := pool.Acquire()
	if err != nil {
		return nil, err
	}
	pool.Release(conn)

	return pool, nil
}

func (cp *NetPool) Close() (err error) {
	defer func() {
		if recover() != nil {
			err = errors.New("NetPool.Close error")
		}
	}()

	close(cp.connections)
	for c := range cp.connections {
		_ = c.Close()
	}
	return
}

func (cp *NetPool) GetWithTimeout(d time.Duration) (conn net.Conn, err error) {
	if cp == nil {
		return nil, ErrNoPool
	}

	// short-circuit available connections.
	select {
	case conn, isOpen := <-cp.connections:
		if !isOpen {
			return nil, ErrClosedPool
		}

		return conn, nil
	default:
	}

	t := time.NewTimer(connPoolAvailWaitTime)
	defer t.Stop()

	// Try to grab an available connection
	select {
	case conn, IsOpen := <-cp.connections:
		if !IsOpen {
			return nil, ErrClosedPool
		}

		return conn, nil
	case <-t.C:
		// No connection came around in time, let's see
		// whether we can get one or build a new one first.
		t.Reset(d) // Reuse the timer for the full timeout.
		select {
		case conn, isOpen := <-cp.connections:
			if !isOpen {
				return nil, ErrClosedPool
			}

			return conn, nil
		case cp.createSem <- true:
			// Build a connection if we can't get a real one.
			// This can potentially be an overflow connection, or
			// a pooled connection.
			conn, err := cp.dial()
			if err != nil {
				// On error, release our create hold
				<-cp.createSem
			}
			return conn, err
		case <-t.C:
			return nil, ErrTimeout
		}
	}
}

func (cp *NetPool) Acquire() (net.Conn, error) {
	return cp.GetWithTimeout(connPoolTimeout)
}

func (cp *NetPool) Release(c net.Conn) {
	if c == nil {
		return
	}

	if cp == nil {
		_ = c.Close()
		return
	}

	defer func() {
		if recover() != nil {
			// This happens when the pool has already been
			// closed and we're trying to return a
			// connection to it anyway.  Just close the
			// connection.
			_ = c.Close()
		}
	}()

	select {
	case cp.connections <- c:
	default:
		<-cp.createSem
		_ = c.Close()
	}
}

// give the ability to discard a connection from a pool
func (cp *NetPool) Discard(c net.Conn) (ok bool) {
	// just in case we are closing a connection after
	// bailOut has been sent but we haven't yet read it
	defer func() {
		if recover() != nil {
			ok = false
		}
	}()
	ok = true

	_ = c.Close()
	<-cp.createSem
	return
}

// close connection with recovery on error
func (cp *NetPool) PoolSize() int {
	return cp.poolSize
}
