Skip to content

Commit

Permalink
fix: make inmem bookmarks random for each run
Browse files Browse the repository at this point in the history
As inmem storage is not persisted, make sure watch bookmark from one run
doesn't work with another run.

This still allows watches to be restarted on connection failures, but if
the watch is restarted on a program restart, bookmark won't match
anymore.

Signed-off-by: Andrey Smirnov <[email protected]>
  • Loading branch information
smira committed Dec 12, 2024
1 parent f4ff7ab commit 6dd2c2b
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 11 deletions.
12 changes: 12 additions & 0 deletions pkg/state/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,15 @@ func errPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error
},
}
}

// ErrInvalidWatchBookmark should be implemented by "invalid watch bookmark" errors.
type ErrInvalidWatchBookmark interface {
InvalidWatchBookmarkError()
}

// IsInvalidWatchBookmarkError checks if err is invalid watch bookmark.
func IsInvalidWatchBookmarkError(err error) bool {
var i ErrInvalidWatchBookmark

return errors.As(err, &i)
}
32 changes: 26 additions & 6 deletions pkg/state/impl/inmem/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package inmem

import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"slices"
"sort"
"sync"
Expand Down Expand Up @@ -265,16 +267,34 @@ func (collection *ResourceCollection) Destroy(ctx context.Context, ptr resource.
return nil
}

// bookmarkCookie is a random cookie used to encode bookmarks.
//
// As the state is in-memory, we need to distinguish between bookmarks from different runs of the program.
var bookmarkCookie = sync.OnceValue(func() []byte {
cookie := make([]byte, 8)

_, err := io.ReadFull(rand.Reader, cookie)
if err != nil {
panic(err)
}

return cookie
})

func encodeBookmark(pos int64) state.Bookmark {
return binary.BigEndian.AppendUint64(nil, uint64(pos))
return binary.BigEndian.AppendUint64(slices.Clone(bookmarkCookie()), uint64(pos))
}

func decodeBookmark(bookmark state.Bookmark) (int64, error) {
if len(bookmark) != 8 {
return 0, fmt.Errorf("invalid bookmark length: %d", len(bookmark))
if len(bookmark) != 16 {
return 0, ErrInvalidWatchBookmark
}

if !slices.Equal(bookmark[:8], bookmarkCookie()) {
return 0, ErrInvalidWatchBookmark
}

return int64(binary.BigEndian.Uint64(bookmark)), nil
return int64(binary.BigEndian.Uint64(bookmark[8:])), nil
}

// Watch for specific resource changes.
Expand Down Expand Up @@ -321,7 +341,7 @@ func (collection *ResourceCollection) Watch(ctx context.Context, id resource.ID,
}

if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < 0 || pos >= collection.writePos {
return fmt.Errorf("invalid bookmark: %d", pos)
return ErrInvalidWatchBookmark
}

// skip the bookmarked event
Expand Down Expand Up @@ -478,7 +498,7 @@ func (collection *ResourceCollection) WatchAll(ctx context.Context, singleCh cha
}

if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < -1 || pos >= collection.writePos {
return fmt.Errorf("invalid bookmark: %d", pos)
return ErrInvalidWatchBookmark
}

// skip the bookmarked event
Expand Down
13 changes: 13 additions & 0 deletions pkg/state/impl/inmem/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package inmem

import (
"errors"
"fmt"

"github.com/cosi-project/runtime/pkg/resource"
Expand Down Expand Up @@ -100,3 +101,15 @@ func ErrPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error
},
}
}

//nolint:errname
type eInvalidWatchBookmark struct {
error
}

func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {}

// ErrInvalidWatchBookmark generates error compatible with state.ErrInvalidWatchBookmark.
var ErrInvalidWatchBookmark = eInvalidWatchBookmark{
errors.New("invalid watch bookmark"),
}
2 changes: 2 additions & 0 deletions pkg/state/impl/inmem/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ func TestErrors(t *testing.T) {

assert.True(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("a"), state.WithResourceNamespace("ns")))
assert.False(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("z"), state.WithResourceNamespace("ns")))

assert.True(t, state.IsInvalidWatchBookmarkError(inmem.ErrInvalidWatchBookmark))
}
34 changes: 34 additions & 0 deletions pkg/state/impl/inmem/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package inmem_test
import (
"context"
"fmt"
"slices"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -199,3 +200,36 @@ func TestNoBufferOverrunDynamic(t *testing.T) {
}
}
}

func TestWatchInvalidBookmark(t *testing.T) {
t.Parallel()

const namespace = "default"

st := state.WrapCore(inmem.NewState(namespace))

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

// start watching for changes
watchKindCh := make(chan state.Event)

err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh)
require.NoError(t, err)

// insert resource
err = st.Create(ctx, conformance.NewPathResource(namespace, "0"))
require.NoError(t, err)

ev := <-watchKindCh

require.Equal(t, state.Created, ev.Type)
require.NotEmpty(t, ev.Bookmark)

invalidBookmark := slices.Clone(ev.Bookmark)
invalidBookmark[0] ^= 0xff

err = st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh, state.WithKindStartFromBookmark(invalidBookmark))
require.Error(t, err)
require.True(t, state.IsInvalidWatchBookmarkError(err))
}
28 changes: 24 additions & 4 deletions pkg/state/protobuf/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,12 @@ func (adapter *Adapter) Watch(ctx context.Context, resourcePointer resource.Poin
// receive first (empty) watch event
_, err = cli.Recv()
if err != nil {
return err
switch status.Code(err) { //nolint:exhaustive
case codes.FailedPrecondition:
return eInvalidWatchBookmark{err}
default:
return err
}
}

go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
Expand Down Expand Up @@ -388,7 +393,12 @@ func (adapter *Adapter) WatchKind(ctx context.Context, resourceKind resource.Kin
// receive first (empty) watch event
_, err = cli.Recv()
if err != nil {
return err
switch status.Code(err) { //nolint:exhaustive
case codes.FailedPrecondition:
return eInvalidWatchBookmark{err}
default:
return err
}
}

go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
Expand Down Expand Up @@ -437,7 +447,12 @@ func (adapter *Adapter) WatchKindAggregated(ctx context.Context, resourceKind re
// receive first (empty) watch event
_, err = cli.Recv()
if err != nil {
return err
switch status.Code(err) { //nolint:exhaustive
case codes.FailedPrecondition:
return eInvalidWatchBookmark{err}
default:
return err
}
}

go adapter.watchAdapter(ctx, cli, nil, ch, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
Expand Down Expand Up @@ -526,7 +541,12 @@ func (adapter *Adapter) watchAdapter(

_, err = cli.Recv()
if err != nil {
continue
switch status.Code(err) { //nolint:exhaustive
case codes.FailedPrecondition: // abort retries on invalid watch bookmark
return nil, eInvalidWatchBookmark{err}
default:
continue
}
}

msg, err = cli.Recv()
Expand Down
7 changes: 7 additions & 0 deletions pkg/state/protobuf/client/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ type ePhaseConflict struct {
}

func (ePhaseConflict) PhaseConflictError() {}

//nolint:errname
type eInvalidWatchBookmark struct {
error
}

func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {}
38 changes: 38 additions & 0 deletions pkg/state/protobuf/protobuf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,44 @@ func TestProtobufWatchRestart(t *testing.T) {
}
}

func TestProtobufWatchInvalidBookmark(t *testing.T) {
grpcConn, _, _, _ := ProtobufSetup(t) //nolint:dogsled

stateClient := v1alpha1.NewStateClient(grpcConn)

st := state.WrapCore(client.NewAdapter(stateClient,
client.WithRetryLogger(zaptest.NewLogger(t)),
))

ch := make(chan []state.Event)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)

require.NoError(t, st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithBootstrapContents(true)))

var bookmark []byte

select {
case <-ctx.Done():
t.Fatal("timeout")
case ev := <-ch:
require.Len(t, ev, 1)

assert.Equal(t, state.Bootstrapped, ev[0].Type)
assert.NotEmpty(t, ev[0].Bookmark)

bookmark = ev[0].Bookmark
}

// send invalid bookmark
bookmark[0] ^= 0xff

err := st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithKindStartFromBookmark(bookmark))
require.Error(t, err)
assert.True(t, state.IsInvalidWatchBookmarkError(err))
}

func noError[T any](t *testing.T, fn func(T) error, v T, ignored ...error) {
t.Helper()

Expand Down
7 changes: 6 additions & 1 deletion pkg/state/protobuf/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ func (server *State) Watch(req *v1alpha1.WatchRequest, srv v1alpha1.State_WatchS
}

if err != nil {
return err
switch {
case state.IsInvalidWatchBookmarkError(err):
return status.Error(codes.FailedPrecondition, err.Error())
default:
return err
}
}

// send empty event to signal that watch is ready
Expand Down

0 comments on commit 6dd2c2b

Please sign in to comment.