Skip to content

Commit

Permalink
Move db storage volume snapshot function to ClusterTx
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Hipp <[email protected]>
  • Loading branch information
monstermunchkin committed Jan 7, 2024
1 parent 6fa429e commit 9702b16
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 71 deletions.
24 changes: 21 additions & 3 deletions cmd/incusd/storage_volumes_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,13 @@ func storagePoolVolumeSnapshotTypeGet(d *Daemon, r *http.Request) response.Respo
return response.SmartError(err)
}

expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID)
var expiry time.Time

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down Expand Up @@ -840,7 +846,13 @@ func storagePoolVolumeSnapshotTypePut(d *Daemon, r *http.Request) response.Respo
return response.SmartError(err)
}

expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID)
var expiry time.Time

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down Expand Up @@ -971,7 +983,13 @@ func storagePoolVolumeSnapshotTypePatch(d *Daemon, r *http.Request) response.Res
return err
})

expiry, err := s.DB.Cluster.GetStorageVolumeSnapshotExpiry(dbVolume.ID)
var expiry time.Time

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
expiry, err = tx.GetStorageVolumeSnapshotExpiry(ctx, dbVolume.ID)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down
8 changes: 7 additions & 1 deletion internal/server/db/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,13 @@ func (c *Cluster) GetURIFromEntity(entityType int, entityID int) (string, error)

uri = fmt.Sprintf(cluster.EntityURIs[entityType], volume.PoolName, volume.TypeName, volume.Name, backup.Name, volume.ProjectName)
case cluster.TypeStorageVolumeSnapshot:
snapshot, err := c.GetStorageVolumeSnapshotWithID(entityID)
var snapshot StorageVolumeArgs

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
snapshot, err = tx.GetStorageVolumeSnapshotWithID(ctx, entityID)

return err
})
if err != nil {
return "", fmt.Errorf("Failed to get volume snapshot: %w", err)
}
Expand Down
6 changes: 5 additions & 1 deletion internal/server/db/storage_pools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ func TestCreateStoragePoolVolume_Snapshot(t *testing.T) {
require.NoError(t, err)

config = map[string]string{"k": "v"}
_, err = cluster.CreateStorageVolumeSnapshot("default", "v1/snap0", "", 1, poolID, config, time.Now(), time.Time{})
err = cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
_, err = tx.CreateStorageVolumeSnapshot(ctx, "default", "v1/snap0", "", 1, poolID, config, time.Now(), time.Time{})

return err
})
require.NoError(t, err)

n := cluster.GetNextStorageVolumeSnapshotIndex("p1", "v1", 1, "snap%d")
Expand Down
111 changes: 48 additions & 63 deletions internal/server/db/storage_volume_snapshots.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,82 @@ import (

// CreateStorageVolumeSnapshot creates a new storage volume snapshot attached to a given
// storage pool.
func (c *Cluster) CreateStorageVolumeSnapshot(projectName string, volumeName string, volumeDescription string, volumeType int, poolID int64, volumeConfig map[string]string, creationDate time.Time, expiryDate time.Time) (int64, error) {
func (c *ClusterTx) CreateStorageVolumeSnapshot(ctx context.Context, projectName string, volumeName string, volumeDescription string, volumeType int, poolID int64, volumeConfig map[string]string, creationDate time.Time, expiryDate time.Time) (int64, error) {
var volumeID int64

var snapshotName string
parts := strings.Split(volumeName, internalInstance.SnapshotDelimiter)
volumeName = parts[0]
snapshotName = parts[1]

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Figure out the volume ID of the parent.
parentID, err := tx.storagePoolVolumeGetTypeID(ctx, projectName, volumeName, volumeType, poolID, c.nodeID)
if err != nil {
return fmt.Errorf("Failed finding parent volume record for snapshot: %w", err)
}

_, err = tx.tx.Exec("UPDATE sqlite_sequence SET seq = seq + 1 WHERE name = 'storage_volumes'")
if err != nil {
return fmt.Errorf("Failed incrementing storage volumes sequence: %w", err)
}
// Figure out the volume ID of the parent.
parentID, err := c.storagePoolVolumeGetTypeID(ctx, projectName, volumeName, volumeType, poolID, c.nodeID)
if err != nil {
return -1, fmt.Errorf("Failed finding parent volume record for snapshot: %w", err)
}

row := tx.tx.QueryRowContext(ctx, "SELECT seq FROM sqlite_sequence WHERE name = 'storage_volumes' LIMIT 1")
err = row.Scan(&volumeID)
if err != nil {
return fmt.Errorf("Failed getting storage volumes sequence: %w", err)
}
_, err = c.tx.ExecContext(ctx, "UPDATE sqlite_sequence SET seq = seq + 1 WHERE name = 'storage_volumes'")
if err != nil {
return -1, fmt.Errorf("Failed incrementing storage volumes sequence: %w", err)
}

_, err = tx.tx.Exec("INSERT INTO storage_volumes_snapshots (id, storage_volume_id, name, description, creation_date, expiry_date) VALUES (?, ?, ?, ?, ?, ?)", volumeID, parentID, snapshotName, volumeDescription, creationDate, expiryDate)
if err != nil {
return fmt.Errorf("Failed creating volume snapshot record: %w", err)
}
row := c.tx.QueryRowContext(ctx, "SELECT seq FROM sqlite_sequence WHERE name = 'storage_volumes' LIMIT 1")
err = row.Scan(&volumeID)
if err != nil {
return -1, fmt.Errorf("Failed getting storage volumes sequence: %w", err)
}

err = storageVolumeConfigAdd(tx.tx, volumeID, volumeConfig, true)
if err != nil {
return fmt.Errorf("Failed inserting storage volume snapshot record configuration: %w", err)
}
_, err = c.tx.ExecContext(ctx, "INSERT INTO storage_volumes_snapshots (id, storage_volume_id, name, description, creation_date, expiry_date) VALUES (?, ?, ?, ?, ?, ?)", volumeID, parentID, snapshotName, volumeDescription, creationDate, expiryDate)
if err != nil {
return -1, fmt.Errorf("Failed creating volume snapshot record: %w", err)
}

return nil
})
err = storageVolumeConfigAdd(c.tx, volumeID, volumeConfig, true)
if err != nil {
volumeID = -1
return -1, fmt.Errorf("Failed inserting storage volume snapshot record configuration: %w", err)
}

return volumeID, err
return volumeID, nil
}

// UpdateStorageVolumeSnapshot updates the storage volume snapshot attached to a given storage pool.
func (c *Cluster) UpdateStorageVolumeSnapshot(projectName string, volumeName string, volumeType int, poolID int64, volumeDescription string, volumeConfig map[string]string, expiryDate time.Time) error {
func (c *ClusterTx) UpdateStorageVolumeSnapshot(ctx context.Context, projectName string, volumeName string, volumeType int, poolID int64, volumeDescription string, volumeConfig map[string]string, expiryDate time.Time) error {
var err error

if !strings.Contains(volumeName, internalInstance.SnapshotDelimiter) {
return fmt.Errorf("Volume is not a snapshot")
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
volume, err := tx.GetStoragePoolVolume(ctx, poolID, projectName, volumeType, volumeName, true)
if err != nil {
return err
}

err = storageVolumeConfigClear(tx.tx, volume.ID, true)
if err != nil {
return err
}
volume, err := c.GetStoragePoolVolume(ctx, poolID, projectName, volumeType, volumeName, true)
if err != nil {
return err
}

err = storageVolumeConfigAdd(tx.tx, volume.ID, volumeConfig, true)
if err != nil {
return err
}
err = storageVolumeConfigClear(c.tx, volume.ID, true)
if err != nil {
return err
}

err = storageVolumeDescriptionUpdate(tx.tx, volume.ID, volumeDescription, true)
if err != nil {
return err
}
err = storageVolumeConfigAdd(c.tx, volume.ID, volumeConfig, true)
if err != nil {
return err
}

err = storageVolumeSnapshotExpiryDateUpdate(tx.tx, volume.ID, expiryDate)
if err != nil {
return err
}
err = storageVolumeDescriptionUpdate(c.tx, volume.ID, volumeDescription, true)
if err != nil {
return err
}

return nil
})
err = storageVolumeSnapshotExpiryDateUpdate(c.tx, volume.ID, expiryDate)
if err != nil {
return err
}

return err
return nil
}

// GetStorageVolumeSnapshotWithID returns the volume snapshot with the given ID.
func (c *Cluster) GetStorageVolumeSnapshotWithID(snapshotID int) (StorageVolumeArgs, error) {
func (c *ClusterTx) GetStorageVolumeSnapshotWithID(ctx context.Context, snapshotID int) (StorageVolumeArgs, error) {
args := StorageVolumeArgs{}
q := `
SELECT
Expand All @@ -121,9 +110,7 @@ WHERE volumes.id=?
arg1 := []any{snapshotID}
outfmt := []any{&args.ID, &args.Name, &args.CreationDate, &args.PoolName, &args.Type, &args.ProjectName}

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, outfmt)
})
err := dbQueryRowScan(ctx, c, q, arg1, outfmt)
if err != nil {
if err == sql.ErrNoRows {
return args, api.StatusErrorf(http.StatusNotFound, "Storage pool volume snapshot not found")
Expand All @@ -142,16 +129,14 @@ WHERE volumes.id=?
}

// GetStorageVolumeSnapshotExpiry gets the expiry date of a storage volume snapshot.
func (c *Cluster) GetStorageVolumeSnapshotExpiry(volumeID int64) (time.Time, error) {
func (c *ClusterTx) GetStorageVolumeSnapshotExpiry(ctx context.Context, volumeID int64) (time.Time, error) {
var expiry time.Time

query := "SELECT expiry_date FROM storage_volumes_snapshots WHERE id=?"
inargs := []any{volumeID}
outargs := []any{&expiry}

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, query, inargs, outargs)
})
err := dbQueryRowScan(ctx, c, query, inargs, outargs)
if err != nil {
if err == sql.ErrNoRows {
return expiry, api.StatusErrorf(http.StatusNotFound, "Storage pool volume snapshot not found")
Expand Down
12 changes: 10 additions & 2 deletions internal/server/storage/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -5378,7 +5378,13 @@ func (b *backend) UpdateCustomVolumeSnapshot(projectName string, volName string,
return err
}

curExpiryDate, err := b.state.DB.Cluster.GetStorageVolumeSnapshotExpiry(curVol.ID)
var curExpiryDate time.Time

err = b.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
curExpiryDate, err = tx.GetStorageVolumeSnapshotExpiry(ctx, curVol.ID)

return err
})
if err != nil {
return err
}
Expand All @@ -5392,7 +5398,9 @@ func (b *backend) UpdateCustomVolumeSnapshot(projectName string, volName string,

// Update the database if description changed. Use current config.
if newDesc != curVol.Description || newExpiryDate != curExpiryDate {
err = b.state.DB.Cluster.UpdateStorageVolumeSnapshot(projectName, volName, db.StoragePoolVolumeTypeCustom, b.ID(), newDesc, curVol.Config, newExpiryDate)
err = b.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
return tx.UpdateStorageVolumeSnapshot(ctx, projectName, volName, db.StoragePoolVolumeTypeCustom, b.ID(), newDesc, curVol.Config, newExpiryDate)
})
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion internal/server/storage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ func VolumeDBCreate(pool Pool, projectName string, volumeName string, volumeDesc

// Create the database entry for the storage volume.
if snapshot {
_, err = p.state.DB.Cluster.CreateStorageVolumeSnapshot(projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), creationDate, expiryDate)
err = p.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
_, err = tx.CreateStorageVolumeSnapshot(ctx, projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), creationDate, expiryDate)

return err
})
} else {
_, err = p.state.DB.Cluster.CreateStoragePoolVolume(projectName, volumeName, volumeDescription, volDBType, pool.ID(), vol.Config(), volDBContentType, creationDate)
}
Expand Down

0 comments on commit 9702b16

Please sign in to comment.