Skip to content

Commit

Permalink
Merge pull request #2 from nicheinc/feature/must
Browse files Browse the repository at this point in the history
Add "must" functions for tersely asserting function success
  • Loading branch information
jonathansharman authored Oct 8, 2024
2 parents c0e69ea + e5884d4 commit 6d71217
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 0 deletions.
53 changes: 53 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,56 @@ func ErrorIsAll(expected ...error) ErrorCheck {
}
}
}

// Must can be used on a (value, error) pair to either get the value or
// immediately fail the test if the error is non-nil. The T parameter is
// curried, rather than passed as a third argument, so that (value, error)
// function return values can be passed to Must directly, without assigning them
// to intermediate variables.
//
// See also Must0, Must2, and Must3 for working with functions of other coarity.
//
// bytes := expect.Must(io.ReadAll(reader))(t)
func Must[V any](value V, err error) func(T) V {
return func(t T) V {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value
}
}

// Must0 is similar to Must but for functions returning just an error, without a
// value.
func Must0(err error) func(T) {
return func(t T) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
}

// Must2 is similar to Must but for functions returning two values and an error.
func Must2[V1 any, V2 any](value1 V1, value2 V2, err error) func(T) (V1, V2) {
return func(t T) (V1, V2) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value1, value2
}
}

// Must3 is similar to Must but for functions returning three values and an
// error.
func Must3[V1 any, V2 any, V3 any](value1 V1, value2 V2, value3 V3, err error) func(T) (V1, V2, V3) {
return func(t T) (V1, V2, V3) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value1, value2, value3
}
}
141 changes: 141 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,146 @@ func TestErrors(t *testing.T) {
}
}

func TestMust(t *testing.T) {
type testCase struct {
f func() (bool, error)
expectedValue bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value := Must(testCase.f())(tMock)
Equal(t, value, testCase.expectedValue)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, error) {
return false, ErrTest
},
expectedValue: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, error) {
return true, nil
},
expectedValue: true,
expectedFatalCalls: 0,
})
}

func TestMust0(t *testing.T) {
type testCase struct {
f func() error
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
Must0(testCase.f())(tMock)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() error {
return ErrTest
},
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() error {
return nil
},
expectedFatalCalls: 0,
})
}

func TestMust2(t *testing.T) {
type testCase struct {
f func() (bool, bool, error)
expectedValue1 bool
expectedValue2 bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value1, value2 := Must2(testCase.f())(tMock)
Equal(t, value1, testCase.expectedValue1)
Equal(t, value2, testCase.expectedValue2)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, bool, error) {
return false, false, ErrTest
},
expectedValue1: false,
expectedValue2: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, bool, error) {
return true, true, nil
},
expectedValue1: true,
expectedValue2: true,
expectedFatalCalls: 0,
})
}

func TestMust3(t *testing.T) {
type testCase struct {
f func() (bool, bool, bool, error)
expectedValue1 bool
expectedValue2 bool
expectedValue3 bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value1, value2, value3 := Must3(testCase.f())(tMock)
Equal(t, value1, testCase.expectedValue1)
Equal(t, value2, testCase.expectedValue2)
Equal(t, value3, testCase.expectedValue3)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, bool, bool, error) {
return false, false, false, ErrTest
},
expectedValue1: false,
expectedValue2: false,
expectedValue3: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, bool, bool, error) {
return true, true, true, nil
},
expectedValue1: true,
expectedValue2: true,
expectedValue3: true,
expectedFatalCalls: 0,
})
}

type testErrorA struct{}

func (e testErrorA) Error() string {
Expand Down Expand Up @@ -199,5 +339,6 @@ func newTMock() *TMock {
return &TMock{
HelperStub: func() {},
ErrorfStub: func(format string, args ...any) {},
FatalfStub: func(format string, args ...any) {},
}
}
1 change: 1 addition & 0 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ package expect
type T interface {
Helper()
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
}
30 changes: 30 additions & 0 deletions testing_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,59 @@ package expect

import (
"sync/atomic"
"testing"
)

// TMock is a mock implementation of the T
// interface.
type TMock struct {
T *testing.T
HelperStub func()
HelperCalled int32
ErrorfStub func(format string, args ...any)
ErrorfCalled int32
FatalfStub func(format string, args ...any)
FatalfCalled int32
}

// Verify that *TMock implements T.
var _ T = &TMock{}

// Helper is a stub for the T.Helper
// method that records the number of times it has been called.
func (m *TMock) Helper() {
atomic.AddInt32(&m.HelperCalled, 1)
if m.HelperStub == nil {
if m.T != nil {
m.T.Error("HelperStub is nil")
}
panic("Helper unimplemented")
}
m.HelperStub()
}

// Errorf is a stub for the T.Errorf
// method that records the number of times it has been called.
func (m *TMock) Errorf(format string, args ...any) {
atomic.AddInt32(&m.ErrorfCalled, 1)
if m.ErrorfStub == nil {
if m.T != nil {
m.T.Error("ErrorfStub is nil")
}
panic("Errorf unimplemented")
}
m.ErrorfStub(format, args...)
}

// Fatalf is a stub for the T.Fatalf
// method that records the number of times it has been called.
func (m *TMock) Fatalf(format string, args ...any) {
atomic.AddInt32(&m.FatalfCalled, 1)
if m.FatalfStub == nil {
if m.T != nil {
m.T.Error("FatalfStub is nil")
}
panic("Fatalf unimplemented")
}
m.FatalfStub(format, args...)
}

0 comments on commit 6d71217

Please sign in to comment.