From dd2fc19b8551ae3588a3d4319df27d499f50af40 Mon Sep 17 00:00:00 2001 From: Thomas Hipp Date: Mon, 11 Dec 2023 23:22:15 +0100 Subject: [PATCH] Move db network ACL functions to ClusterTx Signed-off-by: Thomas Hipp --- cmd/incusd/network_acls.go | 14 +- cmd/incusd/patches.go | 60 ++++---- internal/server/db/entity.go | 11 +- internal/server/db/network_acls.go | 153 ++++++++----------- internal/server/network/acl/acl_firewall.go | 12 +- internal/server/network/acl/acl_load.go | 47 +++++- internal/server/network/acl/acl_ovn.go | 28 +++- internal/server/network/acl/driver_common.go | 44 ++++-- internal/server/network/driver_common.go | 20 ++- internal/server/network/driver_ovn.go | 32 +++- 10 files changed, 272 insertions(+), 149 deletions(-) diff --git a/cmd/incusd/network_acls.go b/cmd/incusd/network_acls.go index 709b337ae52..f238a84a310 100644 --- a/cmd/incusd/network_acls.go +++ b/cmd/incusd/network_acls.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/lxc/incus/internal/server/auth" clusterRequest "github.com/lxc/incus/internal/server/cluster/request" + "github.com/lxc/incus/internal/server/db" "github.com/lxc/incus/internal/server/lifecycle" "github.com/lxc/incus/internal/server/network/acl" "github.com/lxc/incus/internal/server/project" @@ -150,8 +152,16 @@ func networkACLsGet(d *Daemon, r *http.Request) response.Response { recursion := localUtil.IsRecursionRequest(r) - // Get list of Network ACLs. - aclNames, err := s.DB.Cluster.GetNetworkACLs(projectName) + var aclNames []string + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Get list of Network ACLs. + aclNames, err = tx.GetNetworkACLs(ctx, projectName) + + return err + }) if err != nil { return response.InternalError(err) } diff --git a/cmd/incusd/patches.go b/cmd/incusd/patches.go index e8a8145412d..63fa0263978 100644 --- a/cmd/incusd/patches.go +++ b/cmd/incusd/patches.go @@ -290,44 +290,50 @@ func patchNetworkACLRemoveDefaults(name string, d *Daemon) error { return err } - // Get ACLs in projects. - for _, projectName := range projectNames { - aclNames, err := d.db.Cluster.GetNetworkACLs(projectName) - if err != nil { - return err - } - - for _, aclName := range aclNames { - aclID, acl, err := d.db.Cluster.GetNetworkACL(projectName, aclName) + err = d.db.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get ACLs in projects. + for _, projectName := range projectNames { + aclNames, err := tx.GetNetworkACLs(ctx, projectName) if err != nil { return err } - modified := false + for _, aclName := range aclNames { + aclID, acl, err := tx.GetNetworkACL(ctx, projectName, aclName) + if err != nil { + return err + } - // Remove the offending keys if found. - _, found := acl.Config["default.action"] - if found { - delete(acl.Config, "default.action") - modified = true - } + modified := false - _, found = acl.Config["default.logged"] - if found { - delete(acl.Config, "default.logged") - modified = true - } + // Remove the offending keys if found. + _, found := acl.Config["default.action"] + if found { + delete(acl.Config, "default.action") + modified = true + } - // Write back modified config if needed. - if modified { - err = d.db.Cluster.UpdateNetworkACL(aclID, &acl.NetworkACLPut) - if err != nil { - return fmt.Errorf("Failed updating network ACL %d: %w", aclID, err) + _, found = acl.Config["default.logged"] + if found { + delete(acl.Config, "default.logged") + modified = true + } + + // Write back modified config if needed. + if modified { + err = tx.UpdateNetworkACL(ctx, aclID, &acl.NetworkACLPut) + if err != nil { + return fmt.Errorf("Failed updating network ACL %d: %w", aclID, err) + } } } } - } + return nil + }) + if err != nil { + return err + } return nil } diff --git a/internal/server/db/entity.go b/internal/server/db/entity.go index 5b59df7c1ea..d44aa0d711c 100644 --- a/internal/server/db/entity.go +++ b/internal/server/db/entity.go @@ -233,7 +233,16 @@ func (c *Cluster) GetURIFromEntity(entityType int, entityID int) (string, error) uri = fmt.Sprintf(cluster.EntityURIs[entityType], networkName, projectName) case cluster.TypeNetworkACL: - networkACLName, projectName, err := c.GetNetworkACLNameAndProjectWithID(entityID) + var networkACLName string + var projectName string + + err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { + var err error + + networkACLName, projectName, err = tx.GetNetworkACLNameAndProjectWithID(ctx, entityID) + + return err + }) if err != nil { return "", fmt.Errorf("Failed to get network ACL name and project name: %w", err) } diff --git a/internal/server/db/network_acls.go b/internal/server/db/network_acls.go index c3a190bfb9e..1dd9ace26b9 100644 --- a/internal/server/db/network_acls.go +++ b/internal/server/db/network_acls.go @@ -15,7 +15,7 @@ import ( ) // GetNetworkACLs returns the names of existing Network ACLs. -func (c *Cluster) GetNetworkACLs(project string) ([]string, error) { +func (c *ClusterTx) GetNetworkACLs(ctx context.Context, project string) ([]string, error) { q := `SELECT name FROM networks_acls WHERE project_id = (SELECT id FROM projects WHERE name = ? LIMIT 1) ORDER BY id @@ -23,20 +23,18 @@ func (c *Cluster) GetNetworkACLs(project string) ([]string, error) { var aclNames []string - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return query.Scan(ctx, tx.Tx(), q, func(scan func(dest ...any) error) error { - var aclName string + err := query.Scan(ctx, c.tx, q, func(scan func(dest ...any) error) error { + var aclName string - err := scan(&aclName) - if err != nil { - return err - } + err := scan(&aclName) + if err != nil { + return err + } - aclNames = append(aclNames, aclName) + aclNames = append(aclNames, aclName) - return nil - }, project) - }) + return nil + }, project) if err != nil { return nil, err } @@ -45,7 +43,7 @@ func (c *Cluster) GetNetworkACLs(project string) ([]string, error) { } // GetNetworkACLIDsByNames returns a map of names to IDs of existing Network ACLs. -func (c *Cluster) GetNetworkACLIDsByNames(project string) (map[string]int64, error) { +func (c *ClusterTx) GetNetworkACLIDsByNames(ctx context.Context, project string) (map[string]int64, error) { q := `SELECT id, name FROM networks_acls WHERE project_id = (SELECT id FROM projects WHERE name = ? LIMIT 1) ORDER BY id @@ -53,21 +51,19 @@ func (c *Cluster) GetNetworkACLIDsByNames(project string) (map[string]int64, err acls := make(map[string]int64) - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return query.Scan(ctx, tx.Tx(), q, func(scan func(dest ...any) error) error { - var aclID int64 - var aclName string + err := query.Scan(ctx, c.tx, q, func(scan func(dest ...any) error) error { + var aclID int64 + var aclName string - err := scan(&aclID, &aclName) - if err != nil { - return err - } + err := scan(&aclID, &aclName) + if err != nil { + return err + } - acls[aclName] = aclID + acls[aclName] = aclID - return nil - }, project) - }) + return nil + }, project) if err != nil { return nil, err } @@ -76,7 +72,7 @@ func (c *Cluster) GetNetworkACLIDsByNames(project string) (map[string]int64, err } // GetNetworkACL returns the Network ACL with the given name in the given project. -func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.NetworkACL, error) { +func (c *ClusterTx) GetNetworkACL(ctx context.Context, projectName string, name string) (int64, *api.NetworkACL, error) { var id int64 = int64(-1) var ingressJSON string var egressJSON string @@ -94,25 +90,22 @@ func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.Ne LIMIT 1 ` - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - err := tx.tx.QueryRowContext(ctx, q, projectName, name).Scan(&id, &acl.Description, &ingressJSON, &egressJSON) - if err != nil { - return err + err := c.tx.QueryRowContext(ctx, q, projectName, name).Scan(&id, &acl.Description, &ingressJSON, &egressJSON) + if err != nil { + if err == sql.ErrNoRows { + return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network ACL not found") } - err = networkACLConfig(ctx, tx, id, &acl) - if err != nil { - return fmt.Errorf("Failed loading config: %w", err) - } + return -1, nil, err + } - return nil - }) + err = networkACLConfig(ctx, c, id, &acl) if err != nil { if err == sql.ErrNoRows { return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network ACL not found") } - return -1, nil, err + return -1, nil, fmt.Errorf("Failed loading config: %w", err) } acl.Ingress = []api.NetworkACLRule{} @@ -135,15 +128,13 @@ func (c *Cluster) GetNetworkACL(projectName string, name string) (int64, *api.Ne } // GetNetworkACLNameAndProjectWithID returns the network ACL name and project name for the given ID. -func (c *Cluster) GetNetworkACLNameAndProjectWithID(networkACLID int) (string, string, error) { +func (c *ClusterTx) GetNetworkACLNameAndProjectWithID(ctx context.Context, networkACLID int) (string, string, error) { var networkACLName string var projectName string q := `SELECT networks_acls.name, projects.name FROM networks_acls JOIN projects ON projects.id=networks.project_id WHERE networks_acls.id=?` - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return tx.tx.QueryRowContext(ctx, q, networkACLID).Scan(&networkACLName, &projectName) - }) + err := c.tx.QueryRowContext(ctx, q, networkACLID).Scan(&networkACLName, &projectName) if err != nil { if err == sql.ErrNoRows { return "", "", api.StatusErrorf(http.StatusNotFound, "Network ACL not found") @@ -184,8 +175,7 @@ func networkACLConfig(ctx context.Context, tx *ClusterTx, id int64, acl *api.Net } // CreateNetworkACL creates a new Network ACL. -func (c *Cluster) CreateNetworkACL(projectName string, info *api.NetworkACLsPost) (int64, error) { - var id int64 +func (c *ClusterTx) CreateNetworkACL(ctx context.Context, projectName string, info *api.NetworkACLsPost) (int64, error) { var err error var ingressJSON, egressJSON []byte @@ -203,30 +193,23 @@ func (c *Cluster) CreateNetworkACL(projectName string, info *api.NetworkACLsPost } } - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - // Insert a new Network ACL record. - result, err := tx.tx.Exec(` + // Insert a new Network ACL record. + result, err := c.tx.ExecContext(ctx, ` INSERT INTO networks_acls (project_id, name, description, ingress, egress) VALUES ((SELECT id FROM projects WHERE name = ? LIMIT 1), ?, ?, ?, ?) `, projectName, info.Name, info.Description, string(ingressJSON), string(egressJSON)) - if err != nil { - return err - } - - id, err := result.LastInsertId() - if err != nil { - return err - } + if err != nil { + return -1, err + } - err = networkACLConfigAdd(tx.tx, id, info.Config) - if err != nil { - return err - } + id, err := result.LastInsertId() + if err != nil { + return -1, err + } - return nil - }) + err = networkACLConfigAdd(c.tx, id, info.Config) if err != nil { - id = -1 + return -1, err } return id, err @@ -257,7 +240,7 @@ func networkACLConfigAdd(tx *sql.Tx, id int64, config map[string]string) error { } // UpdateNetworkACL updates the Network ACL with the given ID. -func (c *Cluster) UpdateNetworkACL(id int64, config *api.NetworkACLPut) error { +func (c *ClusterTx) UpdateNetworkACL(ctx context.Context, id int64, config *api.NetworkACLPut) error { var err error var ingressJSON, egressJSON []byte @@ -275,44 +258,40 @@ func (c *Cluster) UpdateNetworkACL(id int64, config *api.NetworkACLPut) error { } } - return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - _, err := tx.tx.Exec(` + _, err = c.tx.ExecContext(ctx, ` UPDATE networks_acls SET description=?, ingress = ?, egress = ? WHERE id=? `, config.Description, ingressJSON, egressJSON, id) - if err != nil { - return err - } + if err != nil { + return err + } - _, err = tx.tx.Exec("DELETE FROM networks_acls_config WHERE network_acl_id=?", id) - if err != nil { - return err - } + _, err = c.tx.ExecContext(ctx, "DELETE FROM networks_acls_config WHERE network_acl_id=?", id) + if err != nil { + return err + } - err = networkACLConfigAdd(tx.tx, id, config.Config) - if err != nil { - return err - } + err = networkACLConfigAdd(c.tx, id, config.Config) + if err != nil { + return err + } - return nil - }) + return nil } // RenameNetworkACL renames a Network ACL. -func (c *Cluster) RenameNetworkACL(id int64, newName string) error { - return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - _, err := tx.tx.Exec("UPDATE networks_acls SET name=? WHERE id=?", newName, id) - return err - }) +func (c *ClusterTx) RenameNetworkACL(ctx context.Context, id int64, newName string) error { + _, err := c.tx.ExecContext(ctx, "UPDATE networks_acls SET name=? WHERE id=?", newName, id) + + return err } // DeleteNetworkACL deletes the Network ACL. -func (c *Cluster) DeleteNetworkACL(id int64) error { - return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - _, err := tx.tx.Exec("DELETE FROM networks_acls WHERE id=?", id) - return err - }) +func (c *ClusterTx) DeleteNetworkACL(ctx context.Context, id int64) error { + _, err := c.tx.ExecContext(ctx, "DELETE FROM networks_acls WHERE id=?", id) + + return err } // GetNetworkACLURIs returns the URIs for the network ACLs with the given project. diff --git a/internal/server/network/acl/acl_firewall.go b/internal/server/network/acl/acl_firewall.go index fb2e6a15be5..10be47c7d9a 100644 --- a/internal/server/network/acl/acl_firewall.go +++ b/internal/server/network/acl/acl_firewall.go @@ -1,8 +1,10 @@ package acl import ( + "context" "fmt" + "github.com/lxc/incus/internal/server/db" firewallDrivers "github.com/lxc/incus/internal/server/firewall/drivers" "github.com/lxc/incus/internal/server/state" "github.com/lxc/incus/shared/api" @@ -60,7 +62,15 @@ func FirewallApplyACLRules(s *state.State, logger logger.Logger, aclProjectName // Load ACLs specified by network. for _, aclName := range util.SplitNTrimSpace(aclNet.Config["security.acls"], ",", -1, true) { - _, aclInfo, err := s.DB.Cluster.GetNetworkACL(aclProjectName, aclName) + var aclInfo *api.NetworkACL + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + _, aclInfo, err = tx.GetNetworkACL(ctx, aclProjectName, aclName) + + return err + }) if err != nil { return fmt.Errorf("Failed loading ACL %q for network %q: %w", aclName, aclNet.Name, err) } diff --git a/internal/server/network/acl/acl_load.go b/internal/server/network/acl/acl_load.go index 1fcda5576eb..9de3b72e154 100644 --- a/internal/server/network/acl/acl_load.go +++ b/internal/server/network/acl/acl_load.go @@ -16,7 +16,16 @@ import ( // LoadByName loads and initialises a Network ACL from the database by project and name. func LoadByName(s *state.State, projectName string, name string) (NetworkACL, error) { - id, aclInfo, err := s.DB.Cluster.GetNetworkACL(projectName, name) + var id int64 + var aclInfo *api.NetworkACL + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + id, aclInfo, err = tx.GetNetworkACL(ctx, projectName, name) + + return err + }) if err != nil { return nil, err } @@ -42,8 +51,12 @@ func Create(s *state.State, projectName string, aclInfo *api.NetworkACLsPost) er return err } - // Insert DB record. - _, err = s.DB.Cluster.CreateNetworkACL(projectName, aclInfo) + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Insert DB record. + _, err := tx.CreateNetworkACL(ctx, projectName, aclInfo) + + return err + }) if err != nil { return err } @@ -54,7 +67,15 @@ func Create(s *state.State, projectName string, aclInfo *api.NetworkACLsPost) er // Exists checks the ACL name(s) provided exists in the project. // If multiple names are provided, also checks that duplicate names aren't specified in the list. func Exists(s *state.State, projectName string, name ...string) error { - existingACLNames, err := s.DB.Cluster.GetNetworkACLs(projectName) + var existingACLNames []string + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + existingACLNames, err = tx.GetNetworkACLs(ctx, projectName) + + return err + }) if err != nil { return err } @@ -160,14 +181,26 @@ func UsedBy(s *state.State, aclProjectName string, usageFunc func(matchedACLName } } - // Find ACLs that have rules that reference the ACLs. - aclNames, err := s.DB.Cluster.GetNetworkACLs(aclProjectName) + var aclNames []string + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Find ACLs that have rules that reference the ACLs. + aclNames, err = tx.GetNetworkACLs(ctx, aclProjectName) + + return err + }) if err != nil { return err } for _, aclName := range aclNames { - _, aclInfo, err := s.DB.Cluster.GetNetworkACL(aclProjectName, aclName) + var aclInfo *api.NetworkACL + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, aclInfo, err = tx.GetNetworkACL(ctx, aclProjectName, aclName) + + return err + }) if err != nil { return err } diff --git a/internal/server/network/acl/acl_ovn.go b/internal/server/network/acl/acl_ovn.go index 45c483f4f19..cb170311a2a 100644 --- a/internal/server/network/acl/acl_ovn.go +++ b/internal/server/network/acl/acl_ovn.go @@ -133,8 +133,14 @@ func OVNEnsureACLs(s *state.State, l logger.Logger, client *openvswitch.OVN, acl } if portGroupUUID == "" { - // Load the config we'll need to create the port group with ACL rules. - _, aclInfo, err := s.DB.Cluster.GetNetworkACL(aclProjectName, aclName) + var aclInfo *api.NetworkACL + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Load the config we'll need to create the port group with ACL rules. + _, aclInfo, err = tx.GetNetworkACL(ctx, aclProjectName, aclName) + + return err + }) if err != nil { return nil, fmt.Errorf("Failed loading Network ACL %q: %w", aclName, err) } @@ -164,7 +170,11 @@ func OVNEnsureACLs(s *state.State, l logger.Logger, client *openvswitch.OVN, acl // the default rule we add. We also need to reapply the rules if we are adding any // new per-ACL-per-network port groups. if reapplyRules || !portGroupHasACLs || len(addACLNets) > 0 { - _, aclInfo, err = s.DB.Cluster.GetNetworkACL(aclProjectName, aclName) + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, aclInfo, err = tx.GetNetworkACL(ctx, aclProjectName, aclName) + + return err + }) if err != nil { return nil, fmt.Errorf("Failed loading Network ACL %q: %w", aclName, err) } @@ -747,8 +757,16 @@ func OVNApplyNetworkBaselineRules(client *openvswitch.OVN, switchName openvswitc // the desired ACLs are considered unused by the usage type even if the referring config has not yet been removed // from the database. func OVNPortGroupDeleteIfUnused(s *state.State, l logger.Logger, client *openvswitch.OVN, aclProjectName string, ignoreUsageType any, ignoreUsageNicName string, keepACLs ...string) error { - // Get map of ACL names to DB IDs (used for generating OVN port group names). - aclNameIDs, err := s.DB.Cluster.GetNetworkACLIDsByNames(aclProjectName) + var aclNameIDs map[string]int64 + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Get map of ACL names to DB IDs (used for generating OVN port group names). + aclNameIDs, err = tx.GetNetworkACLIDsByNames(ctx, aclProjectName) + + return err + }) if err != nil { return fmt.Errorf("Failed getting network ACL IDs for security ACL port group removal: %w", err) } diff --git a/internal/server/network/acl/driver_common.go b/internal/server/network/acl/driver_common.go index ce3bee36b06..a3d6a03fa19 100644 --- a/internal/server/network/acl/driver_common.go +++ b/internal/server/network/acl/driver_common.go @@ -2,6 +2,7 @@ package acl import ( "bufio" + "context" "fmt" "net" "os" @@ -299,8 +300,16 @@ func (d *common) validateRule(direction ruleDirection, rule api.NetworkACLRule) return fmt.Errorf("State must be one of: %s", strings.Join(validStates, ", ")) } - // Get map of ACL names to DB IDs (used for generating OVN port group names). - acls, err := d.state.DB.Cluster.GetNetworkACLIDsByNames(d.Project()) + var acls map[string]int64 + + err := d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Get map of ACL names to DB IDs (used for generating OVN port group names). + acls, err = tx.GetNetworkACLIDsByNames(ctx, d.Project()) + + return err + }) if err != nil { return fmt.Errorf("Failed getting network ACLs for security ACL subject validation: %w", err) } @@ -583,9 +592,11 @@ func (d *common) Update(config *api.NetworkACLPut, clientType request.ClientType if clientType == request.ClientTypeNormal { oldConfig := d.info.NetworkACLPut - // Update database. Its important this occurs before we attempt to apply to networks using the ACL - // as usage functions will inspect the database. - err = d.state.DB.Cluster.UpdateNetworkACL(d.id, config) + err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Update database. Its important this occurs before we attempt to apply to networks using the ACL + // as usage functions will inspect the database. + return tx.UpdateNetworkACL(ctx, d.id, config) + }) if err != nil { return err } @@ -595,7 +606,10 @@ func (d *common) Update(config *api.NetworkACLPut, clientType request.ClientType d.init(d.state, d.id, d.projectName, d.info) revert.Add(func() { - _ = d.state.DB.Cluster.UpdateNetworkACL(d.id, &oldConfig) + _ = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + return tx.UpdateNetworkACL(ctx, d.id, &oldConfig) + }) + d.info.NetworkACLPut = oldConfig d.init(d.state, d.id, d.projectName, d.info) }) @@ -636,8 +650,14 @@ func (d *common) Update(config *api.NetworkACLPut, clientType request.ClientType return fmt.Errorf("Failed to get OVN client: %w", err) } - // Get map of ACL names to DB IDs (used for generating OVN port group names). - aclNameIDs, err := d.state.DB.Cluster.GetNetworkACLIDsByNames(d.Project()) + var aclNameIDs map[string]int64 + + err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get map of ACL names to DB IDs (used for generating OVN port group names). + aclNameIDs, err = tx.GetNetworkACLIDsByNames(ctx, d.Project()) + + return err + }) if err != nil { return fmt.Errorf("Failed getting network ACL IDs for security ACL update: %w", err) } @@ -704,7 +724,9 @@ func (d *common) Rename(newName string) error { return err } - err = d.state.DB.Cluster.RenameNetworkACL(d.id, newName) + err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + return tx.RenameNetworkACL(ctx, d.id, newName) + }) if err != nil { return err } @@ -726,7 +748,9 @@ func (d *common) Delete() error { return fmt.Errorf("Cannot delete an ACL that is in use") } - return d.state.DB.Cluster.DeleteNetworkACL(d.id) + return d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + return tx.DeleteNetworkACL(ctx, d.id) + }) } // GetLog gets the ACL log. diff --git a/internal/server/network/driver_common.go b/internal/server/network/driver_common.go index 65184733fe6..8b077069063 100644 --- a/internal/server/network/driver_common.go +++ b/internal/server/network/driver_common.go @@ -1460,14 +1460,28 @@ func (n *common) peerUsedBy(peerName string, firstOnly bool) ([]string, error) { return false } - // Find ACLs that have rules that reference the peer connection. - aclNames, err := n.state.DB.Cluster.GetNetworkACLs(n.Project()) + var aclNames []string + + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Find ACLs that have rules that reference the peer connection. + aclNames, err = tx.GetNetworkACLs(ctx, n.Project()) + + return err + }) if err != nil { return nil, err } for _, aclName := range aclNames { - _, aclInfo, err := n.state.DB.Cluster.GetNetworkACL(n.Project(), aclName) + var aclInfo *api.NetworkACL + + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, aclInfo, err = tx.GetNetworkACL(ctx, n.Project(), aclName) + + return err + }) if err != nil { return nil, err } diff --git a/internal/server/network/driver_ovn.go b/internal/server/network/driver_ovn.go index 60404d9753f..4cb61282995 100644 --- a/internal/server/network/driver_ovn.go +++ b/internal/server/network/driver_ovn.go @@ -2391,8 +2391,16 @@ func (n *ovn) setup(update bool) error { // Ensure any network assigned security ACL port groups are created ready for instance NICs to use. securityACLS := util.SplitNTrimSpace(n.config["security.acls"], ",", -1, true) if len(securityACLS) > 0 { - // Get map of ACL names to DB IDs (used for generating OVN port group names). - aclNameIDs, err := n.state.DB.Cluster.GetNetworkACLIDsByNames(n.Project()) + var aclNameIDs map[string]int64 + + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Get map of ACL names to DB IDs (used for generating OVN port group names). + aclNameIDs, err = tx.GetNetworkACLIDsByNames(ctx, n.Project()) + + return err + }) if err != nil { return fmt.Errorf("Failed getting network ACL IDs for security ACL setup: %w", err) } @@ -2969,8 +2977,14 @@ func (n *ovn) Update(newNetwork api.NetworkPut, targetNode string, clientType re } } - // Get map of ACL names to DB IDs (used for generating OVN port group names). - aclNameIDs, err := n.state.DB.Cluster.GetNetworkACLIDsByNames(n.Project()) + var aclNameIDs map[string]int64 + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get map of ACL names to DB IDs (used for generating OVN port group names). + aclNameIDs, err = tx.GetNetworkACLIDsByNames(ctx, n.Project()) + + return err + }) if err != nil { return fmt.Errorf("Failed getting network ACL IDs for security ACL update: %w", err) } @@ -3821,8 +3835,14 @@ func (n *ovn) InstanceDevicePortStart(opts *OVNInstanceNICSetupOpts, securityACL n.logger.Debug("Scheduled logical port for network port group addition", logger.Ctx{"portGroup": acl.OVNIntSwitchPortGroupName(n.ID()), "port": instancePortName}) if len(nicACLNames) > 0 || len(securityACLsRemove) > 0 { - // Get map of ACL names to DB IDs (used for generating OVN port group names). - aclNameIDs, err := n.state.DB.Cluster.GetNetworkACLIDsByNames(n.Project()) + var aclNameIDs map[string]int64 + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get map of ACL names to DB IDs (used for generating OVN port group names). + aclNameIDs, err = tx.GetNetworkACLIDsByNames(ctx, n.Project()) + + return err + }) if err != nil { return "", nil, fmt.Errorf("Failed getting network ACL IDs for security ACL setup: %w", err) }