From 6dd2c2bc0d1e72feeb10beb69cd84a4d940109f3 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Tue, 5 Nov 2024 20:21:24 +0400 Subject: [PATCH] fix: make inmem bookmarks random for each run As inmem storage is not persisted, make sure watch bookmark from one run doesn't work with another run. This still allows watches to be restarted on connection failures, but if the watch is restarted on a program restart, bookmark won't match anymore. Signed-off-by: Andrey Smirnov --- pkg/state/errors.go | 12 +++++++++ pkg/state/impl/inmem/collection.go | 32 +++++++++++++++++++----- pkg/state/impl/inmem/errors.go | 13 ++++++++++ pkg/state/impl/inmem/errors_test.go | 2 ++ pkg/state/impl/inmem/local_test.go | 34 ++++++++++++++++++++++++++ pkg/state/protobuf/client/client.go | 28 ++++++++++++++++++--- pkg/state/protobuf/client/errors.go | 7 ++++++ pkg/state/protobuf/protobuf_test.go | 38 +++++++++++++++++++++++++++++ pkg/state/protobuf/server/server.go | 7 +++++- 9 files changed, 162 insertions(+), 11 deletions(-) diff --git a/pkg/state/errors.go b/pkg/state/errors.go index dc87bb4..3d087e7 100644 --- a/pkg/state/errors.go +++ b/pkg/state/errors.go @@ -137,3 +137,15 @@ func errPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error }, } } + +// ErrInvalidWatchBookmark should be implemented by "invalid watch bookmark" errors. +type ErrInvalidWatchBookmark interface { + InvalidWatchBookmarkError() +} + +// IsInvalidWatchBookmarkError checks if err is invalid watch bookmark. +func IsInvalidWatchBookmarkError(err error) bool { + var i ErrInvalidWatchBookmark + + return errors.As(err, &i) +} diff --git a/pkg/state/impl/inmem/collection.go b/pkg/state/impl/inmem/collection.go index 986045a..080c80f 100644 --- a/pkg/state/impl/inmem/collection.go +++ b/pkg/state/impl/inmem/collection.go @@ -6,8 +6,10 @@ package inmem import ( "context" + "crypto/rand" "encoding/binary" "fmt" + "io" "slices" "sort" "sync" @@ -265,16 +267,34 @@ func (collection *ResourceCollection) Destroy(ctx context.Context, ptr resource. return nil } +// bookmarkCookie is a random cookie used to encode bookmarks. +// +// As the state is in-memory, we need to distinguish between bookmarks from different runs of the program. +var bookmarkCookie = sync.OnceValue(func() []byte { + cookie := make([]byte, 8) + + _, err := io.ReadFull(rand.Reader, cookie) + if err != nil { + panic(err) + } + + return cookie +}) + func encodeBookmark(pos int64) state.Bookmark { - return binary.BigEndian.AppendUint64(nil, uint64(pos)) + return binary.BigEndian.AppendUint64(slices.Clone(bookmarkCookie()), uint64(pos)) } func decodeBookmark(bookmark state.Bookmark) (int64, error) { - if len(bookmark) != 8 { - return 0, fmt.Errorf("invalid bookmark length: %d", len(bookmark)) + if len(bookmark) != 16 { + return 0, ErrInvalidWatchBookmark + } + + if !slices.Equal(bookmark[:8], bookmarkCookie()) { + return 0, ErrInvalidWatchBookmark } - return int64(binary.BigEndian.Uint64(bookmark)), nil + return int64(binary.BigEndian.Uint64(bookmark[8:])), nil } // Watch for specific resource changes. @@ -321,7 +341,7 @@ func (collection *ResourceCollection) Watch(ctx context.Context, id resource.ID, } if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < 0 || pos >= collection.writePos { - return fmt.Errorf("invalid bookmark: %d", pos) + return ErrInvalidWatchBookmark } // skip the bookmarked event @@ -478,7 +498,7 @@ func (collection *ResourceCollection) WatchAll(ctx context.Context, singleCh cha } if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < -1 || pos >= collection.writePos { - return fmt.Errorf("invalid bookmark: %d", pos) + return ErrInvalidWatchBookmark } // skip the bookmarked event diff --git a/pkg/state/impl/inmem/errors.go b/pkg/state/impl/inmem/errors.go index ae61b97..a92e685 100644 --- a/pkg/state/impl/inmem/errors.go +++ b/pkg/state/impl/inmem/errors.go @@ -5,6 +5,7 @@ package inmem import ( + "errors" "fmt" "github.com/cosi-project/runtime/pkg/resource" @@ -100,3 +101,15 @@ func ErrPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error }, } } + +//nolint:errname +type eInvalidWatchBookmark struct { + error +} + +func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {} + +// ErrInvalidWatchBookmark generates error compatible with state.ErrInvalidWatchBookmark. +var ErrInvalidWatchBookmark = eInvalidWatchBookmark{ + errors.New("invalid watch bookmark"), +} diff --git a/pkg/state/impl/inmem/errors_test.go b/pkg/state/impl/inmem/errors_test.go index fd848f3..f5535f9 100644 --- a/pkg/state/impl/inmem/errors_test.go +++ b/pkg/state/impl/inmem/errors_test.go @@ -32,4 +32,6 @@ func TestErrors(t *testing.T) { assert.True(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("a"), state.WithResourceNamespace("ns"))) assert.False(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("z"), state.WithResourceNamespace("ns"))) + + assert.True(t, state.IsInvalidWatchBookmarkError(inmem.ErrInvalidWatchBookmark)) } diff --git a/pkg/state/impl/inmem/local_test.go b/pkg/state/impl/inmem/local_test.go index 68dd4f7..cff4c4a 100644 --- a/pkg/state/impl/inmem/local_test.go +++ b/pkg/state/impl/inmem/local_test.go @@ -7,6 +7,7 @@ package inmem_test import ( "context" "fmt" + "slices" "strconv" "testing" "time" @@ -199,3 +200,36 @@ func TestNoBufferOverrunDynamic(t *testing.T) { } } } + +func TestWatchInvalidBookmark(t *testing.T) { + t.Parallel() + + const namespace = "default" + + st := state.WrapCore(inmem.NewState(namespace)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // start watching for changes + watchKindCh := make(chan state.Event) + + err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh) + require.NoError(t, err) + + // insert resource + err = st.Create(ctx, conformance.NewPathResource(namespace, "0")) + require.NoError(t, err) + + ev := <-watchKindCh + + require.Equal(t, state.Created, ev.Type) + require.NotEmpty(t, ev.Bookmark) + + invalidBookmark := slices.Clone(ev.Bookmark) + invalidBookmark[0] ^= 0xff + + err = st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh, state.WithKindStartFromBookmark(invalidBookmark)) + require.Error(t, err) + require.True(t, state.IsInvalidWatchBookmarkError(err)) +} diff --git a/pkg/state/protobuf/client/client.go b/pkg/state/protobuf/client/client.go index bae1dba..d8297cd 100644 --- a/pkg/state/protobuf/client/client.go +++ b/pkg/state/protobuf/client/client.go @@ -340,7 +340,12 @@ func (adapter *Adapter) Watch(ctx context.Context, resourcePointer resource.Poin // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -388,7 +393,12 @@ func (adapter *Adapter) WatchKind(ctx context.Context, resourceKind resource.Kin // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -437,7 +447,12 @@ func (adapter *Adapter) WatchKindAggregated(ctx context.Context, resourceKind re // receive first (empty) watch event _, err = cli.Recv() if err != nil { - return err + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: + return eInvalidWatchBookmark{err} + default: + return err + } } go adapter.watchAdapter(ctx, cli, nil, ch, opts.UnmarshalOptions.SkipProtobufUnmarshal, req) @@ -526,7 +541,12 @@ func (adapter *Adapter) watchAdapter( _, err = cli.Recv() if err != nil { - continue + switch status.Code(err) { //nolint:exhaustive + case codes.FailedPrecondition: // abort retries on invalid watch bookmark + return nil, eInvalidWatchBookmark{err} + default: + continue + } } msg, err = cli.Recv() diff --git a/pkg/state/protobuf/client/errors.go b/pkg/state/protobuf/client/errors.go index 0863dd1..79098e9 100644 --- a/pkg/state/protobuf/client/errors.go +++ b/pkg/state/protobuf/client/errors.go @@ -38,3 +38,10 @@ type ePhaseConflict struct { } func (ePhaseConflict) PhaseConflictError() {} + +//nolint:errname +type eInvalidWatchBookmark struct { + error +} + +func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {} diff --git a/pkg/state/protobuf/protobuf_test.go b/pkg/state/protobuf/protobuf_test.go index e53acd3..d21d6a8 100644 --- a/pkg/state/protobuf/protobuf_test.go +++ b/pkg/state/protobuf/protobuf_test.go @@ -215,6 +215,44 @@ func TestProtobufWatchRestart(t *testing.T) { } } +func TestProtobufWatchInvalidBookmark(t *testing.T) { + grpcConn, _, _, _ := ProtobufSetup(t) //nolint:dogsled + + stateClient := v1alpha1.NewStateClient(grpcConn) + + st := state.WrapCore(client.NewAdapter(stateClient, + client.WithRetryLogger(zaptest.NewLogger(t)), + )) + + ch := make(chan []state.Event) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + require.NoError(t, st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithBootstrapContents(true))) + + var bookmark []byte + + select { + case <-ctx.Done(): + t.Fatal("timeout") + case ev := <-ch: + require.Len(t, ev, 1) + + assert.Equal(t, state.Bootstrapped, ev[0].Type) + assert.NotEmpty(t, ev[0].Bookmark) + + bookmark = ev[0].Bookmark + } + + // send invalid bookmark + bookmark[0] ^= 0xff + + err := st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithKindStartFromBookmark(bookmark)) + require.Error(t, err) + assert.True(t, state.IsInvalidWatchBookmarkError(err)) +} + func noError[T any](t *testing.T, fn func(T) error, v T, ignored ...error) { t.Helper() diff --git a/pkg/state/protobuf/server/server.go b/pkg/state/protobuf/server/server.go index 35cf45b..4a3a8c6 100644 --- a/pkg/state/protobuf/server/server.go +++ b/pkg/state/protobuf/server/server.go @@ -328,7 +328,12 @@ func (server *State) Watch(req *v1alpha1.WatchRequest, srv v1alpha1.State_WatchS } if err != nil { - return err + switch { + case state.IsInvalidWatchBookmarkError(err): + return status.Error(codes.FailedPrecondition, err.Error()) + default: + return err + } } // send empty event to signal that watch is ready