Skip to content

Commit

Permalink
add support for comparing values in assert.Asserter with IsEqual method
Browse files Browse the repository at this point in the history
  • Loading branch information
adamluzsi committed May 18, 2022
1 parent a24ea1d commit 1797994
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 34 deletions.
108 changes: 76 additions & 32 deletions assert/Asserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package assert
import (
"errors"
"fmt"
"github.com/adamluzsi/testcase/internal"
"reflect"
"strings"
"testing"

"github.com/adamluzsi/testcase/internal"

"github.com/adamluzsi/testcase/internal/fmterror"
)

Expand Down Expand Up @@ -149,6 +150,14 @@ func (a Asserter) NotPanic(blk func(), msg ...interface{}) {
})
}

type equalable[T any] interface {
IsEqual(oth T) bool
}

type equalableWithError[T any] interface {
IsEqual(oth T) (bool, error)
}

func (a Asserter) Equal(expected, actually interface{}, msg ...interface{}) {
a.TB.Helper()
if a.eq(expected, actually) {
Expand Down Expand Up @@ -193,9 +202,50 @@ func (a Asserter) NotEqual(v, oth interface{}, msg ...interface{}) {
}

func (a Asserter) eq(exp, act interface{}) bool {
if isEqual, ok := a.tryIsEqual(exp, act); ok {
return isEqual
}

return reflect.DeepEqual(exp, act)
}

func (a Asserter) tryIsEqual(exp, act interface{}) (isEqual bool, ok bool) {
defer func() { recover() }()
expRV := reflect.ValueOf(exp)
actRV := reflect.ValueOf(act)

if expRV.Type() != actRV.Type() {
return false, false
}

method := expRV.MethodByName("IsEqual")
methodType := method.Type()

if methodType.NumIn() != 1 {
return false, false
}
if numOut := methodType.NumOut(); !(numOut == 1 || numOut == 2) {
return false, false
}
if methodType.In(0) != actRV.Type() {
return false, false
}

res := method.Call([]reflect.Value{actRV})

switch {
case methodType.NumOut() == 1: // IsEqual(T) (bool)
return res[0].Bool(), true

case methodType.NumOut() == 2: // IsEqual(T) (bool, error)
Must(a.TB).Nil(res[1].Interface())
return res[0].Bool(), true

default:
return false, false
}
}

func (a Asserter) Contain(src, has interface{}, msg ...interface{}) {
a.TB.Helper()
rSrc := reflect.ValueOf(src)
Expand Down Expand Up @@ -575,55 +625,49 @@ func (a Asserter) AnyOf(blk func(a *AnyOf), msg ...interface{}) {
blk(anyOf)
}

// Empty gets whether the specified value is considered empty.
func (a Asserter) Empty(v interface{}, msg ...interface{}) {
a.TB.Helper()

fail := func() {
a.Fn(fmterror.Message{
Method: "Empty",
Cause: "Value was expected to be empty.",
Values: []fmterror.Value{
{Label: "value", Value: v},
},
UserMessage: msg,
})
}

func (a Asserter) isEmpty(v any) bool {
if v == nil {
return
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Chan, reflect.Map, reflect.Slice:
if rv.Len() != 0 {
fail()
}
return rv.Len() == 0

case reflect.Array:
zero := reflect.New(rv.Type()).Elem().Interface()
if !a.eq(zero, v) {
fail()
}
return a.eq(zero, v)

case reflect.Ptr:
if rv.IsNil() {
return
}
if !a.try(func(a Asserter) { a.Empty(rv.Elem().Interface()) }) {
fail()
return true
}
return a.isEmpty(rv.Elem().Interface())

default:
if !a.eq(reflect.Zero(rv.Type()).Interface(), v) {
fail()
}
return a.eq(reflect.Zero(rv.Type()).Interface(), v)
}
}

// Empty gets whether the specified value is considered empty.
func (a Asserter) Empty(v interface{}, msg ...interface{}) {
a.TB.Helper()
if a.isEmpty(v) {
return
}
a.Fn(fmterror.Message{
Method: "Empty",
Cause: "Value was expected to be empty.",
Values: []fmterror.Value{
{Label: "value", Value: v},
},
UserMessage: msg,
})
}

// NotEmpty gets whether the specified value is considered empty.
func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) {
a.TB.Helper()

if !a.try(func(a Asserter) { a.Empty(v) }) {
return
}
Expand All @@ -639,7 +683,7 @@ func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) {

// ErrorIs allows you to assert an error value by an expectation.
// if the implementation of the test subject later changes, and for example, it starts to use wrapping,
// this should not be an issue as the err's error chain is also matched against the expectation.
// this should not be an issue as the IsEqualErr's error chain is also matched against the expectation.
func (a Asserter) ErrorIs(expected, actual error, msg ...interface{}) {
a.TB.Helper()

Expand Down
82 changes: 80 additions & 2 deletions assert/Asserter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package assert_test
import (
"errors"
"fmt"
"github.com/adamluzsi/testcase"
"github.com/adamluzsi/testcase/internal"
"reflect"
"strings"
"testing"

"github.com/adamluzsi/testcase"
"github.com/adamluzsi/testcase/internal"

"github.com/adamluzsi/testcase/assert"
"github.com/adamluzsi/testcase/random"
)
Expand Down Expand Up @@ -314,6 +315,54 @@ func TestAsserter_Equal(t *testing.T) {
Actual: []byte("foo"),
IsFailed: true,
},
{
Desc: "when value implements equalable and the two value is equal by IsEqual",
Expected: ExampleEqualable{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualable{
relevantUnexportedValue: 42,
IrrelevantExportedField: 24,
},
IsFailed: false,
},
{
Desc: "when value implements equalable and the two value is not equal by IsEqual",
Expected: ExampleEqualable{
relevantUnexportedValue: 24,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualable{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
IsFailed: true,
},
{
Desc: "when value implements equalableWithError and the two value is equal by IsEqual",
Expected: ExampleEqualableWithError{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualableWithError{
relevantUnexportedValue: 42,
IrrelevantExportedField: 4242,
},
IsFailed: false,
},
{
Desc: "when value implements equalableWithError and the two value is not equal by IsEqual",
Expected: ExampleEqualableWithError{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualableWithError{
relevantUnexportedValue: 4242,
IrrelevantExportedField: 42,
},
IsFailed: true,
},
//{
// Desc: "when equal function provided",
// Expected: fn1,
Expand Down Expand Up @@ -352,6 +401,35 @@ func TestAsserter_Equal(t *testing.T) {
}
}

func TestAsserter_Equal_equalableWithError_ErrorReturned(t *testing.T) {
t.Log("when value implements equalableWithError and IsEqual returns an error")

expected := ExampleEqualableWithError{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
IsEqualErr: errors.New("boom"),
}

actual := ExampleEqualableWithError{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
}

stub := &testcase.StubTB{}

internal.Recover(func() {
a := assert.Asserter{
TB: stub,
Fn: stub.Fatal,
}

a.Equal(expected, actual)
})
if !stub.IsFailed {
t.Fatal("expected that testing.TB is failed because the returned error")
}
}

func TestAsserter_NotEqual(t *testing.T) {
type TestCase struct {
Desc string
Expand Down
52 changes: 52 additions & 0 deletions assert/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,55 @@ func ExampleAsserter_ErrorIs() {
assert.Must(tb).ErrorIs(errors.New("boom"), actualErr) // passes for equality
assert.Must(tb).ErrorIs(errors.New("boom"), fmt.Errorf("wrapped error: %w", actualErr)) // passes for wrapped errors
}

type ExampleEqualable struct {
IrrelevantExportedField int
relevantUnexportedValue int
}

func (es ExampleEqualable) IsEqual(oth ExampleEqualable) bool {
return es.relevantUnexportedValue == oth.relevantUnexportedValue
}

func ExampleAsserter_Equal_isEqualFunctionUsedForComparison() {
var tb testing.TB

expected := ExampleEqualable{
IrrelevantExportedField: 42,
relevantUnexportedValue: 24,
}

actual := ExampleEqualable{
IrrelevantExportedField: 4242,
relevantUnexportedValue: 24,
}

assert.Must(tb).Equal(expected, actual) // passes as by IsEqual terms the two value is equal
}

type ExampleEqualableWithError struct {
IrrelevantExportedField int
relevantUnexportedValue int
IsEqualErr error
}

func (es ExampleEqualableWithError) IsEqual(oth ExampleEqualableWithError) (bool, error) {
return es.relevantUnexportedValue == oth.relevantUnexportedValue, es.IsEqualErr
}

func ExampleAsserter_Equal_isEqualFunctionThatSupportsErrorReturning() {
var tb testing.TB

expected := ExampleEqualableWithError{
IrrelevantExportedField: 42,
relevantUnexportedValue: 24,
IsEqualErr: errors.New("sadly something went wrong"),
}

actual := ExampleEqualableWithError{
IrrelevantExportedField: 42,
relevantUnexportedValue: 24,
}

assert.Must(tb).Equal(expected, actual) // fails because the error returned from the IsEqual function.
}

0 comments on commit 1797994

Please sign in to comment.