Skip to content

Commit

Permalink
assert.ErrorAs: log target type
Browse files Browse the repository at this point in the history
  • Loading branch information
craig65535 committed Dec 31, 2024
1 parent 7c367bb commit ca6698b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 37 deletions.
49 changes: 37 additions & 12 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2102,7 +2102,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
Expand All @@ -2125,7 +2125,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
Expand All @@ -2143,10 +2143,10 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{
return true
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, true)

return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+
"expected: %T\n"+
"in chain: %s", target, chain,
), msgAndArgs...)
}
Expand All @@ -2161,24 +2161,49 @@ func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interfa
return true
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, true)

return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
"found: %T\n"+
"in chain: %s", target, chain,
), msgAndArgs...)
}

func buildErrorChainString(err error) string {
func unwrapAll(err error) (errs []error) {
errs = append(errs, err)
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, unwrapAll(err)...)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
return
default:
return
}
return
}

func buildErrorChainString(err error, withType bool) string {
if err == nil {
return ""
}

e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
var chain string
errs := unwrapAll(err)
for i := range errs {
if i != 0 {
chain += "\n\t"
}
chain += fmt.Sprintf("%q", errs[i].Error())
if withType {
chain += fmt.Sprintf(" (%T)", errs[i])
}
}
return chain
}
87 changes: 62 additions & 25 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3175,11 +3175,13 @@ func parseLabeledOutput(output string) []labeledContent {
}

type captureTestingT struct {
msg string
failed bool
msg string
}

func (ctt *captureTestingT) Errorf(format string, args ...interface{}) {
ctt.msg = fmt.Sprintf(format, args...)
ctt.failed = true
}

func (ctt *captureTestingT) checkResultAndErrMsg(t *testing.T, expectedRes, res bool, expectedErrMsg string) {
Expand All @@ -3188,6 +3190,9 @@ func (ctt *captureTestingT) checkResultAndErrMsg(t *testing.T, expectedRes, res
t.Errorf("Should return %t", expectedRes)
return
}
if res == ctt.failed {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !ctt.failed)
}
contents := parseLabeledOutput(ctt.msg)
if res == true {
if contents != nil {
Expand Down Expand Up @@ -3348,50 +3353,82 @@ func TestNotErrorIs(t *testing.T) {

func TestErrorAs(t *testing.T) {
tests := []struct {
err error
result bool
err error
result bool
resultErrMsg string
}{
{fmt.Errorf("wrap: %w", &customError{}), true},
{io.EOF, false},
{nil, false},
{
err: fmt.Errorf("wrap: %w", &customError{}),
result: true,
},
{
err: io.EOF,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \"EOF\" (*errors.errorString)\n",
},
{
err: nil,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \n",
},
{
err: fmt.Errorf("abc: %w", errors.New("def")),
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \"abc: def\" (*fmt.wrapError)\n" +
"\t\"def\" (*errors.errorString)\n",
},
}
for _, tt := range tests {
tt := tt
var target *customError
t.Run(fmt.Sprintf("ErrorAs(%#v,%#v)", tt.err, target), func(t *testing.T) {
mockT := new(testing.T)
mockT := new(captureTestingT)
res := ErrorAs(mockT, tt.err, &target)
if res != tt.result {
t.Errorf("ErrorAs(%#v,%#v) should return %t", tt.err, target, tt.result)
}
if res == mockT.Failed() {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !mockT.Failed())
}
mockT.checkResultAndErrMsg(t, tt.result, res, tt.resultErrMsg)
})
}
}

func TestNotErrorAs(t *testing.T) {
tests := []struct {
err error
result bool
err error
result bool
resultErrMsg string
}{
{fmt.Errorf("wrap: %w", &customError{}), false},
{io.EOF, true},
{nil, true},
{
err: fmt.Errorf("wrap: %w", &customError{}),
result: false,
resultErrMsg: "" +
"Target error should not be in err chain:\n" +
"found: **assert.customError\n" +
"in chain: \"wrap: fail\" (*fmt.wrapError)\n" +
"\t\"fail\" (*assert.customError)\n",
},
{
err: io.EOF,
result: true,
},
{
err: nil,
result: true,
},
}
for _, tt := range tests {
tt := tt
var target *customError
t.Run(fmt.Sprintf("NotErrorAs(%#v,%#v)", tt.err, target), func(t *testing.T) {
mockT := new(testing.T)
mockT := new(captureTestingT)
res := NotErrorAs(mockT, tt.err, &target)
if res != tt.result {
t.Errorf("NotErrorAs(%#v,%#v) should not return %t", tt.err, target, tt.result)
}
if res == mockT.Failed() {
t.Errorf("The test result (%t) should be reflected in the testing.T type (%t)", res, !mockT.Failed())
}
mockT.checkResultAndErrMsg(t, tt.result, res, tt.resultErrMsg)
})
}
}

0 comments on commit ca6698b

Please sign in to comment.