Skip to content

Commit

Permalink
fix: fix the panic in loading state from storage
Browse files Browse the repository at this point in the history
When importing clusters from a file, grow the used buffer to the correct size of the next cluster, so that it does not panic when unmarshaling it.

Handle panics on the storage's save & load, so that it will never crash the discovery service when it fails.

Additionally:
- fix the slice growing logic when exporting clusters so that we avoid over-growing the slices for the affiliates and the endpoints.
- modify the storage tests to use the real state instead of a mock, replace the assertions to ignore the order accordingly.

Signed-off-by: Utku Ozdemir <[email protected]>
  • Loading branch information
utkuozdemir committed May 28, 2024
1 parent 10c83d2 commit 417251c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 51 deletions.
12 changes: 9 additions & 3 deletions internal/state/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
func (state *State) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error {
var err error

// reuse the same snapshotin each iteration
// reuse the same snapshot in each iteration
clusterSnapshot := &storagepb.ClusterSnapshot{}

state.clusters.Enumerate(func(_ string, cluster *Cluster) bool {
Expand Down Expand Up @@ -67,7 +67,10 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) {
snapshot.Id = cluster.id

// reuse the same slice, resize it as needed
snapshot.Affiliates = slices.Grow(snapshot.Affiliates, len(cluster.affiliates))
if len(cluster.affiliates) > cap(snapshot.Affiliates) {
snapshot.Affiliates = slices.Grow(snapshot.Affiliates, len(cluster.affiliates)-len(snapshot.Affiliates))
}

snapshot.Affiliates = snapshot.Affiliates[:len(cluster.affiliates)]

i := 0
Expand All @@ -88,7 +91,10 @@ func snapshotCluster(cluster *Cluster, snapshot *storagepb.ClusterSnapshot) {
snapshot.Affiliates[i].Data = affiliate.data

// reuse the same slice, resize it as needed
snapshot.Affiliates[i].Endpoints = slices.Grow(snapshot.Affiliates[i].Endpoints, len(affiliate.endpoints))
if len(affiliate.endpoints) > cap(snapshot.Affiliates[i].Endpoints) {
snapshot.Affiliates[i].Endpoints = slices.Grow(snapshot.Affiliates[i].Endpoints, len(affiliate.endpoints)-len(snapshot.Affiliates[i].Endpoints))
}

snapshot.Affiliates[i].Endpoints = snapshot.Affiliates[i].Endpoints[:len(affiliate.endpoints)]

for j, endpoint := range affiliate.endpoints {
Expand Down
16 changes: 15 additions & 1 deletion internal/state/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ func (storage *Storage) Save() (err error) {
}
}()

// never panic, convert it into an error instead
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("save panicked: %v", recovered)
}
}()

if err = os.MkdirAll(filepath.Dir(storage.path), 0o755); err != nil {
return fmt.Errorf("failed to create directory path: %w", err)
}
Expand Down Expand Up @@ -195,6 +202,13 @@ func (storage *Storage) Load() (err error) {
}
}()

// never panic, convert it into an error instead
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("load panicked: %v", recovered)
}
}()

start := time.Now()

// open file for reading
Expand Down Expand Up @@ -254,7 +268,7 @@ func (storage *Storage) Import(reader io.Reader) (SnapshotStats, error) {
}

if clusterSize > cap(buffer) {
buffer = slices.Grow(buffer, clusterSize-cap(buffer))
buffer = slices.Grow(buffer, clusterSize)
}

buffer = buffer[:clusterSize]
Expand Down
154 changes: 107 additions & 47 deletions internal/state/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"math"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"testing"
Expand All @@ -24,6 +25,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"

storagepb "github.com/siderolabs/discovery-service/api/storage"
"github.com/siderolabs/discovery-service/internal/state"
"github.com/siderolabs/discovery-service/internal/state/storage"
"github.com/siderolabs/discovery-service/pkg/limits"
)
Expand Down Expand Up @@ -51,11 +53,12 @@ func TestExport(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

snapshot := buildTestSnapshot(10)
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
state := &mockSnapshotter{data: snapshot}
logger := zaptest.NewLogger(t)
state := state.NewState(logger)

importTestState(t, state, tc.snapshot)

stateStorage := storage.New(path, state, logger)

Expand All @@ -64,12 +67,13 @@ func TestExport(t *testing.T) {
exportStats, err := stateStorage.Export(&buffer)
require.NoError(t, err)

assert.Equal(t, statsForSnapshot(snapshot), exportStats)
assert.Equal(t, statsForSnapshot(tc.snapshot), exportStats)

expected, err := snapshot.MarshalVT()
require.NoError(t, err)
exported := &storagepb.StateSnapshot{}

require.NoError(t, exported.UnmarshalVT(buffer.Bytes()))

require.Equal(t, expected, buffer.Bytes())
requireEqualIgnoreOrder(t, tc.snapshot, exported)
})
}
}
Expand All @@ -91,16 +95,15 @@ func TestImport(t *testing.T) {
},
{
"large state",
buildTestSnapshot(100),
buildTestSnapshot(2),
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

path := filepath.Join(t.TempDir(), "test.binpb")
state := &mockSnapshotter{data: tc.snapshot}
logger := zaptest.NewLogger(t)

state := state.NewState(logger)
stateStorage := storage.New(path, state, logger)

data, err := tc.snapshot.MarshalVT()
Expand All @@ -111,10 +114,9 @@ func TestImport(t *testing.T) {

require.Equal(t, statsForSnapshot(tc.snapshot), importStats)

loads := state.getLoads()
importedState := exportTestState(t, state)

require.Len(t, loads, 1)
require.True(t, loads[0].EqualVT(tc.snapshot))
requireEqualIgnoreOrder(t, tc.snapshot, importedState)
})
}
}
Expand All @@ -125,8 +127,8 @@ func TestImportMaxSize(t *testing.T) {
cluster := buildMaxSizeCluster()
stateSnapshot := &storagepb.StateSnapshot{Clusters: []*storagepb.ClusterSnapshot{cluster}}
path := filepath.Join(t.TempDir(), "test.binpb")
state := &mockSnapshotter{data: stateSnapshot}
logger := zaptest.NewLogger(t)
state := state.NewState(logger)

stateStorage := storage.New(path, state, logger)

Expand Down Expand Up @@ -161,7 +163,7 @@ func TestStorage(t *testing.T) {
snapshot := buildTestSnapshot(10)
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
state := &mockSnapshotter{data: snapshot}
state := newTestSnapshotter(t, snapshot)
logger := zaptest.NewLogger(t)

stateStorage := storage.New(path, state, logger)
Expand All @@ -170,19 +172,20 @@ func TestStorage(t *testing.T) {

require.NoError(t, stateStorage.Save())

expectedData, err := snapshot.MarshalVT()
savedBytes, err := os.ReadFile(path)
require.NoError(t, err)

actualData, err := os.ReadFile(path)
require.NoError(t, err)
savedSnapshot := &storagepb.StateSnapshot{}

require.Equal(t, expectedData, actualData)
require.NoError(t, savedSnapshot.UnmarshalVT(savedBytes))

requireEqualIgnoreOrder(t, snapshot, savedSnapshot)

// test load

require.NoError(t, stateStorage.Load())
require.Len(t, state.getLoads(), 1)
require.True(t, snapshot.EqualVT(state.getLoads()[0]))
requireEqualIgnoreOrder(t, snapshot, state.getLoads()[0])

// modify, save & load again to assert that the file content gets overwritten

Expand All @@ -191,7 +194,7 @@ func TestStorage(t *testing.T) {
require.NoError(t, stateStorage.Save())
require.NoError(t, stateStorage.Load())
require.Len(t, state.getLoads(), 2)
require.True(t, snapshot.EqualVT(state.getLoads()[1]))
requireEqualIgnoreOrder(t, snapshot, state.getLoads()[1])
}

func TestSchedule(t *testing.T) {
Expand All @@ -201,7 +204,7 @@ func TestSchedule(t *testing.T) {
snapshot := buildTestSnapshot(10)
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
state := &mockSnapshotter{data: snapshot}
state := newTestSnapshotter(t, snapshot)
logger := zaptest.NewLogger(t)

stateStorage := storage.New(path, state, logger)
Expand Down Expand Up @@ -246,66 +249,70 @@ func TestSchedule(t *testing.T) {
}, 2*time.Second, 100*time.Millisecond)
}

type mockSnapshotter struct {
data *storagepb.StateSnapshot
loads []*storagepb.StateSnapshot
// testSnapshotter is a mock implementation of storage.Snapshotter for testing purposes.
//
// It keeps track of the loads and the number of snapshots that have been performed to be used in assertions.
type testSnapshotter struct {
exportData *storagepb.StateSnapshot

tb testing.TB
loads []*storagepb.StateSnapshot
snapshots int
lock sync.Mutex
}

func newTestSnapshotter(tb testing.TB, exportData *storagepb.StateSnapshot) *testSnapshotter {
state := state.NewState(zaptest.NewLogger(tb))

importTestState(tb, state, exportData)

lock sync.Mutex
return &testSnapshotter{exportData: exportData, tb: tb}
}

func (m *mockSnapshotter) getSnapshots() int {
func (m *testSnapshotter) getSnapshots() int {
m.lock.Lock()
defer m.lock.Unlock()

return m.snapshots
}

func (m *mockSnapshotter) getLoads() []*storagepb.StateSnapshot {
func (m *testSnapshotter) getLoads() []*storagepb.StateSnapshot {
m.lock.Lock()
defer m.lock.Unlock()

return append([]*storagepb.StateSnapshot(nil), m.loads...)
}

// ExportClusterSnapshots implements storage.Snapshotter interface.
func (m *mockSnapshotter) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error {
func (m *testSnapshotter) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error {
m.lock.Lock()
defer m.lock.Unlock()

m.snapshots++
tempState := state.NewState(zaptest.NewLogger(m.tb))

for _, cluster := range m.data.Clusters {
if err := f(cluster); err != nil {
return err
}
importTestState(m.tb, tempState, m.exportData)

if err := tempState.ExportClusterSnapshots(f); err != nil {
return err
}

m.snapshots++

return nil
}

// ImportClusterSnapshots implements storage.Snapshotter interface.
func (m *mockSnapshotter) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error {
func (m *testSnapshotter) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error {
m.lock.Lock()
defer m.lock.Unlock()

var clusters []*storagepb.ClusterSnapshot
tempState := state.NewState(zaptest.NewLogger(m.tb))

for {
cluster, ok, err := f()
if err != nil {
return err
}

if !ok {
break
}

clusters = append(clusters, cluster)
if err := tempState.ImportClusterSnapshots(f); err != nil {
return err
}

m.loads = append(m.loads, &storagepb.StateSnapshot{Clusters: clusters})
m.loads = append(m.loads, exportTestState(m.tb, tempState))

return nil
}
Expand Down Expand Up @@ -396,3 +403,56 @@ func buildMaxSizeCluster() *storagepb.ClusterSnapshot {
Affiliates: affiliates,
}
}

func importTestState(tb testing.TB, state *state.State, snapshot *storagepb.StateSnapshot) {
clusters := snapshot.Clusters
i := 0

err := state.ImportClusterSnapshots(func() (*storagepb.ClusterSnapshot, bool, error) {
if i >= len(clusters) {
return nil, false, nil
}

cluster := clusters[i]
i++

return cluster, true, nil
})
require.NoError(tb, err)
}

func exportTestState(tb testing.TB, state *state.State) *storagepb.StateSnapshot {
snapshot := &storagepb.StateSnapshot{}

err := state.ExportClusterSnapshots(func(cluster *storagepb.ClusterSnapshot) error {
snapshot.Clusters = append(snapshot.Clusters, cluster.CloneVT()) // clone the cluster here, as its reference is reused across iterations

return nil
})
require.NoError(tb, err)

return snapshot
}

func requireEqualIgnoreOrder(tb testing.TB, expected, actual *storagepb.StateSnapshot) {
tb.Helper()

a := expected.CloneVT()
b := actual.CloneVT()

// sort clusters
for _, st := range []*storagepb.StateSnapshot{a, b} {
slices.SortFunc(st.Clusters, func(a, b *storagepb.ClusterSnapshot) int {
return strings.Compare(a.Id, b.Id)
})
}

// sort affiliates
for _, cluster := range append(a.Clusters, b.Clusters...) {
slices.SortFunc(cluster.Affiliates, func(a, b *storagepb.AffiliateSnapshot) int {
return strings.Compare(a.Id, b.Id)
})
}

assert.True(tb, a.EqualVT(b))
}

0 comments on commit 417251c

Please sign in to comment.