diff --git a/utils/timed_version.go b/utils/timed_version.go index 8053191f..d609eabd 100644 --- a/utils/timed_version.go +++ b/utils/timed_version.go @@ -15,6 +15,9 @@ package utils import ( + "database/sql/driver" + "encoding/binary" + "errors" "fmt" "sync" "time" @@ -110,10 +113,22 @@ func TimedVersionFromTime(t time.Time) TimedVersion { } func (t *TimedVersion) Update(other *TimedVersion) bool { + return t.Upgrade(other) +} + +func (t *TimedVersion) Upgrade(other *TimedVersion) bool { + return t.update(other, func(ov, prev uint64) bool { return ov > prev }) +} + +func (t *TimedVersion) Downgrade(other *TimedVersion) bool { + return t.update(other, func(ov, prev uint64) bool { return ov < prev }) +} + +func (t *TimedVersion) update(other *TimedVersion, cmp func(ov, prev uint64) bool) bool { ov := other.v.Load() for { prev := t.v.Load() - if ov <= prev { + if !cmp(ov, prev) { return false } if t.v.CompareAndSwap(prev, ov) { @@ -167,3 +182,36 @@ func (t *TimedVersion) String() string { ts, ticks := timedVersionComponents(t.v.Load()) return fmt.Sprintf("%d.%d", ts, ticks) } + +func (t TimedVersion) Value() (driver.Value, error) { + if t.IsZero() { + return nil, nil + } + + ts, ticks := timedVersionComponents(t.v.Load()) + b := make([]byte, 0, 12) + b = binary.BigEndian.AppendUint64(b, uint64(ts)) + b = binary.BigEndian.AppendUint32(b, uint32(ticks)) + return b, nil +} + +func (t *TimedVersion) Scan(src interface{}) (err error) { + switch b := src.(type) { + case []byte: + switch len(b) { + case 0: + t.v.Store(0) + case 12: + ts := int64(binary.BigEndian.Uint64(b)) + ticks := int32(binary.BigEndian.Uint32(b[8:])) + *t = timedVersionFromComponents(ts, ticks) + default: + return errors.New("(*TimedVersion).Scan: unsupported format") + } + case nil: + t.v.Store(0) + default: + return errors.New("(*TimedVersion).Scan: unsupported data type") + } + return nil +} diff --git a/utils/timed_version_test.go b/utils/timed_version_test.go index 1905132f..9a8b16ac 100644 --- a/utils/timed_version_test.go +++ b/utils/timed_version_test.go @@ -64,4 +64,8 @@ func TestTimedVersion(t *testing.T) { require.Equal(t, ts1, ts2) require.Equal(t, tv1.v.Load(), tv2.v.Load()) }) + + t.Run("timed version from nil is zero", func(t *testing.T) { + require.True(t, NewTimedVersionFromProto(nil).IsZero()) + }) }