package comparison

import (
	"encoding/json"
	"fmt"
	"reflect"
	"testing"

	. "github.com/smartystreets/goconvey/convey"
)

type NearlyCorrectPreprocessor struct{}

func TestCompareRequiresPreprocessor(t *testing.T) {
	Convey("When comparing JSON data", t, func() {
		json1 := `[{
            "Test1": 2,
            "Test2": "test",
            "Test3": {
                "Field": [
                    1, 2, 3
                ]
            }
        }]`

		json2 := `[{
            "Test1": 2,
            "Test2": "test",
            "Test3": {
                "Field": [
                    3, 0, 4
                ]
            }
        }]`

		vals1 := make([]interface{}, 0)
		vals2 := make([]interface{}, 0)

		err := json.Unmarshal([]byte(json1), &vals1)
		So(err, ShouldBeNil)
		err = json.Unmarshal([]byte(json2), &vals2)
		So(err, ShouldBeNil)

		cmp1 := ComparablesFromJSONObj(vals1)
		cmp2 := ComparablesFromJSONObj(vals2)

		Convey("No preprocessor fails", func() {
			result, err := DeepCompare("test1", nil, cmp1, cmp2)
			So(err, ShouldBeNil)
			So(result, ShouldBeFalse)
		})

		Convey("Preprocessor against wrong method fails", func() {
			result, err := DeepCompare("test2", &NearlyCorrectPreprocessor{}, cmp1, cmp2)
			So(err, ShouldBeNil)
			So(result, ShouldBeFalse)
		})

		Convey("Correct preprocessor succeeds", func() {
			result, err := DeepCompare("test1", &NearlyCorrectPreprocessor{}, cmp1, cmp2)
			So(err, ShouldBeNil)
			So(result, ShouldBeTrue)
		})
	})
}

func (self *NearlyCorrectPreprocessor) Preprocess(methodName string, resultIndex int, fixture *PreprocessFixture) (bool, error) {
	if methodName == "test1" && resultIndex == 0 {
		pairList := fixture.ExtractValues("Test3", "Field", "*")

		for _, pair := range pairList {
			if (pair.Old == nil) != (pair.New == nil) {
				return false, nil
			}

			oldInt, ok := pair.Old.(float64)
			if !ok {
				return false, fmt.Errorf("%v is not a float, instead is %s", pair.Old, reflect.TypeOf(pair.Old).Name())
			}
			newInt, ok := pair.New.(float64)
			if !ok {
				return false, fmt.Errorf("%v is not a float, instead is %s", pair.New, reflect.TypeOf(pair.New).Name())
			}

			diff := oldInt - newInt
			if diff < -2 || diff > 2 {
				return false, nil
			}
		}
	}

	return true, nil
}
