diff --git a/atomic.go b/atomic.go index 1ca50dc..1db6849 100644 --- a/atomic.go +++ b/atomic.go @@ -25,6 +25,7 @@ package atomic import ( "math" "sync/atomic" + "time" ) // Int32 is an atomic wrapper around an int32. @@ -304,6 +305,47 @@ func (f *Float64) CAS(old, new float64) bool { return atomic.CompareAndSwapUint64(&f.v, math.Float64bits(old), math.Float64bits(new)) } +// Duration is an atomic wrapper around time.Duration +// https://godoc.org/time#Duration +type Duration struct { + v Int64 +} + +// NewDuration creates a Duration. +func NewDuration(d time.Duration) *Duration { + return &Duration{v: *NewInt64(int64(d))} +} + +// Load atomically loads the wrapped value. +func (d *Duration) Load() time.Duration { + return time.Duration(d.v.Load()) +} + +// Store atomically stores the passed value. +func (d *Duration) Store(n time.Duration) { + d.v.Store(int64(n)) +} + +// Add atomically adds to the wrapped time.Duration and returns the new value. +func (d *Duration) Add(n time.Duration) time.Duration { + return time.Duration(d.v.Add(int64(n))) +} + +// Sub atomically subtracts from the wrapped time.Duration and returns the new value. +func (d *Duration) Sub(n time.Duration) time.Duration { + return time.Duration(d.v.Sub(int64(n))) +} + +// Swap atomically swaps the wrapped time.Duration and returns the old value. +func (d *Duration) Swap(n time.Duration) time.Duration { + return time.Duration(d.v.Swap(int64(n))) +} + +// CAS is an atomic compare-and-swap. +func (d *Duration) CAS(old, new time.Duration) bool { + return d.v.CAS(int64(old), int64(new)) +} + // Value shadows the type of the same name from sync/atomic // https://godoc.org/sync/atomic#Value type Value struct{ atomic.Value } diff --git a/atomic_test.go b/atomic_test.go index 9f293b7..6666f8a 100644 --- a/atomic_test.go +++ b/atomic_test.go @@ -22,6 +22,7 @@ package atomic import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -140,6 +141,23 @@ func TestFloat64(t *testing.T) { require.Equal(t, float64(42.0), atom.Sub(0.5), "Sub didn't work.") } +func TestDuration(t *testing.T) { + atom := NewDuration(5 * time.Minute) + + require.Equal(t, 5*time.Minute, atom.Load(), "Load didn't work.") + require.Equal(t, 6*time.Minute, atom.Add(time.Minute), "Add didn't work.") + require.Equal(t, 4*time.Minute, atom.Sub(2*time.Minute), "Sub didn't work.") + + require.True(t, atom.CAS(4*time.Minute, time.Minute), "CAS didn't report a swap.") + require.Equal(t, time.Minute, atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, time.Minute, atom.Swap(2*time.Minute), "Swap didn't return the old value.") + require.Equal(t, 2*time.Minute, atom.Load(), "Swap didn't set the correct value.") + + atom.Store(10 * time.Minute) + require.Equal(t, 10*time.Minute, atom.Load(), "Store didn't set the correct value.") +} + func TestValue(t *testing.T) { var v Value assert.Nil(t, v.Load(), "initial Value is not nil") diff --git a/stress_test.go b/stress_test.go index f35de2d..8fd1251 100644 --- a/stress_test.go +++ b/stress_test.go @@ -34,17 +34,18 @@ const ( ) var _stressTests = map[string]func() func(){ - "i32/std": stressStdInt32, - "i32": stressInt32, - "i64/std": stressStdInt32, - "i64": stressInt64, - "u32/std": stressStdUint32, - "u32": stressUint32, - "u64/std": stressStdUint64, - "u64": stressUint64, - "f64": stressFloat64, - "bool": stressBool, - "string": stressString, + "i32/std": stressStdInt32, + "i32": stressInt32, + "i64/std": stressStdInt32, + "i64": stressInt64, + "u32/std": stressStdUint32, + "u32": stressUint32, + "u64/std": stressStdUint64, + "u64": stressUint64, + "f64": stressFloat64, + "bool": stressBool, + "string": stressString, + "duration": stressDuration, } func TestStress(t *testing.T) { @@ -243,3 +244,15 @@ func stressString() func() { atom.Store("") } } + +func stressDuration() func() { + var atom = NewDuration(0) + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +}