Skip to content

Commit

Permalink
Move db network ACL functions to ClusterTx
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Hipp <[email protected]>
  • Loading branch information
monstermunchkin committed Dec 12, 2023
1 parent 2ec03ce commit dd2fc19
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 149 deletions.
14 changes: 12 additions & 2 deletions cmd/incusd/network_acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -12,6 +13,7 @@ import (

"github.com/lxc/incus/internal/server/auth"
clusterRequest "github.com/lxc/incus/internal/server/cluster/request"
"github.com/lxc/incus/internal/server/db"
"github.com/lxc/incus/internal/server/lifecycle"
"github.com/lxc/incus/internal/server/network/acl"
"github.com/lxc/incus/internal/server/project"
Expand Down Expand Up @@ -150,8 +152,16 @@ func networkACLsGet(d *Daemon, r *http.Request) response.Response {

recursion := localUtil.IsRecursionRequest(r)

// Get list of Network ACLs.
aclNames, err := s.DB.Cluster.GetNetworkACLs(projectName)
var aclNames []string

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
var err error

// Get list of Network ACLs.
aclNames, err = tx.GetNetworkACLs(ctx, projectName)

return err
})
if err != nil {
return response.InternalError(err)
}
Expand Down
60 changes: 33 additions & 27 deletions cmd/incusd/patches.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,44 +290,50 @@ func patchNetworkACLRemoveDefaults(name string, d *Daemon) error {
return err
}

// Get ACLs in projects.
for _, projectName := range projectNames {
aclNames, err := d.db.Cluster.GetNetworkACLs(projectName)
if err != nil {
return err
}

for _, aclName := range aclNames {
aclID, acl, err := d.db.Cluster.GetNetworkACL(projectName, aclName)
err = d.db.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
// Get ACLs in projects.
for _, projectName := range projectNames {
aclNames, err := tx.GetNetworkACLs(ctx, projectName)
if err != nil {
return err
}

modified := false
for _, aclName := range aclNames {
aclID, acl, err := tx.GetNetworkACL(ctx, projectName, aclName)
if err != nil {
return err
}

// Remove the offending keys if found.
_, found := acl.Config["default.action"]
if found {
delete(acl.Config, "default.action")
modified = true
}
modified := false

_, found = acl.Config["default.logged"]
if found {
delete(acl.Config, "default.logged")
modified = true
}
// Remove the offending keys if found.
_, found := acl.Config["default.action"]
if found {
delete(acl.Config, "default.action")
modified = true
}

// Write back modified config if needed.
if modified {
err = d.db.Cluster.UpdateNetworkACL(aclID, &acl.NetworkACLPut)
if err != nil {
return fmt.Errorf("Failed updating network ACL %d: %w", aclID, err)
_, found = acl.Config["default.logged"]
if found {
delete(acl.Config, "default.logged")
modified = true
}

// Write back modified config if needed.
if modified {
err = tx.UpdateNetworkACL(ctx, aclID, &acl.NetworkACLPut)
if err != nil {
return fmt.Errorf("Failed updating network ACL %d: %w", aclID, err)
}
}
}
}
}

return nil
})
if err != nil {
return err
}
return nil
}

Expand Down
11 changes: 10 additions & 1 deletion internal/server/db/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,16 @@ func (c *Cluster) GetURIFromEntity(entityType int, entityID int) (string, error)

uri = fmt.Sprintf(cluster.EntityURIs[entityType], networkName, projectName)
case cluster.TypeNetworkACL:
networkACLName, projectName, err := c.GetNetworkACLNameAndProjectWithID(entityID)
var networkACLName string
var projectName string

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error

networkACLName, projectName, err = tx.GetNetworkACLNameAndProjectWithID(ctx, entityID)

return err
})
if err != nil {
return "", fmt.Errorf("Failed to get network ACL name and project name: %w", err)
}
Expand Down
153 changes: 66 additions & 87 deletions internal/server/db/network_acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,26 @@ import (
)

// GetNetworkACLs returns the names of existing Network ACLs.
func (c *Cluster) GetNetworkACLs(project string) ([]string, error) {
func (c *ClusterTx) GetNetworkACLs(ctx context.Context, project string) ([]string, error) {
q := `SELECT name FROM networks_acls
WHERE project_id = (SELECT id FROM projects WHERE name = ? LIMIT 1)
ORDER BY id
`

var aclNames []string

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return query.Scan(ctx, tx.Tx(), q, func(scan func(dest ...any) error) error {
var aclName string
err := query.Scan(ctx, c.tx, q, func(scan func(dest ...any) error) error {
var aclName string

err := scan(&aclName)
if err != nil {
return err
}
err := scan(&aclName)
if err != nil {
return err
}

aclNames = append(aclNames, aclName)
aclNames = append(aclNames, aclName)

return nil
}, project)
})
return nil
}, project)
if err != nil {
return nil, err
}
Expand All @@ -45,29 +43,27 @@ func (c *Cluster) GetNetworkACLs(project string) ([]string, error) {
}

// GetNetworkACLIDsByNames returns a map of names to IDs of existing Network ACLs.
func (c *Cluster) GetNetworkACLIDsByNames(project string) (map[string]int64, error) {
func (c *ClusterTx) GetNetworkACLIDsByNames(ctx context.Context, project string) (map[string]int64, error) {
q := `SELECT id, name FROM networks_acls
WHERE project_id = (SELECT id FROM projects WHERE name = ? LIMIT 1)
ORDER BY id
`

acls := make(map[string]int64)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return query.Scan(ctx, tx.Tx(), q, func(scan func(dest ...any) error) error {
var aclID int64
var aclName string
err := query.Scan(ctx, c.tx, q, func(scan func(dest ...any) error) error {
var aclID int64
var aclName string

err := scan(&aclID, &aclName)
if err != nil {
return err
}
err := scan(&aclID, &aclName)
if err != nil {
return err
}

acls[aclName] = aclID
acls[aclName] = aclID

return nil
}, project)
})
return nil
}, project)
if err != nil {
return nil, err
}
Expand All @@ -76,7 +72,7 @@ func (c *Cluster) GetNetworkACLIDsByNames(project string) (map[string]int64, err
}

// GetNetworkACL returns the Network ACL with the given name in the given project.
func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.NetworkACL, error) {
func (c *ClusterTx) GetNetworkACL(ctx context.Context, projectName string, name string) (int64, *api.NetworkACL, error) {
var id int64 = int64(-1)
var ingressJSON string
var egressJSON string
Expand All @@ -94,25 +90,22 @@ func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.Ne
LIMIT 1
`

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
err := tx.tx.QueryRowContext(ctx, q, projectName, name).Scan(&id, &acl.Description, &ingressJSON, &egressJSON)
if err != nil {
return err
err := c.tx.QueryRowContext(ctx, q, projectName, name).Scan(&id, &acl.Description, &ingressJSON, &egressJSON)
if err != nil {
if err == sql.ErrNoRows {
return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network ACL not found")
}

err = networkACLConfig(ctx, tx, id, &acl)
if err != nil {
return fmt.Errorf("Failed loading config: %w", err)
}
return -1, nil, err
}

return nil
})
err = networkACLConfig(ctx, c, id, &acl)
if err != nil {
if err == sql.ErrNoRows {
return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network ACL not found")
}

return -1, nil, err
return -1, nil, fmt.Errorf("Failed loading config: %w", err)
}

acl.Ingress = []api.NetworkACLRule{}
Expand All @@ -135,15 +128,13 @@ func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.Ne
}

// GetNetworkACLNameAndProjectWithID returns the network ACL name and project name for the given ID.
func (c *Cluster) GetNetworkACLNameAndProjectWithID(networkACLID int) (string, string, error) {
func (c *ClusterTx) GetNetworkACLNameAndProjectWithID(ctx context.Context, networkACLID int) (string, string, error) {
var networkACLName string
var projectName string

q := `SELECT networks_acls.name, projects.name FROM networks_acls JOIN projects ON projects.id=networks.project_id WHERE networks_acls.id=?`

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return tx.tx.QueryRowContext(ctx, q, networkACLID).Scan(&networkACLName, &projectName)
})
err := c.tx.QueryRowContext(ctx, q, networkACLID).Scan(&networkACLName, &projectName)
if err != nil {
if err == sql.ErrNoRows {
return "", "", api.StatusErrorf(http.StatusNotFound, "Network ACL not found")
Expand Down Expand Up @@ -184,8 +175,7 @@ func networkACLConfig(ctx context.Context, tx *ClusterTx, id int64, acl *api.Net
}

// CreateNetworkACL creates a new Network ACL.
func (c *Cluster) CreateNetworkACL(projectName string, info *api.NetworkACLsPost) (int64, error) {
var id int64
func (c *ClusterTx) CreateNetworkACL(ctx context.Context, projectName string, info *api.NetworkACLsPost) (int64, error) {
var err error
var ingressJSON, egressJSON []byte

Expand All @@ -203,30 +193,23 @@ func (c *Cluster) CreateNetworkACL(projectName string, info *api.NetworkACLsPost
}
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Insert a new Network ACL record.
result, err := tx.tx.Exec(`
// Insert a new Network ACL record.
result, err := c.tx.ExecContext(ctx, `
INSERT INTO networks_acls (project_id, name, description, ingress, egress)
VALUES ((SELECT id FROM projects WHERE name = ? LIMIT 1), ?, ?, ?, ?)
`, projectName, info.Name, info.Description, string(ingressJSON), string(egressJSON))
if err != nil {
return err
}

id, err := result.LastInsertId()
if err != nil {
return err
}
if err != nil {
return -1, err
}

err = networkACLConfigAdd(tx.tx, id, info.Config)
if err != nil {
return err
}
id, err := result.LastInsertId()
if err != nil {
return -1, err
}

return nil
})
err = networkACLConfigAdd(c.tx, id, info.Config)
if err != nil {
id = -1
return -1, err
}

return id, err
Expand Down Expand Up @@ -257,7 +240,7 @@ func networkACLConfigAdd(tx *sql.Tx, id int64, config map[string]string) error {
}

// UpdateNetworkACL updates the Network ACL with the given ID.
func (c *Cluster) UpdateNetworkACL(id int64, config *api.NetworkACLPut) error {
func (c *ClusterTx) UpdateNetworkACL(ctx context.Context, id int64, config *api.NetworkACLPut) error {
var err error
var ingressJSON, egressJSON []byte

Expand All @@ -275,44 +258,40 @@ func (c *Cluster) UpdateNetworkACL(id int64, config *api.NetworkACLPut) error {
}
}

return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.Exec(`
_, err = c.tx.ExecContext(ctx, `
UPDATE networks_acls
SET description=?, ingress = ?, egress = ?
WHERE id=?
`, config.Description, ingressJSON, egressJSON, id)
if err != nil {
return err
}
if err != nil {
return err
}

_, err = tx.tx.Exec("DELETE FROM networks_acls_config WHERE network_acl_id=?", id)
if err != nil {
return err
}
_, err = c.tx.ExecContext(ctx, "DELETE FROM networks_acls_config WHERE network_acl_id=?", id)
if err != nil {
return err
}

err = networkACLConfigAdd(tx.tx, id, config.Config)
if err != nil {
return err
}
err = networkACLConfigAdd(c.tx, id, config.Config)
if err != nil {
return err
}

return nil
})
return nil
}

// RenameNetworkACL renames a Network ACL.
func (c *Cluster) RenameNetworkACL(id int64, newName string) error {
return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.Exec("UPDATE networks_acls SET name=? WHERE id=?", newName, id)
return err
})
func (c *ClusterTx) RenameNetworkACL(ctx context.Context, id int64, newName string) error {
_, err := c.tx.ExecContext(ctx, "UPDATE networks_acls SET name=? WHERE id=?", newName, id)

return err
}

// DeleteNetworkACL deletes the Network ACL.
func (c *Cluster) DeleteNetworkACL(id int64) error {
return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.Exec("DELETE FROM networks_acls WHERE id=?", id)
return err
})
func (c *ClusterTx) DeleteNetworkACL(ctx context.Context, id int64) error {
_, err := c.tx.ExecContext(ctx, "DELETE FROM networks_acls WHERE id=?", id)

return err
}

// GetNetworkACLURIs returns the URIs for the network ACLs with the given project.
Expand Down
Loading

0 comments on commit dd2fc19

Please sign in to comment.