From 417251c0ba82917a5d31e9bae56eb054340a0c64 Mon Sep 17 00:00:00 2001 From: Utku Ozdemir Date: Tue, 28 May 2024 18:45:20 +0200 Subject: [PATCH] fix: fix the panic in loading state from storage When importing clusters from a file, grow the used buffer to the correct size of the next cluster, so that it does not panic when unmarshaling it. Handle panics on the storage's save & load, so that it will never crash the discovery service when it fails. Additionally: - fix the slice growing logic when exporting clusters so that we avoid over-growing the slices for the affiliates and the endpoints. - modify the storage tests to use the real state instead of a mock, replace the assertions to ignore the order accordingly. Signed-off-by: Utku Ozdemir --- internal/state/snapshot.go | 12 +- internal/state/storage/storage.go | 16 ++- internal/state/storage/storage_test.go | 154 +++++++++++++++++-------- 3 files changed, 131 insertions(+), 51 deletions(-) diff --git a/internal/state/snapshot.go b/internal/state/snapshot.go index 99ed5c3..f1aa4ba 100644 --- a/internal/state/snapshot.go +++ b/internal/state/snapshot.go @@ -21,7 +21,7 @@ import ( func (state *State) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error { var err error - // reuse the same snapshotin each iteration + // reuse the same snapshot in each iteration clusterSnapshot := &storagepb.ClusterSnapshot{} state.clusters.Enumerate(func(_ string, cluster *Cluster) bool { @@ -67,7 +67,10 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) { snapshot.Id = cluster.id // reuse the same slice, resize it as needed - snapshot.Affiliates = slices.Grow(snapshot.Affiliates, len(cluster.affiliates)) + if len(cluster.affiliates) > cap(snapshot.Affiliates) { + snapshot.Affiliates = slices.Grow(snapshot.Affiliates, len(cluster.affiliates)-len(snapshot.Affiliates)) + } + snapshot.Affiliates = snapshot.Affiliates[:len(cluster.affiliates)] i := 0 @@ -88,7 +91,10 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) { snapshot.Affiliates[i].Data = affiliate.data // reuse the same slice, resize it as needed - snapshot.Affiliates[i].Endpoints = slices.Grow(snapshot.Affiliates[i].Endpoints, len(affiliate.endpoints)) + if len(affiliate.endpoints) > cap(snapshot.Affiliates[i].Endpoints) { + snapshot.Affiliates[i].Endpoints = slices.Grow(snapshot.Affiliates[i].Endpoints, len(affiliate.endpoints)-len(snapshot.Affiliates[i].Endpoints)) + } + snapshot.Affiliates[i].Endpoints = snapshot.Affiliates[i].Endpoints[:len(affiliate.endpoints)] for j, endpoint := range affiliate.endpoints { diff --git a/internal/state/storage/storage.go b/internal/state/storage/storage.go index bdeed0a..0803ea1 100644 --- a/internal/state/storage/storage.go +++ b/internal/state/storage/storage.go @@ -149,6 +149,13 @@ func (storage *Storage) Save() (err error) { } }() + // never panic, convert it into an error instead + defer func() { + if recovered := recover(); recovered != nil { + err = fmt.Errorf("save panicked: %v", recovered) + } + }() + if err = os.MkdirAll(filepath.Dir(storage.path), 0o755); err != nil { return fmt.Errorf("failed to create directory path: %w", err) } @@ -195,6 +202,13 @@ func (storage *Storage) Load() (err error) { } }() + // never panic, convert it into an error instead + defer func() { + if recovered := recover(); recovered != nil { + err = fmt.Errorf("load panicked: %v", recovered) + } + }() + start := time.Now() // open file for reading @@ -254,7 +268,7 @@ func (storage *Storage) Import(reader io.Reader) (SnapshotStats, error) { } if clusterSize > cap(buffer) { - buffer = slices.Grow(buffer, clusterSize-cap(buffer)) + buffer = slices.Grow(buffer, clusterSize) } buffer = buffer[:clusterSize] diff --git a/internal/state/storage/storage_test.go b/internal/state/storage/storage_test.go index 78c243a..b3592bc 100644 --- a/internal/state/storage/storage_test.go +++ b/internal/state/storage/storage_test.go @@ -12,6 +12,7 @@ import ( "math" "os" "path/filepath" + "slices" "strings" "sync" "testing" @@ -24,6 +25,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" storagepb "github.com/siderolabs/discovery-service/api/storage" + "github.com/siderolabs/discovery-service/internal/state" "github.com/siderolabs/discovery-service/internal/state/storage" "github.com/siderolabs/discovery-service/pkg/limits" ) @@ -51,11 +53,12 @@ func TestExport(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - snapshot := buildTestSnapshot(10) tempDir := t.TempDir() path := filepath.Join(tempDir, "test.binpb") - state := &mockSnapshotter{data: snapshot} logger := zaptest.NewLogger(t) + state := state.NewState(logger) + + importTestState(t, state, tc.snapshot) stateStorage := storage.New(path, state, logger) @@ -64,12 +67,13 @@ func TestExport(t *testing.T) { exportStats, err := stateStorage.Export(&buffer) require.NoError(t, err) - assert.Equal(t, statsForSnapshot(snapshot), exportStats) + assert.Equal(t, statsForSnapshot(tc.snapshot), exportStats) - expected, err := snapshot.MarshalVT() - require.NoError(t, err) + exported := &storagepb.StateSnapshot{} + + require.NoError(t, exported.UnmarshalVT(buffer.Bytes())) - require.Equal(t, expected, buffer.Bytes()) + requireEqualIgnoreOrder(t, tc.snapshot, exported) }) } } @@ -91,16 +95,15 @@ func TestImport(t *testing.T) { }, { "large state", - buildTestSnapshot(100), + buildTestSnapshot(2), }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() path := filepath.Join(t.TempDir(), "test.binpb") - state := &mockSnapshotter{data: tc.snapshot} logger := zaptest.NewLogger(t) - + state := state.NewState(logger) stateStorage := storage.New(path, state, logger) data, err := tc.snapshot.MarshalVT() @@ -111,10 +114,9 @@ func TestImport(t *testing.T) { require.Equal(t, statsForSnapshot(tc.snapshot), importStats) - loads := state.getLoads() + importedState := exportTestState(t, state) - require.Len(t, loads, 1) - require.True(t, loads[0].EqualVT(tc.snapshot)) + requireEqualIgnoreOrder(t, tc.snapshot, importedState) }) } } @@ -125,8 +127,8 @@ func TestImportMaxSize(t *testing.T) { cluster := buildMaxSizeCluster() stateSnapshot := &storagepb.StateSnapshot{Clusters: []*storagepb.ClusterSnapshot{cluster}} path := filepath.Join(t.TempDir(), "test.binpb") - state := &mockSnapshotter{data: stateSnapshot} logger := zaptest.NewLogger(t) + state := state.NewState(logger) stateStorage := storage.New(path, state, logger) @@ -161,7 +163,7 @@ func TestStorage(t *testing.T) { snapshot := buildTestSnapshot(10) tempDir := t.TempDir() path := filepath.Join(tempDir, "test.binpb") - state := &mockSnapshotter{data: snapshot} + state := newTestSnapshotter(t, snapshot) logger := zaptest.NewLogger(t) stateStorage := storage.New(path, state, logger) @@ -170,19 +172,20 @@ func TestStorage(t *testing.T) { require.NoError(t, stateStorage.Save()) - expectedData, err := snapshot.MarshalVT() + savedBytes, err := os.ReadFile(path) require.NoError(t, err) - actualData, err := os.ReadFile(path) - require.NoError(t, err) + savedSnapshot := &storagepb.StateSnapshot{} - require.Equal(t, expectedData, actualData) + require.NoError(t, savedSnapshot.UnmarshalVT(savedBytes)) + + requireEqualIgnoreOrder(t, snapshot, savedSnapshot) // test load require.NoError(t, stateStorage.Load()) require.Len(t, state.getLoads(), 1) - require.True(t, snapshot.EqualVT(state.getLoads()[0])) + requireEqualIgnoreOrder(t, snapshot, state.getLoads()[0]) // modify, save & load again to assert that the file content gets overwritten @@ -191,7 +194,7 @@ func TestStorage(t *testing.T) { require.NoError(t, stateStorage.Save()) require.NoError(t, stateStorage.Load()) require.Len(t, state.getLoads(), 2) - require.True(t, snapshot.EqualVT(state.getLoads()[1])) + requireEqualIgnoreOrder(t, snapshot, state.getLoads()[1]) } func TestSchedule(t *testing.T) { @@ -201,7 +204,7 @@ func TestSchedule(t *testing.T) { snapshot := buildTestSnapshot(10) tempDir := t.TempDir() path := filepath.Join(tempDir, "test.binpb") - state := &mockSnapshotter{data: snapshot} + state := newTestSnapshotter(t, snapshot) logger := zaptest.NewLogger(t) stateStorage := storage.New(path, state, logger) @@ -246,23 +249,34 @@ func TestSchedule(t *testing.T) { }, 2*time.Second, 100*time.Millisecond) } -type mockSnapshotter struct { - data *storagepb.StateSnapshot - loads []*storagepb.StateSnapshot +// testSnapshotter is a mock implementation of storage.Snapshotter for testing purposes. +// +// It keeps track of the loads and the number of snapshots that have been performed to be used in assertions. +type testSnapshotter struct { + exportData *storagepb.StateSnapshot + tb testing.TB + loads []*storagepb.StateSnapshot snapshots int + lock sync.Mutex +} + +func newTestSnapshotter(tb testing.TB, exportData *storagepb.StateSnapshot) *testSnapshotter { + state := state.NewState(zaptest.NewLogger(tb)) + + importTestState(tb, state, exportData) - lock sync.Mutex + return &testSnapshotter{exportData: exportData, tb: tb} } -func (m *mockSnapshotter) getSnapshots() int { +func (m *testSnapshotter) getSnapshots() int { m.lock.Lock() defer m.lock.Unlock() return m.snapshots } -func (m *mockSnapshotter) getLoads() []*storagepb.StateSnapshot { +func (m *testSnapshotter) getLoads() []*storagepb.StateSnapshot { m.lock.Lock() defer m.lock.Unlock() @@ -270,42 +284,35 @@ func (m *mockSnapshotter) getLoads() []*storagepb.StateSnapshot { } // ExportClusterSnapshots implements storage.Snapshotter interface. -func (m *mockSnapshotter) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error { +func (m *testSnapshotter) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error { m.lock.Lock() defer m.lock.Unlock() - m.snapshots++ + tempState := state.NewState(zaptest.NewLogger(m.tb)) - for _, cluster := range m.data.Clusters { - if err := f(cluster); err != nil { - return err - } + importTestState(m.tb, tempState, m.exportData) + + if err := tempState.ExportClusterSnapshots(f); err != nil { + return err } + m.snapshots++ + return nil } // ImportClusterSnapshots implements storage.Snapshotter interface. -func (m *mockSnapshotter) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error { +func (m *testSnapshotter) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error { m.lock.Lock() defer m.lock.Unlock() - var clusters []*storagepb.ClusterSnapshot + tempState := state.NewState(zaptest.NewLogger(m.tb)) - for { - cluster, ok, err := f() - if err != nil { - return err - } - - if !ok { - break - } - - clusters = append(clusters, cluster) + if err := tempState.ImportClusterSnapshots(f); err != nil { + return err } - m.loads = append(m.loads, &storagepb.StateSnapshot{Clusters: clusters}) + m.loads = append(m.loads, exportTestState(m.tb, tempState)) return nil } @@ -396,3 +403,56 @@ func buildMaxSizeCluster() *storagepb.ClusterSnapshot { Affiliates: affiliates, } } + +func importTestState(tb testing.TB, state *state.State, snapshot *storagepb.StateSnapshot) { + clusters := snapshot.Clusters + i := 0 + + err := state.ImportClusterSnapshots(func() (*storagepb.ClusterSnapshot, bool, error) { + if i >= len(clusters) { + return nil, false, nil + } + + cluster := clusters[i] + i++ + + return cluster, true, nil + }) + require.NoError(tb, err) +} + +func exportTestState(tb testing.TB, state *state.State) *storagepb.StateSnapshot { + snapshot := &storagepb.StateSnapshot{} + + err := state.ExportClusterSnapshots(func(cluster *storagepb.ClusterSnapshot) error { + snapshot.Clusters = append(snapshot.Clusters, cluster.CloneVT()) // clone the cluster here, as its reference is reused across iterations + + return nil + }) + require.NoError(tb, err) + + return snapshot +} + +func requireEqualIgnoreOrder(tb testing.TB, expected, actual *storagepb.StateSnapshot) { + tb.Helper() + + a := expected.CloneVT() + b := actual.CloneVT() + + // sort clusters + for _, st := range []*storagepb.StateSnapshot{a, b} { + slices.SortFunc(st.Clusters, func(a, b *storagepb.ClusterSnapshot) int { + return strings.Compare(a.Id, b.Id) + }) + } + + // sort affiliates + for _, cluster := range append(a.Clusters, b.Clusters...) { + slices.SortFunc(cluster.Affiliates, func(a, b *storagepb.AffiliateSnapshot) int { + return strings.Compare(a.Id, b.Id) + }) + } + + assert.True(tb, a.EqualVT(b)) +}