package hystrixhelper

import (
	"bytes"
	"io/ioutil"
	"net/http"
	"testing"

	"github.com/afex/hystrix-go/hystrix"
	"github.com/pkg/errors"

	. "github.com/smartystreets/goconvey/convey"
)

func TestHystrixRoundTripper(t *testing.T) {
	Convey("Test RoundTrip", t, func() {
		defaultHTTPRequest := generateHTTPRequest([]byte("req"))

		name := "test"
		config := hystrix.CommandConfig{
			Timeout:                1000,
			MaxConcurrentRequests:  100,
			RequestVolumeThreshold: 4,
		}
		hystrix.ConfigureCommand(name, config)

		Convey("when downstream service returns ok", func() {
			successWrapper := HystrixRoundTripWrapper{
				Next: &responseMock{
					statusCode: http.StatusOK,
				},
				Name:       name,
				RetryCount: 2,
			}

			resp, err := successWrapper.RoundTrip(defaultHTTPRequest)
			So(err, ShouldBeNil)
			So(resp, ShouldNotBeNil)
			So(resp.StatusCode, ShouldEqual, http.StatusOK)
			mock := successWrapper.Next.(*responseMock)
			So(mock.callCount, ShouldEqual, 1)
		})

		Convey("when downstream service returns internal server error", func() {
			failWrapper := HystrixRoundTripWrapper{
				Next: &responseMock{
					statusCode: http.StatusInternalServerError,
				},
				Name:       name,
				RetryCount: 2,
			}

			resp, err := failWrapper.RoundTrip(defaultHTTPRequest)
			So(err, ShouldBeNil)
			So(resp, ShouldNotBeNil)
			So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
			mock := failWrapper.Next.(*responseMock)
			So(mock.callCount, ShouldEqual, 3)
		})

		Convey("when downstream service returns error", func() {
			errorWrapper := HystrixRoundTripWrapper{
				Next:       &responseMock{},
				Name:       name,
				RetryCount: 3,
			}

			resp, err := errorWrapper.RoundTrip(defaultHTTPRequest)
			So(err, ShouldNotBeNil)
			So(resp, ShouldBeNil)
			mock := errorWrapper.Next.(*responseMock)
			So(mock.callCount, ShouldEqual, 4)
		})

		Convey("when downstream service returns internal server errors that opens the hystrix circuit", func() {
			failWrapper := HystrixRoundTripWrapper{
				Next: &responseMock{
					statusCode: http.StatusInternalServerError,
				},
				Name:       name,
				RetryCount: 2,
			}

			resp, err := failWrapper.RoundTrip(defaultHTTPRequest)
			So(err, ShouldBeNil)
			So(resp, ShouldNotBeNil)
			So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
			mock := failWrapper.Next.(*responseMock)
			So(mock.callCount, ShouldEqual, 3)

			// call a bunch more times to make sure the circuit trips
			_, _ = failWrapper.RoundTrip(defaultHTTPRequest)
			_, _ = failWrapper.RoundTrip(defaultHTTPRequest)
			_, _ = failWrapper.RoundTrip(defaultHTTPRequest)
			_, _ = failWrapper.RoundTrip(defaultHTTPRequest)

			resp, err = failWrapper.RoundTrip(defaultHTTPRequest)
			So(err, ShouldNotBeNil)
			So(resp, ShouldBeNil)
		})
	})
}

type responseMock struct {
	statusCode int
	callCount  int
}

func (m *responseMock) RoundTrip(req *http.Request) (*http.Response, error) {
	m.callCount = m.callCount + 1
	if m.statusCode == 0 {
		return nil, errors.New("error")
	}
	return generateHTTPResponse(m.statusCode, []byte("resp")), nil
}

func generateHTTPRequest(bodyBytes []byte) *http.Request {
	reqBody := ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
	req, err := http.NewRequest("GET", "myurl", reqBody)
	So(err, ShouldBeNil)
	return req
}

func generateHTTPResponse(statusCode int, bodyBytes []byte) *http.Response {
	respBody := ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
	resp := &http.Response{
		StatusCode: statusCode,
		Body:       respBody,
	}
	return resp
}
