package repro

import (
	"flag"
	"fmt"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"unsafe"
)

var (
	cost        = flag.Int("cost", 1000, "loop cost")
	ratio       = flag.Float64("ratio", 0.01, "fration of loops with lock held")
	concurrency = flag.String("concurrency", "", "number of concurrent workers")

	sink interface{}
)

func BenchmarkLock(b *testing.B) {

	const loopCount = 1000

	testcase := func(workers int, mu sync.Locker, loopCost, withLock int, sem semaphore) func(b *testing.B) {
		return func(b *testing.B) {
			var (
				wg sync.WaitGroup

				buf = make([]byte, loopCost)
				out byte

				start = make(chan struct{})
			)

			for i := 0; i < workers; i++ {
				work := b.N / workers
				if i < b.N-(b.N/workers)*workers {
					// division rounded down, now's the time to round up
					work++
				}

				wg.Add(1)
				go func() {
					defer wg.Done()

					<-start

					sem.Acquire()
					defer sem.Release()

					for i := 0; i < work; i++ {
						var v byte
						for k := 0; k < loopCount-withLock; k++ {
							for l := range buf {
								v += buf[l]
							}
						}
						mu.Lock()
						for k := 0; k < withLock; k++ {
							for k := range buf {
								v += buf[k]
							}
						}
						out = v
						mu.Unlock()
					}
				}()
			}

			b.ResetTimer()
			close(start)
			wg.Wait()

			sink = out
		}
	}

	suite := func(newLock func() sync.Locker) func(b *testing.B) {
		return func(b *testing.B) {
			workers := 4 * runtime.GOMAXPROCS(0)
			cost, ratio := *cost, *ratio
			var concurrencies []int
			for _, v := range strings.Split(*concurrency, ",") {
				n, err := strconv.Atoi(v)
				if err == nil && n > 0 {
					concurrencies = append(concurrencies, n)
				}
			}
			if len(concurrencies) == 0 {
				concurrencies = []int{workers}
			}
			withLock := int(ratio * loopCount)
			for _, concurrency := range concurrencies {
				b.Run(fmt.Sprintf("cost=%d,ratio=%0.3f,concurrency=%d", cost, float64(withLock)/loopCount, concurrency),
					func(b *testing.B) {
						b.Run("basic", testcase(workers, func() sync.Locker {
							mu := newLock()
							mu = &lockTracer{Locker: mu}
							return mu
						}(), cost, withLock, make(semaphore, concurrency)))
						b.Run("gosched", testcase(workers, func() sync.Locker {
							mu := newLock()
							mu = &lockTracer{Locker: mu}
							mu = &lockYielder{Locker: mu}
							return mu
						}(), cost, withLock, make(semaphore, concurrency)))
					})
			}
		}
	}

	b.Run("sync.Mutex", suite(func() sync.Locker { return new(sync.Mutex) }))
	b.Run("mcs", suite(func() sync.Locker { return new(mcs) }))
	b.Run("chan", suite(func() sync.Locker { return newChannelLock() }))
}

type semaphore chan struct{}

func (s semaphore) Acquire() {
	if s != nil {
		s <- struct{}{}
	}
}

func (s semaphore) Release() {
	if s != nil {
		<-s
	}
}

type lockYielder struct {
	sync.Locker
}

func (ly *lockYielder) Lock()   { ly.Locker.Lock() }
func (ly *lockYielder) Unlock() { ly.Locker.Unlock(); runtime.Gosched() }

type lockTracer struct {
	sync.Locker
	handoffEnd func()
}

type channelLock chan struct{}

func newChannelLock() channelLock { return make(chan struct{}, 1) }
func (mu channelLock) Lock()      { mu <- struct{}{} }
func (mu channelLock) Unlock()    { <-mu }

type mcs struct {
	tail *qnode

	holder *qnode // protected by the lock itself
}

type qnode struct {
	next *qnode

	wg sync.WaitGroup
}

func (mu *mcs) Lock() {
	node := mu.lock()
	mu.holder = node
}

func (mu *mcs) Unlock() {
	node := mu.holder
	mu.holder = nil
	mu.unlock(node)
}

func (mu *mcs) lock() *qnode {
	self := new(qnode)
	self.wg.Add(1)
	prev := (*qnode)(atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&mu.tail)), unsafe.Pointer(self)))
	if prev != nil {
		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&prev.next)), unsafe.Pointer(self)) // prev.next = self
		// wait for prev node to unlock us
		self.wg.Wait()
	}
	return self
}

func (mu *mcs) unlock(node *qnode) {
	for {
		next := (*qnode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&node.next))))
		if next != nil {
			// known successor, unblock their call to Lock
			next.wg.Done()
			return
		}

		if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&mu.tail)), unsafe.Pointer(node), nil) {
			// no known successor, and in fact there's no successor at all
			return
		}

		// successor hasn't finished adding themselves to the queue
	}
}
