package loggers

import (
	"bytes"
	"crypto/rand"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"sync"
	"testing"
	"time"
)

type slowWriter struct {
	w     io.Writer
	latch chan bool
}

func (s *slowWriter) Write(b []byte) (int, error) {
	<-s.latch
	return s.w.Write(b)
}

func ExampleRingBufferWriter() {
	msgs := []string{
		"a",
		"b",
		"c",
		"d",
	}
	bufferSize := 2
	wrappedWriter := &slowWriter{
		w:     os.Stdout,
		latch: make(chan bool),
	}
	var onDropped DroppedItemCallback
	bufferedOut, _ := NewRingBufferWriter(bufferSize, wrappedWriter, onDropped)
	for _, msg := range msgs {
		fmt.Fprintln(bufferedOut, msg)
	}
	close(wrappedWriter.latch)
	bufferedOut.Close()

	// Output:
	// c
	// d
}

func ExampleRingBufferWriter_withJsonLogger() {
	bufferSize := 2
	wrappedWriter := os.Stdout
	onDropped := func() {
		// do something whnever an item is dropped
	}
	bufferedOut, _ := NewRingBufferWriter(bufferSize, wrappedWriter, onDropped)
	encoder := json.NewEncoder(bufferedOut)
	jsonLogger := &JSONLogger{
		Dest:    encoder,
		OnError: nil,
	}
	decoratedLogger := With(jsonLogger, "MyKey", "MyValue")
	decoratedLogger.Log("Some Message!")
	bufferedOut.Close()

	//Output: {"MyKey":"MyValue","msg":"Some Message!"}
}

func TestInsertToRing(t *testing.T) {
	ring := newConcurrentRingBuffer(4, nil)
	for i := 0; i < 4; i++ {
		msg := []byte(fmt.Sprintf("%d", i))
		ring.push(msg)
	}

	for i := 0; i < 4; i++ {
		msg := string(ring.poll())
		expected := fmt.Sprintf("%d", i)
		if msg != expected {
			t.Fatalf("expected %s, got %s", expected, msg)
		}
	}
}

func TestWrapAroundCallback(t *testing.T) {
	numDropped := 0
	ring := newConcurrentRingBuffer(1, func() {
		numDropped++
	})

	for i := 0; i < 5; i++ {
		ring.push([]byte{0})
	}

	if numDropped != 4 {
		t.Fatalf("expected the dropped items callback to called 4 times")
	}
}

func TestBufferWrapAround(t *testing.T) {
	ring := newConcurrentRingBuffer(4, nil)
	for i := 0; i < 5; i++ {
		msg := []byte(fmt.Sprintf("%d", i))
		ring.push(msg)
	}

	msg := string(ring.poll())
	if msg != "1" {
		t.Fatalf("expected 1, got %s", msg)
	}
	for i := 2; i < 5; i++ {
		msg := string(ring.poll())
		expected := fmt.Sprintf("%d", i)
		if msg != expected {
			t.Fatalf("expected %s, got %s", expected, msg)
		}
	}
}

func TestBlockingPoll(t *testing.T) {
	c := make(chan []byte)
	ring := newConcurrentRingBuffer(2, nil)
	go func() {
		c <- ring.poll()
	}()

	input := []byte("foo")
	ring.push(input)
	select {
	case res := <-c:
		if string(res) != "foo" {
			t.Fatalf("expected foo, got %s", string(res))
		}
	case <-time.After(time.Second):
		t.Fatal()
	}
}

func TestCancelBlockingPoll(t *testing.T) {
	wg := &sync.WaitGroup{}
	ring := newConcurrentRingBuffer(2, nil)

	wg.Add(1)
	go func() {
		for {
			d := ring.poll()
			if d != nil {
				t.Fatalf("received an unexpected message")
			}
			wg.Done()
			return
		}
	}()

	ring.cancel()
	wg.Wait()
}

func TestDrainAfterCancel(t *testing.T) {
	wg := &sync.WaitGroup{}
	c := make(chan []byte)
	ring := newConcurrentRingBuffer(3, nil)

	wg.Add(1)
	go func() {
		for {
			d := ring.poll()
			if d != nil {
				c <- d
			} else {
				wg.Done()
				return
			}
		}
	}()

	for i := 0; i < 2; i++ {
		ring.push([]byte(fmt.Sprintf("%d", i)))
	}
	ring.cancel()

	if s := string(<-c); s != "0" {
		t.Fatalf("expected 0, got %s", s)
	}
	if s := string(<-c); s != "1" {
		t.Fatalf("expected 1, got %s", s)
	}
	wg.Wait()
}

func TestPushAfterCancel(t *testing.T) {
	ring := newConcurrentRingBuffer(1, nil)
	ring.cancel()

	err := ring.push([]byte("foo"))
	if err != ErrCanceled {
		t.Fatalf("expected %+v, got %+v", ErrCanceled, err)
	}
}

func TestWriteToRingBuffer(t *testing.T) {
	buf := &bytes.Buffer{}

	rbw, _ := NewRingBufferWriter(1<<10, buf, nil)
	fmt.Fprintln(rbw, "hello world")
	rbw.Close()

	contents := string(buf.Bytes())
	if contents != "hello world\n" {
		t.Fatalf("expected 'hello world', got '%s'", contents)
	}
}

func TestOverwriteRingBuffer(t *testing.T) {
	buf := &bytes.Buffer{}
	sw := &slowWriter{
		w:     buf,
		latch: make(chan bool),
	}
	rbw, _ := NewRingBufferWriter(2, sw, nil)

	for i := 0; i < 4; i++ {
		fmt.Fprintf(rbw, "%d\n", i)
	}
	sw.latch <- false
	sw.latch <- false
	rbw.Close()

	actual := string(buf.Bytes())
	expected := "2\n3\n"
	if expected != actual {
		t.Fatalf("expected '%s', got '%s'", expected, actual)
	}
}

func TestReadPointerMovesToOldestEntryOnOverwrite(t *testing.T) {
	buf := &bytes.Buffer{}
	sw := &slowWriter{
		w:     buf,
		latch: make(chan bool),
	}
	rbw, _ := NewRingBufferWriter(2, sw, nil)

	for i := 0; i < 5; i++ {
		fmt.Fprintf(rbw, "%d\n", i)
	}
	sw.latch <- false
	sw.latch <- false
	rbw.Close()

	actual := string(buf.Bytes())
	expected := "3\n4\n"
	if expected != actual {
		t.Fatalf("expected '%s', got '%s'", expected, actual)
	}
}

func BenchmarkWriteToRingBuffer(b *testing.B) {
	rbw, _ := NewRingBufferWriter(1<<20, ioutil.Discard, nil)
	randomMsg := make([]byte, 10*1<<10)
	rand.Read(randomMsg)
	b.ResetTimer()

	b.RunParallel(func(pb *testing.PB) {
		for pb.Next() {
			rbw.Write(randomMsg)
		}
	})

	rbw.Close()
}
