From 9702b160232a861b5d522bf3eef34b2ba9dc89be Mon Sep 17 00:00:00 2001 From: Thomas Hipp Date: Sun, 7 Jan 2024 21:10:39 +0100 Subject: [PATCH] Move db storage volume snapshot function to ClusterTx Signed-off-by: Thomas Hipp --- cmd/incusd/storage_volumes_snapshot.go | 24 +++- internal/server/db/entity.go | 8 +- internal/server/db/storage_pools_test.go | 6 +- .../server/db/storage_volume_snapshots.go | 111 ++++++++---------- internal/server/storage/backend.go | 12 +- internal/server/storage/utils.go | 6 +- 6 files changed, 96 insertions(+), 71 deletions(-) diff --git a/cmd/incusd/storage_volumes_snapshot.go b/cmd/incusd/storage_volumes_snapshot.go index 69e487691a2..78f564a24ec 100644 --- a/cmd/incusd/storage_volumes_snapshot.go +++ b/cmd/incusd/storage_volumes_snapshot.go @@ -711,7 +711,13 @@ func storagePoolVolumeSnapshotTypeGet(d *Daemon, r *http.Request) response.Respo return response.SmartError(err) } - expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID) + var expiry time.Time + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID) + + return err + }) if err != nil { return response.SmartError(err) } @@ -840,7 +846,13 @@ func storagePoolVolumeSnapshotTypePut(d *Daemon, r *http.Request) response.Respo return response.SmartError(err) } - expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID) + var expiry time.Time + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID) + + return err + }) if err != nil { return response.SmartError(err) } @@ -971,7 +983,13 @@ func storagePoolVolumeSnapshotTypePatch(d *Daemon, r *http.Request) response.Res return err }) - expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID) + var expiry time.Time + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID) + + return err + }) if err != nil { return response.SmartError(err) } diff --git a/internal/server/db/entity.go b/internal/server/db/entity.go index 07ac027ffc6..ac1b522e221 100644 --- a/internal/server/db/entity.go +++ b/internal/server/db/entity.go @@ -347,7 +347,13 @@ func (c *Cluster) GetURIFromEntity(entityType int, entityID int) (string, error) uri = fmt.Sprintf(cluster.EntityURIs[entityType], volume.PoolName, volume.TypeName, volume.Name, backup.Name, volume.ProjectName) case cluster.TypeStorageVolumeSnapshot: - snapshot, err := c.GetStorageVolumeSnapshotWithID(entityID) + var snapshot StorageVolumeArgs + + err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { + snapshot, err = tx.GetStorageVolumeSnapshotWithID(ctx, entityID) + + return err + }) if err != nil { return "", fmt.Errorf("Failed to get volume snapshot: %w", err) } diff --git a/internal/server/db/storage_pools_test.go b/internal/server/db/storage_pools_test.go index 58de1958502..0580a6aadca 100644 --- a/internal/server/db/storage_pools_test.go +++ b/internal/server/db/storage_pools_test.go @@ -261,7 +261,11 @@ func TestCreateStoragePoolVolume_Snapshot(t *testing.T) { require.NoError(t, err) config = map[string]string{"k": "v"} - _, err = cluster.CreateStorageVolumeSnapshot("default", "v1/snap0", "", 1, poolID, config, time.Now(), time.Time{}) + err = cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err = tx.CreateStorageVolumeSnapshot(ctx, "default", "v1/snap0", "", 1, poolID, config, time.Now(), time.Time{}) + + return err + }) require.NoError(t, err) n := cluster.GetNextStorageVolumeSnapshotIndex("p1", "v1", 1, "snap%d") diff --git a/internal/server/db/storage_volume_snapshots.go b/internal/server/db/storage_volume_snapshots.go index 30acfad8da2..38428ed888b 100644 --- a/internal/server/db/storage_volume_snapshots.go +++ b/internal/server/db/storage_volume_snapshots.go @@ -17,7 +17,7 @@ import ( // CreateStorageVolumeSnapshot creates a new storage volume snapshot attached to a given // storage pool. -func (c *Cluster) CreateStorageVolumeSnapshot(projectName string, volumeName string, volumeDescription string, volumeType int, poolID int64, volumeConfig map[string]string, creationDate time.Time, expiryDate time.Time) (int64, error) { +func (c *ClusterTx) CreateStorageVolumeSnapshot(ctx context.Context, projectName string, volumeName string, volumeDescription string, volumeType int, poolID int64, volumeConfig map[string]string, creationDate time.Time, expiryDate time.Time) (int64, error) { var volumeID int64 var snapshotName string @@ -25,85 +25,74 @@ func (c *Cluster) CreateStorageVolumeSnapshot(projectName string, volumeName str volumeName = parts[0] snapshotName = parts[1] - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - // Figure out the volume ID of the parent. - parentID, err := tx.storagePoolVolumeGetTypeID(ctx, projectName, volumeName, volumeType, poolID, c.nodeID) - if err != nil { - return fmt.Errorf("Failed finding parent volume record for snapshot: %w", err) - } - - _, err = tx.tx.Exec("UPDATE sqlite_sequence SET seq = seq + 1 WHERE name = 'storage_volumes'") - if err != nil { - return fmt.Errorf("Failed incrementing storage volumes sequence: %w", err) - } + // Figure out the volume ID of the parent. + parentID, err := c.storagePoolVolumeGetTypeID(ctx, projectName, volumeName, volumeType, poolID, c.nodeID) + if err != nil { + return -1, fmt.Errorf("Failed finding parent volume record for snapshot: %w", err) + } - row := tx.tx.QueryRowContext(ctx, "SELECT seq FROM sqlite_sequence WHERE name = 'storage_volumes' LIMIT 1") - err = row.Scan(&volumeID) - if err != nil { - return fmt.Errorf("Failed getting storage volumes sequence: %w", err) - } + _, err = c.tx.ExecContext(ctx, "UPDATE sqlite_sequence SET seq = seq + 1 WHERE name = 'storage_volumes'") + if err != nil { + return -1, fmt.Errorf("Failed incrementing storage volumes sequence: %w", err) + } - _, err = tx.tx.Exec("INSERT INTO storage_volumes_snapshots (id, storage_volume_id, name, description, creation_date, expiry_date) VALUES (?, ?, ?, ?, ?, ?)", volumeID, parentID, snapshotName, volumeDescription, creationDate, expiryDate) - if err != nil { - return fmt.Errorf("Failed creating volume snapshot record: %w", err) - } + row := c.tx.QueryRowContext(ctx, "SELECT seq FROM sqlite_sequence WHERE name = 'storage_volumes' LIMIT 1") + err = row.Scan(&volumeID) + if err != nil { + return -1, fmt.Errorf("Failed getting storage volumes sequence: %w", err) + } - err = storageVolumeConfigAdd(tx.tx, volumeID, volumeConfig, true) - if err != nil { - return fmt.Errorf("Failed inserting storage volume snapshot record configuration: %w", err) - } + _, err = c.tx.ExecContext(ctx, "INSERT INTO storage_volumes_snapshots (id, storage_volume_id, name, description, creation_date, expiry_date) VALUES (?, ?, ?, ?, ?, ?)", volumeID, parentID, snapshotName, volumeDescription, creationDate, expiryDate) + if err != nil { + return -1, fmt.Errorf("Failed creating volume snapshot record: %w", err) + } - return nil - }) + err = storageVolumeConfigAdd(c.tx, volumeID, volumeConfig, true) if err != nil { - volumeID = -1 + return -1, fmt.Errorf("Failed inserting storage volume snapshot record configuration: %w", err) } - return volumeID, err + return volumeID, nil } // UpdateStorageVolumeSnapshot updates the storage volume snapshot attached to a given storage pool. -func (c *Cluster) UpdateStorageVolumeSnapshot(projectName string, volumeName string, volumeType int, poolID int64, volumeDescription string, volumeConfig map[string]string, expiryDate time.Time) error { +func (c *ClusterTx) UpdateStorageVolumeSnapshot(ctx context.Context, projectName string, volumeName string, volumeType int, poolID int64, volumeDescription string, volumeConfig map[string]string, expiryDate time.Time) error { var err error if !strings.Contains(volumeName, internalInstance.SnapshotDelimiter) { return fmt.Errorf("Volume is not a snapshot") } - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - volume, err := tx.GetStoragePoolVolume(ctx, poolID, projectName, volumeType, volumeName, true) - if err != nil { - return err - } - - err = storageVolumeConfigClear(tx.tx, volume.ID, true) - if err != nil { - return err - } + volume, err := c.GetStoragePoolVolume(ctx, poolID, projectName, volumeType, volumeName, true) + if err != nil { + return err + } - err = storageVolumeConfigAdd(tx.tx, volume.ID, volumeConfig, true) - if err != nil { - return err - } + err = storageVolumeConfigClear(c.tx, volume.ID, true) + if err != nil { + return err + } - err = storageVolumeDescriptionUpdate(tx.tx, volume.ID, volumeDescription, true) - if err != nil { - return err - } + err = storageVolumeConfigAdd(c.tx, volume.ID, volumeConfig, true) + if err != nil { + return err + } - err = storageVolumeSnapshotExpiryDateUpdate(tx.tx, volume.ID, expiryDate) - if err != nil { - return err - } + err = storageVolumeDescriptionUpdate(c.tx, volume.ID, volumeDescription, true) + if err != nil { + return err + } - return nil - }) + err = storageVolumeSnapshotExpiryDateUpdate(c.tx, volume.ID, expiryDate) + if err != nil { + return err + } - return err + return nil } // GetStorageVolumeSnapshotWithID returns the volume snapshot with the given ID. -func (c *Cluster) GetStorageVolumeSnapshotWithID(snapshotID int) (StorageVolumeArgs, error) { +func (c *ClusterTx) GetStorageVolumeSnapshotWithID(ctx context.Context, snapshotID int) (StorageVolumeArgs, error) { args := StorageVolumeArgs{} q := ` SELECT @@ -121,9 +110,7 @@ WHERE volumes.id=? arg1 := []any{snapshotID} outfmt := []any{&args.ID, &args.Name, &args.CreationDate, &args.PoolName, &args.Type, &args.ProjectName} - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return dbQueryRowScan(ctx, tx, q, arg1, outfmt) - }) + err := dbQueryRowScan(ctx, c, q, arg1, outfmt) if err != nil { if err == sql.ErrNoRows { return args, api.StatusErrorf(http.StatusNotFound, "Storage pool volume snapshot not found") @@ -142,16 +129,14 @@ WHERE volumes.id=? } // GetStorageVolumeSnapshotExpiry gets the expiry date of a storage volume snapshot. -func (c *Cluster) GetStorageVolumeSnapshotExpiry(volumeID int64) (time.Time, error) { +func (c *ClusterTx) GetStorageVolumeSnapshotExpiry(ctx context.Context, volumeID int64) (time.Time, error) { var expiry time.Time query := "SELECT expiry_date FROM storage_volumes_snapshots WHERE id=?" inargs := []any{volumeID} outargs := []any{&expiry} - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return dbQueryRowScan(ctx, tx, query, inargs, outargs) - }) + err := dbQueryRowScan(ctx, c, query, inargs, outargs) if err != nil { if err == sql.ErrNoRows { return expiry, api.StatusErrorf(http.StatusNotFound, "Storage pool volume snapshot not found") diff --git a/internal/server/storage/backend.go b/internal/server/storage/backend.go index 850cc1c6e9a..8342f15f9f9 100644 --- a/internal/server/storage/backend.go +++ b/internal/server/storage/backend.go @@ -5378,7 +5378,13 @@ func (b *backend) UpdateCustomVolumeSnapshot(projectName string, volName string, return err } - curExpiryDate, err := b.state.DB.Cluster.GetStorageVolumeSnapshotExpiry(curVol.ID) + var curExpiryDate time.Time + + err = b.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + curExpiryDate, err = tx.GetStorageVolumeSnapshotExpiry(ctx, curVol.ID) + + return err + }) if err != nil { return err } @@ -5392,7 +5398,9 @@ func (b *backend) UpdateCustomVolumeSnapshot(projectName string, volName string, // Update the database if description changed. Use current config. if newDesc != curVol.Description || newExpiryDate != curExpiryDate { - err = b.state.DB.Cluster.UpdateStorageVolumeSnapshot(projectName, volName, db.StoragePoolVolumeTypeCustom, b.ID(), newDesc, curVol.Config, newExpiryDate) + err = b.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + return tx.UpdateStorageVolumeSnapshot(ctx, projectName, volName, db.StoragePoolVolumeTypeCustom, b.ID(), newDesc, curVol.Config, newExpiryDate) + }) if err != nil { return err } diff --git a/internal/server/storage/utils.go b/internal/server/storage/utils.go index 83cb36d82f2..623271433c0 100644 --- a/internal/server/storage/utils.go +++ b/internal/server/storage/utils.go @@ -285,7 +285,11 @@ func VolumeDBCreate(pool Pool, projectName string, volumeName string, volumeDesc // Create the database entry for the storage volume. if snapshot { - _, err = p.state.DB.Cluster.CreateStorageVolumeSnapshot(projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), creationDate, expiryDate) + err = p.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err = tx.CreateStorageVolumeSnapshot(ctx, projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), creationDate, expiryDate) + + return err + }) } else { _, err = p.state.DB.Cluster.CreateStoragePoolVolume(projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), volDBContentType, creationDate) }