package twirp

import (
	"reflect"
	"testing"

	"golang.org/x/net/context"
)

func TestChainHooks(t *testing.T) {
	var (
		hook1 = new(ServerHooks)
		hook2 = new(ServerHooks)
		hook3 = new(ServerHooks)
	)

	const key = "key"

	hook1.RequestReceived = func(ctx context.Context) context.Context {
		return context.WithValue(ctx, key, []string{"hook1"})
	}
	hook2.RequestReceived = func(ctx context.Context) context.Context {
		v := ctx.Value(key).([]string)
		return context.WithValue(ctx, key, append(v, "hook2"))
	}
	hook3.RequestReceived = func(ctx context.Context) context.Context {
		v := ctx.Value(key).([]string)
		return context.WithValue(ctx, key, append(v, "hook3"))
	}

	hook1.RequestRouted = func(ctx context.Context) context.Context {
		return context.WithValue(ctx, key, []string{"hook1"})
	}

	hook2.ResponsePrepared = func(ctx context.Context) context.Context {
		return context.WithValue(ctx, key, []string{"hook2"})
	}

	chain := ChainHooks(hook1, hook2, hook3)

	ctx := context.Background()

	// When all three chained hooks have a handler, all should be called in order.
	want := []string{"hook1", "hook2", "hook3"}
	have := chain.RequestReceived(ctx).Value(key)
	if !reflect.DeepEqual(want, have) {
		t.Errorf("RequestReceived chain has unexpected ctx, have=%v, want=%v", have, want)
	}

	// When only the first chained hook has a handler, it should be called, and
	// there should be no panic.
	want = []string{"hook1"}
	have = chain.RequestRouted(ctx).Value(key)
	if !reflect.DeepEqual(want, have) {
		t.Errorf("RequestRouted chain has unexpected ctx, have=%v, want=%v", have, want)
	}

	// When only the second chained hook has a handler, it should be called, and
	// there should be no panic.
	want = []string{"hook2"}
	have = chain.ResponsePrepared(ctx).Value(key)
	if !reflect.DeepEqual(want, have) {
		t.Errorf("RequestRouted chain has unexpected ctx, have=%v, want=%v", have, want)
	}

	// When none of the chained hooks has a handler there should be no panic.
	have = chain.ResponseSent(ctx).Value(key)
	if have != nil {
		t.Errorf("RequestRouted chain has unexpected ctx, have=%v, want=%v", have, nil)
	}
}
