From 1797994ad8db94665ec9838d978b35da57a2ee7f Mon Sep 17 00:00:00 2001 From: Adam Luzsi Date: Wed, 18 May 2022 22:16:23 +0200 Subject: [PATCH] add support for comparing values in assert.Asserter with IsEqual method --- assert/Asserter.go | 108 ++++++++++++++++++++++++++++------------ assert/Asserter_test.go | 82 +++++++++++++++++++++++++++++- assert/example_test.go | 52 +++++++++++++++++++ 3 files changed, 208 insertions(+), 34 deletions(-) diff --git a/assert/Asserter.go b/assert/Asserter.go index 4226fb1..c722802 100644 --- a/assert/Asserter.go +++ b/assert/Asserter.go @@ -3,11 +3,12 @@ package assert import ( "errors" "fmt" - "github.com/adamluzsi/testcase/internal" "reflect" "strings" "testing" + "github.com/adamluzsi/testcase/internal" + "github.com/adamluzsi/testcase/internal/fmterror" ) @@ -149,6 +150,14 @@ func (a Asserter) NotPanic(blk func(), msg ...interface{}) { }) } +type equalable[T any] interface { + IsEqual(oth T) bool +} + +type equalableWithError[T any] interface { + IsEqual(oth T) (bool, error) +} + func (a Asserter) Equal(expected, actually interface{}, msg ...interface{}) { a.TB.Helper() if a.eq(expected, actually) { @@ -193,9 +202,50 @@ func (a Asserter) NotEqual(v, oth interface{}, msg ...interface{}) { } func (a Asserter) eq(exp, act interface{}) bool { + if isEqual, ok := a.tryIsEqual(exp, act); ok { + return isEqual + } + return reflect.DeepEqual(exp, act) } +func (a Asserter) tryIsEqual(exp, act interface{}) (isEqual bool, ok bool) { + defer func() { recover() }() + expRV := reflect.ValueOf(exp) + actRV := reflect.ValueOf(act) + + if expRV.Type() != actRV.Type() { + return false, false + } + + method := expRV.MethodByName("IsEqual") + methodType := method.Type() + + if methodType.NumIn() != 1 { + return false, false + } + if numOut := methodType.NumOut(); !(numOut == 1 || numOut == 2) { + return false, false + } + if methodType.In(0) != actRV.Type() { + return false, false + } + + res := method.Call([]reflect.Value{actRV}) + + switch { + case methodType.NumOut() == 1: // IsEqual(T) (bool) + return res[0].Bool(), true + + case methodType.NumOut() == 2: // IsEqual(T) (bool, error) + Must(a.TB).Nil(res[1].Interface()) + return res[0].Bool(), true + + default: + return false, false + } +} + func (a Asserter) Contain(src, has interface{}, msg ...interface{}) { a.TB.Helper() rSrc := reflect.ValueOf(src) @@ -575,55 +625,49 @@ func (a Asserter) AnyOf(blk func(a *AnyOf), msg ...interface{}) { blk(anyOf) } -// Empty gets whether the specified value is considered empty. -func (a Asserter) Empty(v interface{}, msg ...interface{}) { - a.TB.Helper() - - fail := func() { - a.Fn(fmterror.Message{ - Method: "Empty", - Cause: "Value was expected to be empty.", - Values: []fmterror.Value{ - {Label: "value", Value: v}, - }, - UserMessage: msg, - }) - } - +func (a Asserter) isEmpty(v any) bool { if v == nil { - return + return true } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Chan, reflect.Map, reflect.Slice: - if rv.Len() != 0 { - fail() - } + return rv.Len() == 0 + case reflect.Array: zero := reflect.New(rv.Type()).Elem().Interface() - if !a.eq(zero, v) { - fail() - } + return a.eq(zero, v) case reflect.Ptr: if rv.IsNil() { - return - } - if !a.try(func(a Asserter) { a.Empty(rv.Elem().Interface()) }) { - fail() + return true } + return a.isEmpty(rv.Elem().Interface()) default: - if !a.eq(reflect.Zero(rv.Type()).Interface(), v) { - fail() - } + return a.eq(reflect.Zero(rv.Type()).Interface(), v) } } +// Empty gets whether the specified value is considered empty. +func (a Asserter) Empty(v interface{}, msg ...interface{}) { + a.TB.Helper() + if a.isEmpty(v) { + return + } + a.Fn(fmterror.Message{ + Method: "Empty", + Cause: "Value was expected to be empty.", + Values: []fmterror.Value{ + {Label: "value", Value: v}, + }, + UserMessage: msg, + }) +} + // NotEmpty gets whether the specified value is considered empty. func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) { a.TB.Helper() - if !a.try(func(a Asserter) { a.Empty(v) }) { return } @@ -639,7 +683,7 @@ func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) { // ErrorIs allows you to assert an error value by an expectation. // if the implementation of the test subject later changes, and for example, it starts to use wrapping, -// this should not be an issue as the err's error chain is also matched against the expectation. +// this should not be an issue as the IsEqualErr's error chain is also matched against the expectation. func (a Asserter) ErrorIs(expected, actual error, msg ...interface{}) { a.TB.Helper() diff --git a/assert/Asserter_test.go b/assert/Asserter_test.go index 075db9c..d07a02a 100644 --- a/assert/Asserter_test.go +++ b/assert/Asserter_test.go @@ -3,12 +3,13 @@ package assert_test import ( "errors" "fmt" - "github.com/adamluzsi/testcase" - "github.com/adamluzsi/testcase/internal" "reflect" "strings" "testing" + "github.com/adamluzsi/testcase" + "github.com/adamluzsi/testcase/internal" + "github.com/adamluzsi/testcase/assert" "github.com/adamluzsi/testcase/random" ) @@ -314,6 +315,54 @@ func TestAsserter_Equal(t *testing.T) { Actual: []byte("foo"), IsFailed: true, }, + { + Desc: "when value implements equalable and the two value is equal by IsEqual", + Expected: ExampleEqualable{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + }, + Actual: ExampleEqualable{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 24, + }, + IsFailed: false, + }, + { + Desc: "when value implements equalable and the two value is not equal by IsEqual", + Expected: ExampleEqualable{ + relevantUnexportedValue: 24, + IrrelevantExportedField: 42, + }, + Actual: ExampleEqualable{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + }, + IsFailed: true, + }, + { + Desc: "when value implements equalableWithError and the two value is equal by IsEqual", + Expected: ExampleEqualableWithError{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + }, + Actual: ExampleEqualableWithError{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 4242, + }, + IsFailed: false, + }, + { + Desc: "when value implements equalableWithError and the two value is not equal by IsEqual", + Expected: ExampleEqualableWithError{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + }, + Actual: ExampleEqualableWithError{ + relevantUnexportedValue: 4242, + IrrelevantExportedField: 42, + }, + IsFailed: true, + }, //{ // Desc: "when equal function provided", // Expected: fn1, @@ -352,6 +401,35 @@ func TestAsserter_Equal(t *testing.T) { } } +func TestAsserter_Equal_equalableWithError_ErrorReturned(t *testing.T) { + t.Log("when value implements equalableWithError and IsEqual returns an error") + + expected := ExampleEqualableWithError{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + IsEqualErr: errors.New("boom"), + } + + actual := ExampleEqualableWithError{ + relevantUnexportedValue: 42, + IrrelevantExportedField: 42, + } + + stub := &testcase.StubTB{} + + internal.Recover(func() { + a := assert.Asserter{ + TB: stub, + Fn: stub.Fatal, + } + + a.Equal(expected, actual) + }) + if !stub.IsFailed { + t.Fatal("expected that testing.TB is failed because the returned error") + } +} + func TestAsserter_NotEqual(t *testing.T) { type TestCase struct { Desc string diff --git a/assert/example_test.go b/assert/example_test.go index 7448339..c485f35 100644 --- a/assert/example_test.go +++ b/assert/example_test.go @@ -248,3 +248,55 @@ func ExampleAsserter_ErrorIs() { assert.Must(tb).ErrorIs(errors.New("boom"), actualErr) // passes for equality assert.Must(tb).ErrorIs(errors.New("boom"), fmt.Errorf("wrapped error: %w", actualErr)) // passes for wrapped errors } + +type ExampleEqualable struct { + IrrelevantExportedField int + relevantUnexportedValue int +} + +func (es ExampleEqualable) IsEqual(oth ExampleEqualable) bool { + return es.relevantUnexportedValue == oth.relevantUnexportedValue +} + +func ExampleAsserter_Equal_isEqualFunctionUsedForComparison() { + var tb testing.TB + + expected := ExampleEqualable{ + IrrelevantExportedField: 42, + relevantUnexportedValue: 24, + } + + actual := ExampleEqualable{ + IrrelevantExportedField: 4242, + relevantUnexportedValue: 24, + } + + assert.Must(tb).Equal(expected, actual) // passes as by IsEqual terms the two value is equal +} + +type ExampleEqualableWithError struct { + IrrelevantExportedField int + relevantUnexportedValue int + IsEqualErr error +} + +func (es ExampleEqualableWithError) IsEqual(oth ExampleEqualableWithError) (bool, error) { + return es.relevantUnexportedValue == oth.relevantUnexportedValue, es.IsEqualErr +} + +func ExampleAsserter_Equal_isEqualFunctionThatSupportsErrorReturning() { + var tb testing.TB + + expected := ExampleEqualableWithError{ + IrrelevantExportedField: 42, + relevantUnexportedValue: 24, + IsEqualErr: errors.New("sadly something went wrong"), + } + + actual := ExampleEqualableWithError{ + IrrelevantExportedField: 42, + relevantUnexportedValue: 24, + } + + assert.Must(tb).Equal(expected, actual) // fails because the error returned from the IsEqual function. +}