diff --git a/pkg/state/errors.go b/pkg/state/errors.go index dc87bb4..3d087e7 100644 --- a/pkg/state/errors.go +++ b/pkg/state/errors.go @@ -137,3 +137,15 @@ func errPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error }, } } + +// ErrInvalidWatchBookmark should be implemented by "invalid watch bookmark" errors. +type ErrInvalidWatchBookmark interface { + InvalidWatchBookmarkError() +} + +// IsInvalidWatchBookmarkError checks if err is invalid watch bookmark. +func IsInvalidWatchBookmarkError(err error) bool { + var i ErrInvalidWatchBookmark + + return errors.As(err, &i) +} diff --git a/pkg/state/impl/inmem/collection.go b/pkg/state/impl/inmem/collection.go index 986045a..080c80f 100644 --- a/pkg/state/impl/inmem/collection.go +++ b/pkg/state/impl/inmem/collection.go @@ -6,8 +6,10 @@ package inmem import ( "context" + "crypto/rand" "encoding/binary" "fmt" + "io" "slices" "sort" "sync" @@ -265,16 +267,34 @@ func (collection *ResourceCollection) Destroy(ctx context.Context, ptr resource. return nil } +// bookmarkCookie is a random cookie used to encode bookmarks. +// +// As the state is in-memory, we need to distinguish between bookmarks from different runs of the program. +var bookmarkCookie = sync.OnceValue(func() []byte { + cookie := make([]byte, 8) + + _, err := io.ReadFull(rand.Reader, cookie) + if err != nil { + panic(err) + } + + return cookie +}) + func encodeBookmark(pos int64) state.Bookmark { - return binary.BigEndian.AppendUint64(nil, uint64(pos)) + return binary.BigEndian.AppendUint64(slices.Clone(bookmarkCookie()), uint64(pos)) } func decodeBookmark(bookmark state.Bookmark) (int64, error) { - if len(bookmark) != 8 { - return 0, fmt.Errorf("invalid bookmark length: %d", len(bookmark)) + if len(bookmark) != 16 { + return 0, ErrInvalidWatchBookmark + } + + if !slices.Equal(bookmark[:8], bookmarkCookie()) { + return 0, ErrInvalidWatchBookmark } - return int64(binary.BigEndian.Uint64(bookmark)), nil + return int64(binary.BigEndian.Uint64(bookmark[8:])), nil } // Watch for specific resource changes. @@ -321,7 +341,7 @@ func (collection *ResourceCollection) Watch(ctx context.Context, id resource.ID, } if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < 0 || pos >= collection.writePos { - return fmt.Errorf("invalid bookmark: %d", pos) + return ErrInvalidWatchBookmark } // skip the bookmarked event @@ -478,7 +498,7 @@ func (collection *ResourceCollection) WatchAll(ctx context.Context, singleCh cha } if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < -1 || pos >= collection.writePos { - return fmt.Errorf("invalid bookmark: %d", pos) + return ErrInvalidWatchBookmark } // skip the bookmarked event diff --git a/pkg/state/impl/inmem/errors.go b/pkg/state/impl/inmem/errors.go index ae61b97..a92e685 100644 --- a/pkg/state/impl/inmem/errors.go +++ b/pkg/state/impl/inmem/errors.go @@ -5,6 +5,7 @@ package inmem import ( + "errors" "fmt" "github.com/cosi-project/runtime/pkg/resource" @@ -100,3 +101,15 @@ func ErrPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error }, } } + +//nolint:errname +type eInvalidWatchBookmark struct { + error +} + +func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {} + +// ErrInvalidWatchBookmark generates error compatible with state.ErrInvalidWatchBookmark. +var ErrInvalidWatchBookmark = eInvalidWatchBookmark{ + errors.New("invalid watch bookmark"), +} diff --git a/pkg/state/impl/inmem/errors_test.go b/pkg/state/impl/inmem/errors_test.go index fd848f3..f5535f9 100644 --- a/pkg/state/impl/inmem/errors_test.go +++ b/pkg/state/impl/inmem/errors_test.go @@ -32,4 +32,6 @@ func TestErrors(t *testing.T) { assert.True(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("a"), state.WithResourceNamespace("ns"))) assert.False(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("z"), state.WithResourceNamespace("ns"))) + + assert.True(t, state.IsInvalidWatchBookmarkError(inmem.ErrInvalidWatchBookmark)) } diff --git a/pkg/state/impl/inmem/local_test.go b/pkg/state/impl/inmem/local_test.go index 68dd4f7..cff4c4a 100644 --- a/pkg/state/impl/inmem/local_test.go +++ b/pkg/state/impl/inmem/local_test.go @@ -7,6 +7,7 @@ package inmem_test import ( "context" "fmt" + "slices" "strconv" "testing" "time" @@ -199,3 +200,36 @@ func TestNoBufferOverrunDynamic(t *testing.T) { } } } + +func TestWatchInvalidBookmark(t *testing.T) { + t.Parallel() + + const namespace = "default" + + st := state.WrapCore(inmem.NewState(namespace)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // start watching for changes + watchKindCh := make(chan state.Event) + + err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh) + require.NoError(t, err) + + // insert resource + err = st.Create(ctx, conformance.NewPathResource(namespace, "0")) + require.NoError(t, err) + + ev := <-watchKindCh + + require.Equal(t, state.Created, ev.Type) + require.NotEmpty(t, ev.Bookmark) + + invalidBookmark := slices.Clone(ev.Bookmark) + invalidBookmark[0] ^= 0xff + + err = st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh, state.WithKindStartFromBookmark(invalidBookmark)) + require.Error(t, err) + require.True(t, state.IsInvalidWatchBookmarkError(err)) +} diff --git a/pkg/state/protobuf/client/client.go b/pkg/state/protobuf/client/client.go index bae1dba..d8297cd 100644 --- a/pkg/state/protobuf/client/client.go +++ b/pkg/state/protobuf/client/client.go @@ -340,7 +340,12 @@ func (adapter *Adapter) Watch(ctx context.Context, resourcePointer resource.Poin // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -388,7 +393,12 @@ func (adapter *Adapter) WatchKind(ctx context.Context, resourceKind resource.Kin // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -437,7 +447,12 @@ func (adapter *Adapter) WatchKindAggregated(ctx context.Context, resourceKind re // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, nil, ch, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -526,7 +541,12 @@ func (adapter *Adapter) watchAdapter( _, err = cli.Recv() if err != nil { - continue + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: // abort retries on invalid watch bookmark + return nil, eInvalidWatchBookmark{err} + default: + continue + } } msg, err = cli.Recv() diff --git a/pkg/state/protobuf/client/errors.go b/pkg/state/protobuf/client/errors.go index 0863dd1..79098e9 100644 --- a/pkg/state/protobuf/client/errors.go +++ b/pkg/state/protobuf/client/errors.go @@ -38,3 +38,10 @@ type ePhaseConflict struct { } func (ePhaseConflict) PhaseConflictError() {} + +//nolint:errname +type eInvalidWatchBookmark struct { + error +} + +func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {} diff --git a/pkg/state/protobuf/protobuf_test.go b/pkg/state/protobuf/protobuf_test.go index e53acd3..d21d6a8 100644 --- a/pkg/state/protobuf/protobuf_test.go +++ b/pkg/state/protobuf/protobuf_test.go @@ -215,6 +215,44 @@ func TestProtobufWatchRestart(t *testing.T) { } } +func TestProtobufWatchInvalidBookmark(t *testing.T) { + grpcConn, _, _, _ := ProtobufSetup(t) //nolint:dogsled + + stateClient := v1alpha1.NewStateClient(grpcConn) + + st := state.WrapCore(client.NewAdapter(stateClient, + client.WithRetryLogger(zaptest.NewLogger(t)), + )) + + ch := make(chan []state.Event) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + require.NoError(t, st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithBootstrapContents(true))) + + var bookmark []byte + + select { + case <-ctx.Done(): + t.Fatal("timeout") + case ev := <-ch: + require.Len(t, ev, 1) + + assert.Equal(t, state.Bootstrapped, ev[0].Type) + assert.NotEmpty(t, ev[0].Bookmark) + + bookmark = ev[0].Bookmark + } + + // send invalid bookmark + bookmark[0] ^= 0xff + + err := st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithKindStartFromBookmark(bookmark)) + require.Error(t, err) + assert.True(t, state.IsInvalidWatchBookmarkError(err)) +} + func noError[T any](t *testing.T, fn func(T) error, v T, ignored ...error) { t.Helper() diff --git a/pkg/state/protobuf/server/server.go b/pkg/state/protobuf/server/server.go index 35cf45b..4a3a8c6 100644 --- a/pkg/state/protobuf/server/server.go +++ b/pkg/state/protobuf/server/server.go @@ -328,7 +328,12 @@ func (server *State) Watch(req *v1alpha1.WatchRequest, srv v1alpha1.State_WatchS } if err != nil { - return err + switch { + case state.IsInvalidWatchBookmarkError(err): + return status.Error(codes.FailedPrecondition, err.Error()) + default: + return err + } } // send empty event to signal that watch is ready