diff --git a/assert/assert.go b/assert/assert.go index ba557d1..ec70c33 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -2,6 +2,7 @@ package assert import ( + "errors" "fmt" "reflect" "testing" @@ -47,6 +48,23 @@ func Error(t *testing.T, err error, expectedError string, errorMessage ...string fail(t, msg, errorMessage...) } +// ErrorIs asserts that a function returned an error that matches the specified error. +func ErrorIs(t *testing.T, err, expectedError error, errorMessage ...string) { + t.Helper() + if err == nil { + msg := fmt.Sprintf("Error not returned: \nexpected: %v\nactual : nil", expectedError) + fail(t, msg, errorMessage...) + return + } + + if errors.Is(err, expectedError) { + return + } + + msg := fmt.Sprintf("Error not equal: \nexpected: %v\nactual : %v", expectedError, err) + fail(t, msg, errorMessage...) +} + // True asserts that the specified value is true. func True(t *testing.T, value bool, errorMessage ...string) { t.Helper() diff --git a/assert/assert_test.go b/assert/assert_test.go index 9419042..4802c1c 100644 --- a/assert/assert_test.go +++ b/assert/assert_test.go @@ -1,6 +1,37 @@ package assert -import "testing" +import ( + "errors" + "fmt" + "testing" +) + +func TestEqual(t *testing.T) { + Equal(t, 1, 1) +} + +func TestNoError(t *testing.T) { + NoError(t, nil) +} + +func TestError(t *testing.T) { + err := errors.New("error text") + Error(t, err, err.Error()) +} + +func TestErrorIs(t *testing.T) { + errTest := errors.New("error") + err := fmt.Errorf("error: %w", errTest) + ErrorIs(t, err, errTest) +} + +func TestTrue(t *testing.T) { + True(t, true) +} + +func TestFalse(t *testing.T) { + False(t, false) +} func TestInterfaceNilEqual(t *testing.T) { var values []int