package shutdown

import (
	"errors"
	"os"
	"sync"
	"sync/atomic"
	"syscall"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestManager(t *testing.T) {
	expected := errors.New("expected error")

	t.Run("Should shutdown cleanly without hooks", func(t *testing.T) {
		assert.Empty(t, NewManager().Shutdown())
	})

	t.Run("Should shutdown cleanly with clean hooks", func(t *testing.T) {
		m := NewManager()
		m.RegisterHook("test", func() error { return nil })
		m.RegisterHook("test2", func() error { return nil })
		assert.Empty(t, m.Shutdown())
	})

	t.Run("Should count shutdown errors", func(t *testing.T) {
		m := NewManager()
		m.RegisterHook("test", func() error { return expected })
		m.RegisterHook("test2", func() error { return nil })
		assert.Equal(t, []error{expected}, m.Shutdown())
	})

	t.Run("Should reset count when called", func(t *testing.T) {
		m := NewManager()
		m.RegisterHook("test", func() error { return expected })
		assert.Equal(t, []error{expected}, m.Shutdown())
		assert.Empty(t, m.Shutdown())
	})

	t.Run("Should report shutdown errors to the Reporter", func(t *testing.T) {
		m := NewManager()
		var errored interface{}
		m.SetOnError(func(key interface{}, err error) { errored = key })
		m.RegisterHook("test", func() error { return expected })
		m.RegisterHook("test2", func() error { return nil })
		assert.Equal(t, []error{expected}, m.Shutdown())
		assert.Equal(t, "test", errored)
	})

	t.Run("Should allow early execution of a hook", func(t *testing.T) {
		m := NewManager()
		var errored interface{}
		m.SetOnError(func(key interface{}, err error) { errored = key })
		m.RegisterHook("test", func() error { return expected })
		assert.Equal(t, expected, m.ExecuteHook("test"))
		assert.Equal(t, "test", errored)
		assert.Empty(t, m.Shutdown())
	})

	t.Run("Should ignore reexecution of a hook", func(t *testing.T) {
		m := NewManager()
		m.RegisterHook("test", func() error { return errors.New("broken") })
		assert.Equal(t, errors.New("broken"), m.ExecuteHook("test"))
		assert.Equal(t, nil, m.ExecuteHook("test"))
		m.RegisterHook("test", func() error { return errors.New("broken") })
		assert.Equal(t, errors.New("broken"), m.ExecuteHook("test"))
		assert.Equal(t, nil, m.ExecuteHook("test"))
		assert.Equal(t, []error{}, m.Shutdown())
	})

	t.Run("Should allow reentry during hook execution", func(t *testing.T) {
		m := NewManager()
		m.RegisterHook("test", func() error { m.Shutdown(); return nil })
		m.RegisterHook("test2", func() error { m.Shutdown(); return expected })
		assert.Equal(t, []error{expected}, m.Shutdown())
	})

	t.Run("Should execute hooks in FILO/stack order", func(t *testing.T) {
		m := NewManager()
		count := 0
		m.RegisterHook("test", func() error {
			assert.Equal(t, count, 1)
			count = count + 1
			return nil
		})
		m.RegisterHook("test2", func() error { return expected })
		assert.Equal(t, expected, m.ExecuteHook("test2"))
		m.RegisterHook("test3", func() error {
			assert.Equal(t, count, 0)
			count = count + 1
			return nil
		})
		m.RegisterHook("test4", func() error {
			return expected
		})
		assert.Equal(t, []error{expected}, m.Shutdown())
	})

	t.Run("Should allow registration of background threads", func(t *testing.T) {
		m := NewManager()
		m.RunUntilComplete(func() {})
		assert.Equal(t, []error{}, m.Shutdown())
	})

	t.Run("Should block on background threads", func(t *testing.T) {
		m := NewManager()
		var background sync.WaitGroup
		background.Add(1)
		m.RunUntilComplete(func() { background.Wait() })

		var completed uint32
		go func() {
			m.Shutdown()
			atomic.StoreUint32(&completed, 1)
		}()
		time.Sleep(10 * time.Millisecond)
		assert.Empty(t, atomic.LoadUint32(&completed))
		background.Done()
		time.Sleep(10 * time.Millisecond)
		assert.NotEmpty(t, atomic.LoadUint32(&completed))
	})

	t.Run("Should allow waiting for system calls", func(t *testing.T) {
		proc, err := os.FindProcess(os.Getpid())
		require.NoError(t, err, "Unable to get process id for tests, can't test interrupt functionality")

		m := NewManager()
		var completed uint32
		go func() {
			assert.Equal(t, syscall.SIGUSR2, m.ListenForInterrupt())
			atomic.StoreUint32(&completed, 1)
		}()
		time.Sleep(10 * time.Millisecond)
		assert.Empty(t, atomic.LoadUint32(&completed))

		proc.Signal(syscall.SIGUSR2)
		time.Sleep(10 * time.Millisecond)
		assert.NotEmpty(t, atomic.LoadUint32(&completed))
	})

	t.Run("Should allow waiting for completion (all threads finish)", func(t *testing.T) {
		m := NewManager()
		m.RunUntilComplete(func() {})
		m.WaitForCompletion(time.Now().Add(time.Hour))
	})

	t.Run("Should allow waiting for completion (time expired)", func(t *testing.T) {
		m := NewManager()
		run := int32(1)
		m.RunUntilComplete(func() {
			for atomic.LoadInt32(&run) != 0 {
			}
		})
		m.WaitForCompletion(time.Now().Add(time.Millisecond))
		atomic.StoreInt32(&run, 0)
	})
}
