diff --git a/cmd/incusd/api_cluster.go b/cmd/incusd/api_cluster.go index 59ab01b4c28..0d61e4da51c 100644 --- a/cmd/incusd/api_cluster.go +++ b/cmd/incusd/api_cluster.go @@ -614,39 +614,43 @@ func clusterPutJoin(d *Daemon, r *http.Request, req api.ClusterPut) response.Res // Get a list of projects for networks. var projects []dbCluster.Project + networks := []api.InitNetworksProjectPost{} err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { projects, err = dbCluster.GetProjects(ctx, tx.Tx()) - return err - }) - if err != nil { - return fmt.Errorf("Failed to load projects for networks: %w", err) - } - - networks := []api.InitNetworksProjectPost{} - for _, p := range projects { - networkNames, err := s.DB.Cluster.GetNetworks(p.Name) - if err != nil && !response.IsNotFoundError(err) { - return err + if err != nil { + return fmt.Errorf("Failed to load projects for networks: %w", err) } - for _, name := range networkNames { - _, network, _, err := s.DB.Cluster.GetNetworkInAnyState(p.Name, name) - if err != nil { + for _, p := range projects { + networkNames, err := tx.GetNetworks(ctx, p.Name) + if err != nil && !response.IsNotFoundError(err) { return err } - internalNetwork := api.InitNetworksProjectPost{ - NetworksPost: api.NetworksPost{ - NetworkPut: network.NetworkPut, - Name: network.Name, - Type: network.Type, - }, - Project: p.Name, - } + for _, name := range networkNames { + _, network, _, err := tx.GetNetworkInAnyState(ctx, p.Name, name) + if err != nil { + return err + } + + internalNetwork := api.InitNetworksProjectPost{ + NetworksPost: api.NetworksPost{ + NetworkPut: network.NetworkPut, + Name: network.Name, + Type: network.Type, + }, + Project: p.Name, + } - networks = append(networks, internalNetwork) + networks = append(networks, internalNetwork) + } } + + return nil + }) + if err != nil { + return err } // Now request for this node to be added to the list of cluster nodes. @@ -2061,7 +2065,13 @@ func clusterNodeDelete(d *Daemon, r *http.Request) response.Response { } for _, networkProjectName := range networkProjectNames { - networks, err := s.DB.Cluster.GetNetworks(networkProjectName) + var networks []string + + err := s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + networks, err = tx.GetNetworks(ctx, networkProjectName) + + return err + }) if err != nil { return response.SmartError(err) } @@ -2774,52 +2784,56 @@ func clusterCheckNetworksMatch(cluster *db.Cluster, reqNetworks []api.InitNetwor var networkProjectNames []string err = cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { - networkProjectNames, err = dbCluster.GetProjectNames(context.Background(), tx.Tx()) - return err - }) - if err != nil { - return fmt.Errorf("Failed to load projects for networks: %w", err) - } - - for _, networkProjectName := range networkProjectNames { - networkNames, err := cluster.GetCreatedNetworks(networkProjectName) - if err != nil && !response.IsNotFoundError(err) { - return err + networkProjectNames, err = dbCluster.GetProjectNames(ctx, tx.Tx()) + if err != nil { + return fmt.Errorf("Failed to load projects for networks: %w", err) } - for _, networkName := range networkNames { - found := false + for _, networkProjectName := range networkProjectNames { + networkNames, err := tx.GetCreatedNetworkNamesByProject(ctx, networkProjectName) + if err != nil && !response.IsNotFoundError(err) { + return err + } - for _, reqNetwork := range reqNetworks { - if reqNetwork.Name != networkName || reqNetwork.Project != networkProjectName { - continue - } + for _, networkName := range networkNames { + found := false - found = true + for _, reqNetwork := range reqNetworks { + if reqNetwork.Name != networkName || reqNetwork.Project != networkProjectName { + continue + } - _, network, _, err := cluster.GetNetworkInAnyState(networkProjectName, networkName) - if err != nil { - return err - } + found = true - if reqNetwork.Type != network.Type { - return fmt.Errorf("Mismatching type for network %q in project %q", networkName, networkProjectName) - } + _, network, _, err := tx.GetNetworkInAnyState(ctx, networkProjectName, networkName) + if err != nil { + return err + } - // Exclude the keys which are node-specific. - exclude := db.NodeSpecificNetworkConfig - err = localUtil.CompareConfigs(network.Config, reqNetwork.Config, exclude) - if err != nil { - return fmt.Errorf("Mismatching config for network %q in project %q: %w", networkName, networkProjectName, err) - } + if reqNetwork.Type != network.Type { + return fmt.Errorf("Mismatching type for network %q in project %q", networkName, networkProjectName) + } - break - } + // Exclude the keys which are node-specific. + exclude := db.NodeSpecificNetworkConfig + err = localUtil.CompareConfigs(network.Config, reqNetwork.Config, exclude) + if err != nil { + return fmt.Errorf("Mismatching config for network %q in project %q: %w", network.Name, networkProjectName, err) + } - if !found { - return fmt.Errorf("Missing network %q in project %q", networkName, networkProjectName) + break + } + + if !found { + return fmt.Errorf("Missing network %q in project %q", networkName, networkProjectName) + } } } + + return nil + }) + if err != nil { + return err } return nil diff --git a/cmd/incusd/api_project.go b/cmd/incusd/api_project.go index e5c5f6448db..be06df4dc88 100644 --- a/cmd/incusd/api_project.go +++ b/cmd/incusd/api_project.go @@ -1499,8 +1499,14 @@ func projectValidateRestrictedSubnets(s *state.State, value string) error { return fmt.Errorf("Not an IP network address %q", subnetStr) } - // Check uplink exists and load config to compare subnets. - _, uplink, _, err := s.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, uplinkName) + var uplink *api.Network + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Check uplink exists and load config to compare subnets. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, uplinkName) + + return err + }) if err != nil { return fmt.Errorf("Invalid uplink network %q: %w", uplinkName, err) } diff --git a/cmd/incusd/instance_test.go b/cmd/incusd/instance_test.go index e145e338652..d3728d4c57e 100644 --- a/cmd/incusd/instance_test.go +++ b/cmd/incusd/instance_test.go @@ -122,7 +122,11 @@ func (suite *containerTestSuite) TestContainer_ProfilesOverwriteDefaultNic() { Name: "testFoo", } - _, err := suite.d.State().DB.Cluster.CreateNetwork(api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil) + err := suite.d.State().DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err := tx.CreateNetwork(ctx, api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil) + + return err + }) suite.Req.Nil(err) c, op, _, err := instance.CreateInternal(suite.d.State(), args, true) @@ -157,7 +161,11 @@ func (suite *containerTestSuite) TestContainer_LoadFromDB() { state := suite.d.State() - _, err := state.DB.Cluster.CreateNetwork(api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil) + err := state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err := tx.CreateNetwork(ctx, api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil) + + return err + }) suite.Req.Nil(err) // Create the container diff --git a/cmd/incusd/network_allocations.go b/cmd/incusd/network_allocations.go index 78110f7513a..a3cb15c8b19 100644 --- a/cmd/incusd/network_allocations.go +++ b/cmd/incusd/network_allocations.go @@ -125,7 +125,15 @@ func networkAllocationsGet(d *Daemon, r *http.Request) response.Response { // Then, get all the networks, their network forwards and their network load balancers. for _, projectName := range projectNames { - networkNames, err := d.db.Cluster.GetNetworks(projectName) + var networkNames []string + + err := d.db.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + networkNames, err = tx.GetNetworks(ctx, projectName) + + return err + }) if err != nil { return response.SmartError(fmt.Errorf("Failed loading networks: %w", err)) } diff --git a/cmd/incusd/networks.go b/cmd/incusd/networks.go index 169e0f2bf9f..b2c3e20ff5b 100644 --- a/cmd/incusd/networks.go +++ b/cmd/incusd/networks.go @@ -177,8 +177,14 @@ func networksGet(d *Daemon, r *http.Request) response.Response { recursion := localUtil.IsRecursionRequest(r) - // Get list of managed networks (that may or may not have network interfaces on the host). - networkNames, err := s.DB.Cluster.GetNetworks(projectName) + var networkNames []string + + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get list of managed networks (that may or may not have network interfaces on the host). + networkNames, err = tx.GetNetworks(ctx, projectName) + + return err + }) if err != nil { return response.InternalError(err) } @@ -335,7 +341,13 @@ func networksPost(d *Daemon, r *http.Request) response.Response { return response.InternalError(fmt.Errorf("Invalid project limits.network value: %w", err)) } - networks, err := s.DB.Cluster.GetNetworks(projectName) + var networks []string + + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + networks, err = tx.GetNetworks(ctx, projectName) + + return err + }) if err != nil { return response.InternalError(fmt.Errorf("Failed loading project's networks for limits check: %w", err)) } @@ -398,8 +410,14 @@ func networksPost(d *Daemon, r *http.Request) response.Response { return resp } - // Load existing network if exists, if not don't fail. - _, netInfo, _, err := s.DB.Cluster.GetNetworkInAnyState(projectName, req.Name) + var netInfo *api.Network + + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + // Load existing network if exists, if not don't fail. + _, netInfo, _, err = tx.GetNetworkInAnyState(ctx, projectName, req.Name) + + return err + }) if err != nil && !api.StatusErrorCheck(err, http.StatusNotFound) { return response.InternalError(err) } @@ -460,13 +478,21 @@ func networksPost(d *Daemon, r *http.Request) response.Response { return response.SmartError(err) } - // Create the database entry. - _, err = s.DB.Cluster.CreateNetwork(projectName, req.Name, req.Description, netType.DBType(), req.Config) + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + // Create the database entry. + _, err = tx.CreateNetwork(ctx, projectName, req.Name, req.Description, netType.DBType(), req.Config) + + return err + }) if err != nil { return response.SmartError(fmt.Errorf("Error inserting %q into database: %w", req.Name, err)) } - revert.Add(func() { _ = s.DB.Cluster.DeleteNetwork(projectName, req.Name) }) + revert.Add(func() { + _ = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + return tx.DeleteNetwork(ctx, projectName, req.Name) + }) + }) n, err := network.LoadByName(s, projectName, req.Name) if err != nil { @@ -993,8 +1019,12 @@ func networkDelete(d *Daemon, r *http.Request) response.Response { } } - // Remove the network from the database. - err = s.DB.Cluster.DeleteNetwork(n.Project(), n.Name()) + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + // Remove the network from the database. + err = tx.DeleteNetwork(ctx, n.Project(), n.Name()) + + return err + }) if err != nil { return response.SmartError(err) } @@ -1109,8 +1139,14 @@ func networkPost(d *Daemon, r *http.Request) response.Response { return response.BadRequest(fmt.Errorf("Network is currently in use")) } - // Check that the name isn't already in used by an existing managed network. - networks, err := s.DB.Cluster.GetNetworks(projectName) + var networks []string + + err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error { + // Check that the name isn't already in used by an existing managed network. + networks, err = tx.GetNetworks(ctx, projectName) + + return err + }) if err != nil { return response.InternalError(err) } @@ -1462,21 +1498,28 @@ func networkStartup(s *state.State) error { networkPriorityLogical: make(map[network.ProjectNetwork]struct{}), } - for _, projectName := range projectNames { - networkNames, err := s.DB.Cluster.GetCreatedNetworks(projectName) - if err != nil { - return fmt.Errorf("Failed to load networks for project %q: %w", projectName, err) - } - - for _, networkName := range networkNames { - pn := network.ProjectNetwork{ - ProjectName: projectName, - NetworkName: networkName, + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + for _, projectName := range projectNames { + networkNames, err := tx.GetCreatedNetworkNamesByProject(ctx, projectName) + if err != nil { + return fmt.Errorf("Failed to load networks for project %q: %w", projectName, err) } - // Assume all networks are networkPriorityStandalone initially. - initNetworks[networkPriorityStandalone][pn] = struct{}{} + for _, networkName := range networkNames { + pn := network.ProjectNetwork{ + ProjectName: projectName, + NetworkName: networkName, + } + + // Assume all networks are networkPriorityStandalone initially. + initNetworks[networkPriorityStandalone][pn] = struct{}{} + } } + + return nil + }) + if err != nil { + return err } loadedNetworks := make(map[network.ProjectNetwork]network.Network) @@ -1656,8 +1699,14 @@ func networkShutdown(s *state.State) { } for _, projectName := range projectNames { - // Get a list of managed networks. - networks, err := s.DB.Cluster.GetNetworks(projectName) + var networks []string + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get a list of managed networks. + networks, err = tx.GetNetworks(ctx, projectName) + + return err + }) if err != nil { logger.Error("Failed shutting down networks, couldn't load networks for project", logger.Ctx{"project": projectName, "err": err}) continue @@ -1696,7 +1745,13 @@ func networkRestartOVN(s *state.State) error { // Go over all the networks in every project. for _, projectName := range projectNames { - networkNames, err := s.DB.Cluster.GetCreatedNetworks(projectName) + var networkNames []string + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + networkNames, err = tx.GetCreatedNetworkNamesByProject(ctx, projectName) + + return err + }) if err != nil { return fmt.Errorf("Failed to load networks for project %q: %w", projectName, err) } diff --git a/cmd/incusd/patches.go b/cmd/incusd/patches.go index 27cdb86e59b..9f672fd1f4d 100644 --- a/cmd/incusd/patches.go +++ b/cmd/incusd/patches.go @@ -548,7 +548,7 @@ func patchNetworkOVNRemoveRoutes(name string, d *Daemon) error { return err } - for _, networks := range projectNetworks { + for projectName, networks := range projectNetworks { for networkID, network := range networks { if network.Type != "ovn" { continue @@ -570,7 +570,7 @@ func patchNetworkOVNRemoveRoutes(name string, d *Daemon) error { } if modified { - err = tx.UpdateNetwork(networkID, network.Description, network.Config) + err = tx.UpdateNetwork(ctx, projectName, network.Name, network.Description, network.Config) if err != nil { return fmt.Errorf("Failed removing OVN external route settings for %q (%d): %w", network.Name, networkID, err) } @@ -600,7 +600,7 @@ func patchNetworkOVNEnableNAT(name string, d *Daemon) error { return err } - for _, networks := range projectNetworks { + for projectName, networks := range projectNetworks { for networkID, network := range networks { if network.Type != "ovn" { continue @@ -620,7 +620,7 @@ func patchNetworkOVNEnableNAT(name string, d *Daemon) error { } if modified { - err = tx.UpdateNetwork(networkID, network.Description, network.Config) + err = tx.UpdateNetwork(ctx, projectName, network.Name, network.Description, network.Config) if err != nil { return fmt.Errorf("Failed saving OVN NAT settings for %q (%d): %w", network.Name, networkID, err) } @@ -711,25 +711,32 @@ func patchNetworkClearBridgeVolatileHwaddr(name string, d *Daemon) error { // Use api.ProjectDefaultName, as bridge networks don't support projects. projectName := api.ProjectDefaultName - // Get the list of networks. - networks, err := d.db.Cluster.GetNetworks(projectName) - if err != nil { - return fmt.Errorf("Failed loading networks for network_clear_bridge_volatile_hwaddr patch: %w", err) - } - - for _, networkName := range networks { - _, net, _, err := d.db.Cluster.GetNetworkInAnyState(projectName, networkName) + err := d.db.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get the list of networks. + networks, err := tx.GetNetworks(ctx, projectName) if err != nil { - return fmt.Errorf("Failed loading network %q for network_clear_bridge_volatile_hwaddr patch: %w", networkName, err) + return fmt.Errorf("Failed loading networks for network_clear_bridge_volatile_hwaddr patch: %w", err) } - if net.Config["volatile.bridge.hwaddr"] != "" { - delete(net.Config, "volatile.bridge.hwaddr") - err = d.db.Cluster.UpdateNetwork(projectName, net.Name, net.Description, net.Config) + for _, networkName := range networks { + _, net, _, err := tx.GetNetworkInAnyState(ctx, projectName, networkName) if err != nil { - return fmt.Errorf("Failed updating network %q for network_clear_bridge_volatile_hwaddr patch: %w", networkName, err) + return fmt.Errorf("Failed loading network %q for network_clear_bridge_volatile_hwaddr patch: %w", networkName, err) + } + + if net.Config["volatile.bridge.hwaddr"] != "" { + delete(net.Config, "volatile.bridge.hwaddr") + err = tx.UpdateNetwork(ctx, projectName, net.Name, net.Description, net.Config) + if err != nil { + return fmt.Errorf("Failed updating network %q for network_clear_bridge_volatile_hwaddr patch: %w", networkName, err) + } } } + + return nil + }) + if err != nil { + return err } return nil diff --git a/internal/server/cluster/membership_test.go b/internal/server/cluster/membership_test.go index ae4f0fa396d..013b5c83cae 100644 --- a/internal/server/cluster/membership_test.go +++ b/internal/server/cluster/membership_test.go @@ -338,7 +338,12 @@ func TestJoin(t *testing.T) { err = cluster.Bootstrap(targetState, targetGateway, "buzz") require.NoError(t, err) - _, err = targetState.DB.Cluster.GetNetworks(api.ProjectDefaultName) + + err = targetState.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err = tx.GetNetworks(ctx, api.ProjectDefaultName) + + return err + }) require.NoError(t, err) // Setup a joining node diff --git a/internal/server/db/entity.go b/internal/server/db/entity.go index d44aa0d711c..37a9199fbab 100644 --- a/internal/server/db/entity.go +++ b/internal/server/db/entity.go @@ -226,7 +226,14 @@ func (c *Cluster) GetURIFromEntity(entityType int, entityID int) (string, error) } case cluster.TypeNetwork: - networkName, projectName, err := c.GetNetworkNameAndProjectWithID(entityID) + var networkName string + var projectName string + + err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { + networkName, projectName, err = tx.GetNetworkNameAndProjectWithID(ctx, entityID) + + return err + }) if err != nil { return "", fmt.Errorf("Failed to get network name and project name: %w", err) } diff --git a/internal/server/db/networks.go b/internal/server/db/networks.go index d32a651c6f8..a4758883c17 100644 --- a/internal/server/db/networks.go +++ b/internal/server/db/networks.go @@ -87,6 +87,11 @@ func (c *ClusterTx) GetCreatedNetworks(ctx context.Context) (map[string]map[int6 return c.getCreatedNetworks(ctx, "") } +// GetCreatedNetworkNamesByProject returns the names of all networks that are in state networkCreated. +func (c *ClusterTx) GetCreatedNetworkNamesByProject(ctx context.Context, project string) ([]string, error) { + return c.networks(ctx, project, "state=?", networkCreated) +} + // GetCreatedNetworksByProject returns a map of api.Network in a project associated to network ID. // Only networks that have are in state networkCreated are returned. func (c *ClusterTx) GetCreatedNetworksByProject(ctx context.Context, projectName string) (map[int64]api.Network, error) { @@ -199,7 +204,7 @@ func (c *ClusterTx) GetNetworkID(ctx context.Context, projectName string, name s } // GetNetworkNameAndProjectWithID returns the network name and project name for the given ID. -func (c *Cluster) GetNetworkNameAndProjectWithID(networkID int) (string, string, error) { +func (c *ClusterTx) GetNetworkNameAndProjectWithID(ctx context.Context, networkID int) (string, string, error) { var networkName string var projectName string @@ -208,9 +213,7 @@ func (c *Cluster) GetNetworkNameAndProjectWithID(networkID int) (string, string, inargs := []any{networkID} outargs := []any{&networkName, &projectName} - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return dbQueryRowScan(ctx, tx, q, inargs, outargs) - }) + err := dbQueryRowScan(ctx, c, q, inargs, outargs) if err != nil { if err == sql.ErrNoRows { return "", "", api.StatusErrorf(http.StatusNotFound, "Network not found") @@ -425,26 +428,6 @@ func (c *ClusterTx) networkNodeState(networkID int64, state NetworkState) error return nil } -// UpdateNetwork updates the network with the given ID. -func (c *ClusterTx) UpdateNetwork(id int64, description string, config map[string]string) error { - err := updateNetworkDescription(c.tx, id, description) - if err != nil { - return err - } - - err = clearNetworkConfig(c.tx, id, c.nodeID) - if err != nil { - return err - } - - err = networkConfigAdd(c.tx, id, c.nodeID, config) - if err != nil { - return err - } - - return nil -} - // NetworkNodes returns the nodes keyed by node ID that the given network is defined on. func (c *ClusterTx) NetworkNodes(ctx context.Context, networkID int64) (map[int64]NetworkNode, error) { nodes := []NetworkNode{} @@ -496,17 +479,12 @@ func (c *ClusterTx) GetNetworkURIs(ctx context.Context, projectID int, project s } // GetNetworks returns the names of existing networks. -func (c *Cluster) GetNetworks(project string) ([]string, error) { - return c.networks(project, "") -} - -// GetCreatedNetworks returns the names of all networks that are in state networkCreated. -func (c *Cluster) GetCreatedNetworks(project string) ([]string, error) { - return c.networks(project, "state=?", networkCreated) +func (c *ClusterTx) GetNetworks(ctx context.Context, project string) ([]string, error) { + return c.networks(ctx, project, "") } // Get all networks matching the given WHERE filter (if given). -func (c *Cluster) networks(project string, where string, args ...any) ([]string, error) { +func (c *ClusterTx) networks(ctx context.Context, project string, where string, args ...any) ([]string, error) { q := "SELECT name FROM networks WHERE project_id = (SELECT id FROM projects WHERE name = ?)" inargs := []any{project} @@ -518,13 +496,7 @@ func (c *Cluster) networks(project string, where string, args ...any) ([]string, var name string outfmt := []any{name} - var result [][]any - - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - var err error - result, err = queryScan(ctx, tx, q, inargs, outfmt) - return err - }) + result, err := queryScan(ctx, c, q, inargs, outfmt) if err != nil { return []string{}, err } @@ -568,34 +540,20 @@ type NetworkNode struct { // GetNetworkInAnyState returns the network with the given name. The network can be in any state. // Returns network ID, network info, and network cluster member info. -func (c *Cluster) GetNetworkInAnyState(projectName string, networkName string) (int64, *api.Network, map[int64]NetworkNode, error) { - return c.getNetworkByProjectAndName(projectName, networkName, -1) +func (c *ClusterTx) GetNetworkInAnyState(ctx context.Context, projectName string, networkName string) (int64, *api.Network, map[int64]NetworkNode, error) { + return c.getNetworkByProjectAndName(ctx, projectName, networkName, -1) } // getNetworkByProjectAndName returns the network with the given project, name and state. // If stateFilter is -1, then a network can be in any state. // Returns network ID, network info, and network cluster member info. -func (c *Cluster) getNetworkByProjectAndName(projectName string, networkName string, stateFilter NetworkState) (int64, *api.Network, map[int64]NetworkNode, error) { - var err error - var networkID int64 - var networkState NetworkState - var networkType NetworkType - var network *api.Network - var nodes map[int64]NetworkNode - - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - networkID, networkState, networkType, network, err = c.getPartialNetworkByProjectAndName(ctx, tx, projectName, networkName, stateFilter) - if err != nil { - return err - } - - nodes, err = c.networkPopulatePeerInfo(ctx, tx, networkID, network, networkState, networkType) - if err != nil { - return err - } +func (c *ClusterTx) getNetworkByProjectAndName(ctx context.Context, projectName string, networkName string, stateFilter NetworkState) (int64, *api.Network, map[int64]NetworkNode, error) { + networkID, networkState, networkType, network, err := c.getPartialNetworkByProjectAndName(ctx, c, projectName, networkName, stateFilter) + if err != nil { + return -1, nil, nil, err + } - return nil - }) + nodes, err := c.networkPopulatePeerInfo(ctx, c, networkID, network, networkState, networkType) if err != nil { return -1, nil, nil, err } @@ -606,7 +564,7 @@ func (c *Cluster) getNetworkByProjectAndName(projectName string, networkName str // getPartialNetworkByProjectAndName gets the network with the given project, name and state. // If stateFilter is -1, then a network can be in any state. // Returns network ID, network state, network type, and partially populated network info. -func (c *Cluster) getPartialNetworkByProjectAndName(ctx context.Context, tx *ClusterTx, projectName string, networkName string, stateFilter NetworkState) (int64, NetworkState, NetworkType, *api.Network, error) { +func (c *ClusterTx) getPartialNetworkByProjectAndName(ctx context.Context, tx *ClusterTx, projectName string, networkName string, stateFilter NetworkState) (int64, NetworkState, NetworkType, *api.Network, error) { var err error var networkID int64 = int64(-1) var network api.Network @@ -632,7 +590,7 @@ func (c *Cluster) getPartialNetworkByProjectAndName(ctx context.Context, tx *Clu q.WriteString(" LIMIT 1") - err = tx.tx.QueryRowContext(ctx, q.String(), args...).Scan(&networkID, &network.Name, &network.Description, &networkState, &networkType) + err = c.tx.QueryRowContext(ctx, q.String(), args...).Scan(&networkID, &network.Name, &network.Description, &networkState, &networkType) if err != nil { if errors.Is(err, sql.ErrNoRows) { return -1, -1, -1, nil, api.StatusErrorf(http.StatusNotFound, "Network not found") @@ -646,7 +604,7 @@ func (c *Cluster) getPartialNetworkByProjectAndName(ctx context.Context, tx *Clu // networkPopulatePeerInfo takes a pointer to partially populated network info struct and enriches it. // Returns the network cluster member info. -func (c *Cluster) networkPopulatePeerInfo(ctx context.Context, tx *ClusterTx, networkID int64, network *api.Network, networkState NetworkState, networkType NetworkType) (map[int64]NetworkNode, error) { +func (c *ClusterTx) networkPopulatePeerInfo(ctx context.Context, tx *ClusterTx, networkID int64, network *api.Network, networkState NetworkState, networkType NetworkType) (map[int64]NetworkNode, error) { var err error // Populate Status and Type fields by converting from DB values. @@ -704,7 +662,7 @@ func networkFillType(network *api.Network, netType NetworkType) { } // GetNetworkWithInterface returns the network associated with the interface with the given name. -func (c *Cluster) GetNetworkWithInterface(devName string) (int64, *api.Network, error) { +func (c *ClusterTx) GetNetworkWithInterface(ctx context.Context, devName string) (int64, *api.Network, error) { id := int64(-1) name := "" value := "" @@ -713,13 +671,7 @@ func (c *Cluster) GetNetworkWithInterface(devName string) (int64, *api.Network, arg1 := []any{c.nodeID} arg2 := []any{id, name, value} - var result [][]any - - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - var err error - result, err = queryScan(ctx, tx, q, arg1, arg2) - return err - }) + result, err := queryScan(ctx, c, q, arg1, arg2) if err != nil { return -1, nil, err } @@ -745,9 +697,7 @@ func (c *Cluster) GetNetworkWithInterface(devName string) (int64, *api.Network, Type: "bridge", } - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - return c.getNetworkConfig(ctx, tx, id, &network) - }) + err = c.getNetworkConfig(ctx, c, id, &network) if err != nil { return -1, nil, err } @@ -756,7 +706,7 @@ func (c *Cluster) GetNetworkWithInterface(devName string) (int64, *api.Network, } // getNetworkConfig populates the config map of the Network with the given ID. -func (c *Cluster) getNetworkConfig(ctx context.Context, tx *ClusterTx, networkID int64, network *api.Network) error { +func (c *ClusterTx) getNetworkConfig(ctx context.Context, tx *ClusterTx, networkID int64, network *api.Network) error { q := ` SELECT key, value FROM networks_config @@ -766,7 +716,7 @@ func (c *Cluster) getNetworkConfig(ctx context.Context, tx *ClusterTx, networkID network.Config = map[string]string{} - return query.Scan(ctx, tx.Tx(), q, func(scan func(dest ...any) error) error { + return query.Scan(ctx, c.tx, q, func(scan func(dest ...any) error) error { var key, value string err := scan(&key, &value) @@ -786,60 +736,58 @@ func (c *Cluster) getNetworkConfig(ctx context.Context, tx *ClusterTx, networkID } // CreateNetwork creates a new network. -func (c *Cluster) CreateNetwork(projectName string, name string, description string, netType NetworkType, config map[string]string) (int64, error) { - var id int64 - err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - // Insert a new network record with state networkCreated. - result, err := tx.tx.Exec("INSERT INTO networks (project_id, name, description, state, type) VALUES ((SELECT id FROM projects WHERE name = ?), ?, ?, ?, ?)", - projectName, name, description, networkCreated, netType) - if err != nil { - return err - } - - id, err := result.LastInsertId() - if err != nil { - return err - } +func (c *ClusterTx) CreateNetwork(ctx context.Context, projectName string, name string, description string, netType NetworkType, config map[string]string) (int64, error) { + // Insert a new network record with state networkCreated. + result, err := c.tx.ExecContext(ctx, "INSERT INTO networks (project_id, name, description, state, type) VALUES ((SELECT id FROM projects WHERE name = ?), ?, ?, ?, ?)", + projectName, name, description, networkCreated, netType) + if err != nil { + return -1, err + } - // Insert a node-specific entry pointing to ourselves with state networkPending. - columns := []string{"network_id", "node_id", "state"} - values := []any{id, c.nodeID, networkPending} - _, err = query.UpsertObject(tx.tx, "networks_nodes", columns, values) - if err != nil { - return err - } + id, err := result.LastInsertId() + if err != nil { + return -1, err + } - err = networkConfigAdd(tx.tx, id, c.nodeID, config) - if err != nil { - return err - } + // Insert a node-specific entry pointing to ourselves with state networkPending. + columns := []string{"network_id", "node_id", "state"} + values := []any{id, c.nodeID, networkPending} + _, err = query.UpsertObject(c.tx, "networks_nodes", columns, values) + if err != nil { + return -1, err + } - return nil - }) + err = networkConfigAdd(c.tx, id, c.nodeID, config) if err != nil { - id = -1 + return -1, err } - return id, err + return id, nil } // UpdateNetwork updates the network with the given name. -func (c *Cluster) UpdateNetwork(project string, name, description string, config map[string]string) error { - id, _, _, err := c.GetNetworkInAnyState(project, name) +func (c *ClusterTx) UpdateNetwork(ctx context.Context, project string, name, description string, config map[string]string) error { + id, _, _, err := c.GetNetworkInAnyState(ctx, project, name) if err != nil { return err } - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - err = tx.UpdateNetwork(id, description, config) - if err != nil { - return err - } + err = updateNetworkDescription(c.tx, id, description) + if err != nil { + return err + } - return nil - }) + err = clearNetworkConfig(c.tx, id, c.nodeID) + if err != nil { + return err + } - return err + err = networkConfigAdd(c.tx, id, c.nodeID, config) + if err != nil { + return err + } + + return nil } // Update the description of the network with the given ID. @@ -892,29 +840,25 @@ func clearNetworkConfig(tx *sql.Tx, networkID, nodeID int64) error { } // DeleteNetwork deletes the network with the given name. -func (c *Cluster) DeleteNetwork(project string, name string) error { - id, _, _, err := c.GetNetworkInAnyState(project, name) +func (c *ClusterTx) DeleteNetwork(ctx context.Context, project string, name string) error { + id, _, _, err := c.GetNetworkInAnyState(ctx, project, name) if err != nil { return err } - return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - _, err := tx.tx.ExecContext(ctx, "DELETE FROM networks WHERE id=?", id) - return err - }) + _, err = c.tx.ExecContext(ctx, "DELETE FROM networks WHERE id=?", id) + + return err } // RenameNetwork renames a network. -func (c *Cluster) RenameNetwork(project string, oldName string, newName string) error { - id, _, _, err := c.GetNetworkInAnyState(project, oldName) +func (c *ClusterTx) RenameNetwork(ctx context.Context, project string, oldName string, newName string) error { + id, _, _, err := c.GetNetworkInAnyState(ctx, project, oldName) if err != nil { return err } - err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error { - _, err = tx.tx.Exec("UPDATE networks SET name=? WHERE id=?", newName, id) - return err - }) + _, err = c.tx.ExecContext(ctx, "UPDATE networks SET name=? WHERE id=?", newName, id) return err } diff --git a/internal/server/db/networks_test.go b/internal/server/db/networks_test.go index 7c3fa14ba6b..d0abea13708 100644 --- a/internal/server/db/networks_test.go +++ b/internal/server/db/networks_test.go @@ -19,9 +19,13 @@ func TestGetNetworksLocalConfigs(t *testing.T) { cluster, cleanup := db.NewTestCluster(t) defer cleanup() - _, err := cluster.CreateNetwork(api.ProjectDefaultName, "incusbr0", "", db.NetworkTypeBridge, map[string]string{ - "dns.mode": "none", - "bridge.external_interfaces": "vlan0", + err := cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, err := tx.CreateNetwork(ctx, api.ProjectDefaultName, "incusbr0", "", db.NetworkTypeBridge, map[string]string{ + "dns.mode": "none", + "bridge.external_interfaces": "vlan0", + }) + + return err }) require.NoError(t, err) diff --git a/internal/server/device/nic_ovn.go b/internal/server/device/nic_ovn.go index 8dcc46c9da3..95c40fcbd22 100644 --- a/internal/server/device/nic_ovn.go +++ b/internal/server/device/nic_ovn.go @@ -2,6 +2,7 @@ package device import ( "bytes" + "context" "fmt" "net" "net/http" @@ -392,7 +393,14 @@ func (d *nicOVN) Start() (*deviceConfig.RunConfig, error) { // Load uplink network config. uplinkNetworkName := d.network.Config()["network"] - _, uplink, _, err := d.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, uplinkNetworkName) + + var uplink *api.Network + + err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, uplinkNetworkName) + + return err + }) if err != nil { return nil, fmt.Errorf("Failed to load uplink network %q: %w", uplinkNetworkName, err) } @@ -738,7 +746,16 @@ func (d *nicOVN) Update(oldDevices deviceConfig.Devices, isRunning bool) error { if isRunning { // Load uplink network config. uplinkNetworkName := d.network.Config()["network"] - _, uplink, _, err := d.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, uplinkNetworkName) + + var uplink *api.Network + + err := d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, uplinkNetworkName) + + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", uplinkNetworkName, err) } diff --git a/internal/server/device/nictype/nictype.go b/internal/server/device/nictype/nictype.go index a678c5d50cd..206794247e1 100644 --- a/internal/server/device/nictype/nictype.go +++ b/internal/server/device/nictype/nictype.go @@ -3,11 +3,14 @@ package nictype import ( + "context" "fmt" + "github.com/lxc/incus/internal/server/db" deviceConfig "github.com/lxc/incus/internal/server/device/config" "github.com/lxc/incus/internal/server/project" "github.com/lxc/incus/internal/server/state" + "github.com/lxc/incus/shared/api" ) // NICType resolves the NIC Type for the supplied NIC device config. @@ -24,7 +27,13 @@ func NICType(s *state.State, deviceProjectName string, d deviceConfig.Device) (s return "", fmt.Errorf("Failed to translate device project %q into network project: %w", deviceProjectName, err) } - _, netInfo, _, err := s.DB.Cluster.GetNetworkInAnyState(networkProjectName, d["network"]) + var netInfo *api.Network + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + _, netInfo, _, err = tx.GetNetworkInAnyState(ctx, networkProjectName, d["network"]) + + return err + }) if err != nil { return "", fmt.Errorf("Failed to load network %q for project %q: %w", d["network"], networkProjectName, err) } diff --git a/internal/server/network/acl/acl_load.go b/internal/server/network/acl/acl_load.go index 9de3b72e154..fa29f50e4fd 100644 --- a/internal/server/network/acl/acl_load.go +++ b/internal/server/network/acl/acl_load.go @@ -105,33 +105,40 @@ func UsedBy(s *state.State, aclProjectName string, usageFunc func(matchedACLName return nil } - // Find networks using the ACLs. Cheapest to do. - networkNames, err := s.DB.Cluster.GetCreatedNetworks(aclProjectName) - if err != nil && !response.IsNotFoundError(err) { - return fmt.Errorf("Failed loading networks for project %q: %w", aclProjectName, err) - } - - for _, networkName := range networkNames { - _, network, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, networkName) - if err != nil { - return fmt.Errorf("Failed to get network config for %q: %w", networkName, err) + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Find networks using the ACLs. Cheapest to do. + networkNames, err := tx.GetCreatedNetworkNamesByProject(ctx, aclProjectName) + if err != nil && !response.IsNotFoundError(err) { + return fmt.Errorf("Failed loading networks for project %q: %w", aclProjectName, err) } - netACLNames := util.SplitNTrimSpace(network.Config["security.acls"], ",", -1, true) - matchedACLNames := []string{} - for _, netACLName := range netACLNames { - if util.ValueInSlice(netACLName, matchACLNames) { - matchedACLNames = append(matchedACLNames, netACLName) + for _, networkName := range networkNames { + _, network, _, err := tx.GetNetworkInAnyState(ctx, aclProjectName, networkName) + if err != nil { + return fmt.Errorf("Failed to get network config for %q: %w", networkName, err) } - } - if len(matchedACLNames) > 0 { - // Call usageFunc with a list of matched ACLs and info about the network. - err := usageFunc(matchedACLNames, network, "", nil) - if err != nil { - return err + netACLNames := util.SplitNTrimSpace(network.Config["security.acls"], ",", -1, true) + matchedACLNames := []string{} + for _, netACLName := range netACLNames { + if util.ValueInSlice(netACLName, matchACLNames) { + matchedACLNames = append(matchedACLNames, netACLName) + } + } + + if len(matchedACLNames) > 0 { + // Call usageFunc with a list of matched ACLs and info about the network. + err := usageFunc(matchedACLNames, network, "", nil) + if err != nil { + return err + } } } + + return nil + }) + if err != nil { + return err } // Look for profiles. Next cheapest to do. @@ -305,7 +312,16 @@ func NetworkUsage(s *state.State, aclProjectName string, aclNames []string, aclN err := UsedBy(s, aclProjectName, func(matchedACLNames []string, usageType any, _ string, nicConfig map[string]string) error { switch u := usageType.(type) { case db.InstanceArgs, cluster.Profile: - networkID, network, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, nicConfig["network"]) + var networkID int64 + var network *api.Network + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + networkID, network, _, err = tx.GetNetworkInAnyState(ctx, aclProjectName, nicConfig["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load network %q: %w", nicConfig["network"], err) } @@ -326,7 +342,16 @@ func NetworkUsage(s *state.State, aclProjectName string, aclNames []string, aclN if util.ValueInSlice(u.Type, supportedNetTypes) { _, found := aclNets[u.Name] if !found { - networkID, network, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, u.Name) + var networkID int64 + var network *api.Network + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + networkID, network, _, err = tx.GetNetworkInAnyState(ctx, aclProjectName, u.Name) + + return err + }) if err != nil { return fmt.Errorf("Failed to load network %q: %w", u.Name, err) } diff --git a/internal/server/network/acl/acl_ovn.go b/internal/server/network/acl/acl_ovn.go index cb170311a2a..0e0173ca97d 100644 --- a/internal/server/network/acl/acl_ovn.go +++ b/internal/server/network/acl/acl_ovn.go @@ -847,7 +847,14 @@ func OVNPortGroupDeleteIfUnused(s *state.State, l logger.Logger, client *openvsw return nil } - netID, network, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, nicConfig["network"]) + var netID int64 + var network *api.Network + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + netID, network, _, err = tx.GetNetworkInAnyState(ctx, aclProjectName, nicConfig["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load network %q: %w", nicConfig["network"], err) } @@ -876,7 +883,13 @@ func OVNPortGroupDeleteIfUnused(s *state.State, l logger.Logger, client *openvsw } if u.Type == "ovn" { - netID, _, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, u.Name) + var netID int64 + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + netID, _, _, err = tx.GetNetworkInAnyState(ctx, aclProjectName, u.Name) + + return err + }) if err != nil { return fmt.Errorf("Failed to load network %q: %w", nicConfig["network"], err) } @@ -903,7 +916,14 @@ func OVNPortGroupDeleteIfUnused(s *state.State, l logger.Logger, client *openvsw return nil } - netID, network, _, err := s.DB.Cluster.GetNetworkInAnyState(aclProjectName, nicConfig["network"]) + var netID int64 + var network *api.Network + + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + netID, network, _, err = tx.GetNetworkInAnyState(ctx, aclProjectName, nicConfig["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load network %q: %w", nicConfig["network"], err) } diff --git a/internal/server/network/driver_common.go b/internal/server/network/driver_common.go index 464ade349a5..5d9ce7d0c47 100644 --- a/internal/server/network/driver_common.go +++ b/internal/server/network/driver_common.go @@ -391,8 +391,10 @@ func (n *common) update(applyNetwork api.NetworkPut, targetNode string, clientTy } } - // Update the database. - err := n.state.DB.Cluster.UpdateNetwork(n.project, n.name, applyNetwork.Description, applyNetwork.Config) + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Update the database. + return tx.UpdateNetwork(ctx, n.project, n.name, applyNetwork.Description, applyNetwork.Config) + }) if err != nil { return err } @@ -464,8 +466,10 @@ func (n *common) rename(newName string) error { } } - // Rename the database entry. - err := n.state.DB.Cluster.RenameNetwork(n.project, n.name, newName) + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Rename the database entry. + return tx.RenameNetwork(ctx, n.project, n.name, newName) + }) if err != nil { return err } @@ -544,8 +548,14 @@ func (n *common) notifyDependentNetworks(changedKeys []string) { } for _, projectName := range projectNames { - // Get a list of managed networks in project. - depNets, err := n.state.DB.Cluster.GetCreatedNetworks(projectName) + var depNets []string + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get a list of managed networks in project. + depNets, err = tx.GetCreatedNetworkNamesByProject(ctx, projectName) + + return err + }) if err != nil { n.logger.Error("Failed to load networks in project", logger.Ctx{"project": projectName, "err": err}) continue // Continue to next project. diff --git a/internal/server/network/driver_ovn.go b/internal/server/network/driver_ovn.go index 851356ee38a..cf67793a4b8 100644 --- a/internal/server/network/driver_ovn.go +++ b/internal/server/network/driver_ovn.go @@ -385,8 +385,14 @@ func (n *ovn) Validate(config map[string]string) error { return err } - // Get uplink routes. - _, uplink, _, err := n.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, uplinkNetworkName) + var uplink *api.Network + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get uplink routes. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, uplinkNetworkName) + + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", uplinkNetworkName, err) } @@ -1028,7 +1034,7 @@ func (n *ovn) allocateUplinkPortIPs(uplinkNet Network, routerMAC net.HardwareAdd n.config[ovnVolatileUplinkIPv6] = routerExtPortIPv6.String() } - err = tx.UpdateNetwork(n.id, n.description, n.config) + err = tx.UpdateNetwork(ctx, n.project, n.name, n.description, n.config) if err != nil { return fmt.Errorf("Failed saving allocated uplink network IPs: %w", err) } @@ -1918,7 +1924,7 @@ func (n *ovn) setup(update bool) error { } err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { - err = tx.UpdateNetwork(n.id, n.description, n.config) + err = tx.UpdateNetwork(ctx, n.project, n.name, n.description, n.config) if err != nil { return fmt.Errorf("Failed saving updated network config: %w", err) } @@ -3248,11 +3254,17 @@ func (n *ovn) instanceDevicePortRoutesParse(deviceConfig map[string]string) ([]* // InstanceDevicePortValidateExternalRoutes validates the external routes for an OVN instance port. func (n *ovn) InstanceDevicePortValidateExternalRoutes(deviceInstance instance.Instance, deviceName string, portExternalRoutes []*net.IPNet) error { - var err error var p *api.Project + var uplink *api.Network + + err := n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Get uplink routes. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, n.config["network"]) - // Get uplink routes. - _, uplink, _, err := n.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, n.config["network"]) + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", n.config["network"], err) } @@ -4032,8 +4044,14 @@ func (n *ovn) InstanceDevicePortStop(ovsExternalOVNPort openvswitch.OVNSwitchPor return fmt.Errorf("Failed parsing NIC device routes: %w", err) } - // Load uplink network config. - _, uplink, _, err := n.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, n.config["network"]) + var uplink *api.Network + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Load uplink network config. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, n.config["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", n.config["network"], err) } @@ -4559,8 +4577,14 @@ func (n *ovn) ForwardCreate(forward api.NetworkForwardsPost, clientType request. return fmt.Errorf("Failed to load network restrictions from project %q: %w", n.project, err) } - // Get uplink routes. - _, uplink, _, err := n.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, n.config["network"]) + var uplink *api.Network + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get uplink routes. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, n.config["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", n.config["network"], err) } @@ -4909,8 +4933,14 @@ func (n *ovn) LoadBalancerCreate(loadBalancer api.NetworkLoadBalancersPost, clie return fmt.Errorf("Failed to load network restrictions from project %q: %w", n.project, err) } - // Get uplink routes. - _, uplink, _, err := n.state.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, n.config["network"]) + var uplink *api.Network + + err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Get uplink routes. + _, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, n.config["network"]) + + return err + }) if err != nil { return fmt.Errorf("Failed to load uplink network %q: %w", n.config["network"], err) } diff --git a/internal/server/network/driver_physical.go b/internal/server/network/driver_physical.go index 702c152106b..336f1b0c9e6 100644 --- a/internal/server/network/driver_physical.go +++ b/internal/server/network/driver_physical.go @@ -208,7 +208,7 @@ func (n *physical) setup(oldConfig map[string]string) error { if util.IsFalseOrEmpty(n.config["volatile.last_state.created"]) { n.config["volatile.last_state.created"] = fmt.Sprintf("%t", created) err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { - return tx.UpdateNetwork(n.id, n.description, n.config) + return tx.UpdateNetwork(ctx, n.project, n.name, n.description, n.config) }) if err != nil { return fmt.Errorf("Failed saving volatile config: %w", err) @@ -258,7 +258,7 @@ func (n *physical) Stop() error { // Remove last state config. delete(n.config, "volatile.last_state.created") err = n.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { - return tx.UpdateNetwork(n.id, n.description, n.config) + return tx.UpdateNetwork(ctx, n.project, n.name, n.description, n.config) }) if err != nil { return fmt.Errorf("Failed removing volatile config: %w", err) diff --git a/internal/server/network/network_load.go b/internal/server/network/network_load.go index dbaedb25874..7363a1af397 100644 --- a/internal/server/network/network_load.go +++ b/internal/server/network/network_load.go @@ -1,9 +1,11 @@ package network import ( + "context" "fmt" "sync" + "github.com/lxc/incus/internal/server/db" "github.com/lxc/incus/internal/server/state" "github.com/lxc/incus/shared/api" ) @@ -40,7 +42,17 @@ func LoadByType(driverType string) (Type, error) { // LoadByName loads an instantiated network from the database by project and name. func LoadByName(s *state.State, projectName string, name string) (Network, error) { - id, netInfo, netNodes, err := s.DB.Cluster.GetNetworkInAnyState(projectName, name) + var id int64 + var netInfo *api.Network + var netNodes map[int64]db.NetworkNode + + err := s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + id, netInfo, netNodes, err = tx.GetNetworkInAnyState(ctx, projectName, name) + + return err + }) if err != nil { return nil, err } diff --git a/internal/server/network/network_utils.go b/internal/server/network/network_utils.go index b7ab36d2756..3a9a283e919 100644 --- a/internal/server/network/network_utils.go +++ b/internal/server/network/network_utils.go @@ -385,8 +385,12 @@ func UpdateDNSMasqStatic(s *state.State, networkName string) error { if networkName == "" { var err error - // Pass api.ProjectDefaultName here, as currently dnsmasq (bridged) networks do not support projects. - networks, err = s.DB.Cluster.GetNetworks(api.ProjectDefaultName) + err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + // Pass api.ProjectDefaultName here, as currently dnsmasq (bridged) networks do not support projects. + networks, err = tx.GetNetworks(ctx, api.ProjectDefaultName) + + return err + }) if err != nil { return err } diff --git a/internal/server/network/zone/zone.go b/internal/server/network/zone/zone.go index f0caab104e3..81170479f9d 100644 --- a/internal/server/network/zone/zone.go +++ b/internal/server/network/zone/zone.go @@ -90,26 +90,41 @@ func (d *zone) networkUsesZone(netConfig map[string]string) bool { func (d *zone) usedBy(firstOnly bool) ([]string, error) { usedBy := []string{} - // Find networks using the zone. - networkNames, err := d.state.DB.Cluster.GetCreatedNetworks(d.projectName) + var networkNames []string + + err := d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + var err error + + // Find networks using the zone. + networkNames, err = tx.GetCreatedNetworkNamesByProject(ctx, d.projectName) + + return err + }) if err != nil && !response.IsNotFoundError(err) { return nil, fmt.Errorf("Failed loading networks for project %q: %w", d.projectName, err) } - for _, networkName := range networkNames { - _, network, _, err := d.state.DB.Cluster.GetNetworkInAnyState(d.projectName, networkName) - if err != nil { - return nil, fmt.Errorf("Failed to get network config for %q: %w", networkName, err) - } + err = d.state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error { + for _, networkName := range networkNames { + _, network, _, err := tx.GetNetworkInAnyState(ctx, d.projectName, networkName) + if err != nil { + return fmt.Errorf("Failed to get network config for %q: %w", networkName, err) + } - // Check if the network is using this zone. - if d.networkUsesZone(network.Config) { - u := api.NewURL().Path(version.APIVersion, "networks", networkName) - usedBy = append(usedBy, u.String()) - if firstOnly { - return usedBy, nil + // Check if the network is using this zone. + if d.networkUsesZone(network.Config) { + u := api.NewURL().Path(version.APIVersion, "networks", networkName) + usedBy = append(usedBy, u.String()) + if firstOnly { + return nil + } } } + + return nil + }) + if err != nil { + return nil, err } return usedBy, nil