package perp

import (
	"context"
	"runtime"
	"runtime/debug"
	"sync"
	"sync/atomic"
	"testing"
)

func TestNewPoolSize(t *testing.T) {
	testcase := func(n int, before, after, each func()) func(t *testing.T) {
		return func(t *testing.T) {
			ctx := context.Background()

			elts := make([]int64, n)
			p := NewPoolSize(len(elts))

			before()
			defer after()

			for i := 0; i < n; i++ {
				var wg sync.WaitGroup
				start := make(chan struct{})
				for i := 0; i < len(elts)*10; i++ {
					wg.Add(1)
					go func() {
						defer wg.Done()

						<-start
						for i := 0; i < 1+100/n; i++ {
							p.Do(ctx, func(i int) {
								for j := 0; j < 1000; j++ {
									if j == 0 {
										elts[i]++
									}
								}
							})
						}
					}()
				}
				close(start)
				wg.Wait()

				each()
			}

			var unused []int
			for i, n := range elts {
				if n == 0 {
					unused = append(unused, i)
				}
			}
			if len(unused)*2 > len(elts) {
				t.Errorf("used less than half of the pool elements (%d/%d)", len(elts)-len(unused), len(unused))
				t.Logf("index usage: %v", elts)
			}
		}
	}

	noop := func() {}

	procs := runtime.GOMAXPROCS(0)
	gogc := debug.SetGCPercent(-1)
	debug.SetGCPercent(gogc)

	pauseGC := func() { debug.SetGCPercent(-1) }
	restartGC := func() { debug.SetGCPercent(gogc) }

	spawnThreads(procs)

	// ensure that at least a little work is spread across pools of various
	// sizes.
	t.Run("fixed-1", testcase(1, noop, noop, runtime.GC))
	t.Run("fixed-4", testcase(4, noop, noop, runtime.GC))
	t.Run("fixed-40", testcase(40, noop, noop, runtime.GC))
	t.Run("fixed-400", testcase(400, noop, noop, runtime.GC))

	// the normal usage is to set the pool size to equal the number of Ps in
	// the runtime.
	t.Run("GOMAXPROCS", testcase(procs, noop, noop, runtime.GC))

	// confirm that work will be spread across GOMAXPROCS pool elements in
	// between GC cycles.
	t.Run("GOMAXPROCS-noGC", testcase(procs, pauseGC, restartGC, noop))
}

func BenchmarkPool(b *testing.B) {
	type pool interface {
		Do(ctx context.Context, fn func(i int))
	}

	makeFn := func(procs int) func(i int) {
		elts := make([]int64, procs)
		return func(i int) {
			for j := 0; j < 100; j++ {
				if j == 0 {
					elts[i]++
				}
			}
		}
	}

	spin := func() {
		for j := 0; j < 1000; j++ {
		}
	}

	bench := func(p pool, workers int, fn func(i int)) func(b *testing.B) {
		return func(b *testing.B) {
			ctx := context.Background()

			var wg sync.WaitGroup
			start := make(chan struct{})

			for i := 0; i < workers; i++ {
				wg.Add(1)
				go func() {
					defer wg.Done()
					<-start
					for i := 0; i < b.N; i++ {
						spin()
						p.Do(ctx, fn)
					}
				}()
			}
			close(start)
			wg.Wait()
		}
	}

	procs := runtime.GOMAXPROCS(0)
	spawnThreads(procs)

	benchConcurrency := func(workers int) func(b *testing.B) {
		return func(b *testing.B) {
			b.Run("Pool", bench(NewPoolSize(procs), workers, makeFn(procs)))
			b.Run("Pool-double", bench(NewPoolSize(procs*2), workers, makeFn(procs*2)))
			b.Run("LinearPool", bench(newLinearPool(procs), workers, makeFn(procs)))
			b.Run("LockSinglePool", bench(newLockSinglePool(procs), workers, makeFn(procs)))
		}
	}

	b.Run("GOMAXPROCS", benchConcurrency(procs))
	b.Run("GOMAXPROCS*2", benchConcurrency(procs*2))
	b.Run("GOMAXPROCS*4", benchConcurrency(procs*4))
	b.Run("GOMAXPROCS*16", benchConcurrency(procs*16))
	b.Run("GOMAXPROCS*64", benchConcurrency(procs*64))
	b.Run("GOMAXPROCS*256", benchConcurrency(procs*256))
}

func spawnThreads(procs int) {
	var wg sync.WaitGroup
	start := make(chan struct{})
	for i := 0; i < procs; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			runtime.LockOSThread()
			<-start
			runtime.UnlockOSThread()
		}()
	}
	close(start)
	wg.Wait()
}

func newLinearPool(size int) *linearPool {
	p := &linearPool{
		entries: make([]*entry, size),
	}

	for i := 0; i < size; i++ {
		e := &entry{idx: i, ch: make(chan struct{}, 1)}
		e.ch <- struct{}{}
		p.entries[i] = e
	}

	return p
}

type linearPool struct {
	// accessed atomically
	requestCount int32

	entries []*entry
}

func (p *linearPool) pick() *entry {
	req := atomic.AddInt32(&p.requestCount, 1)
	i := int(uint32(req) % uint32(len(p.entries)))
	e := p.entries[i]
	return e
}

func (p *linearPool) Do(ctx context.Context, fn func(i int)) {
	e := p.pick()

	select {
	case <-ctx.Done():
	case <-e.ch:
		defer func() { e.ch <- struct{}{} }()
		fn(e.idx)
	}
}

func newLockSinglePool(size int) *lockSinglePool {
	return &lockSinglePool{}
}

type lockSinglePool struct {
	mu sync.Mutex
}

func (p *lockSinglePool) Do(ctx context.Context, fn func(i int)) {
	defer p.mu.Unlock()
	p.mu.Lock()

	fn(0)
}
