Skip to content

Commit

Permalink
Move db network functions to ClusterTx
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Hipp <[email protected]>
  • Loading branch information
monstermunchkin committed Dec 19, 2023
1 parent becec1a commit 5fa4cc7
Show file tree
Hide file tree
Showing 20 changed files with 507 additions and 307 deletions.
132 changes: 73 additions & 59 deletions cmd/incusd/api_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,39 +614,43 @@ func clusterPutJoin(d *Daemon, r *http.Request, req api.ClusterPut) response.Res

// Get a list of projects for networks.
var projects []dbCluster.Project
networks := []api.InitNetworksProjectPost{}

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
projects, err = dbCluster.GetProjects(ctx, tx.Tx())
return err
})
if err != nil {
return fmt.Errorf("Failed to load projects for networks: %w", err)
}

networks := []api.InitNetworksProjectPost{}
for _, p := range projects {
networkNames, err := s.DB.Cluster.GetNetworks(p.Name)
if err != nil && !response.IsNotFoundError(err) {
return err
if err != nil {
return fmt.Errorf("Failed to load projects for networks: %w", err)
}

for _, name := range networkNames {
_, network, _, err := s.DB.Cluster.GetNetworkInAnyState(p.Name, name)
if err != nil {
for _, p := range projects {
networkNames, err := tx.GetNetworks(ctx, p.Name)
if err != nil && !response.IsNotFoundError(err) {
return err
}

internalNetwork := api.InitNetworksProjectPost{
NetworksPost: api.NetworksPost{
NetworkPut: network.NetworkPut,
Name: network.Name,
Type: network.Type,
},
Project: p.Name,
}
for _, name := range networkNames {
_, network, _, err := tx.GetNetworkInAnyState(ctx, p.Name, name)
if err != nil {
return err
}

internalNetwork := api.InitNetworksProjectPost{
NetworksPost: api.NetworksPost{
NetworkPut: network.NetworkPut,
Name: network.Name,
Type: network.Type,
},
Project: p.Name,
}

networks = append(networks, internalNetwork)
networks = append(networks, internalNetwork)
}
}

return nil
})
if err != nil {
return err
}

// Now request for this node to be added to the list of cluster nodes.
Expand Down Expand Up @@ -2061,7 +2065,13 @@ func clusterNodeDelete(d *Daemon, r *http.Request) response.Response {
}

for _, networkProjectName := range networkProjectNames {
networks, err := s.DB.Cluster.GetNetworks(networkProjectName)
var networks []string

err := s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
networks, err = tx.GetNetworks(ctx, networkProjectName)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down Expand Up @@ -2774,52 +2784,56 @@ func clusterCheckNetworksMatch(cluster *db.Cluster, reqNetworks []api.InitNetwor
var networkProjectNames []string

err = cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
networkProjectNames, err = dbCluster.GetProjectNames(context.Background(), tx.Tx())
return err
})
if err != nil {
return fmt.Errorf("Failed to load projects for networks: %w", err)
}

for _, networkProjectName := range networkProjectNames {
networkNames, err := cluster.GetCreatedNetworks(networkProjectName)
if err != nil && !response.IsNotFoundError(err) {
return err
networkProjectNames, err = dbCluster.GetProjectNames(ctx, tx.Tx())
if err != nil {
return fmt.Errorf("Failed to load projects for networks: %w", err)
}

for _, networkName := range networkNames {
found := false
for _, networkProjectName := range networkProjectNames {
networkNames, err := tx.GetCreatedNetworkNamesByProject(ctx, networkProjectName)
if err != nil && !response.IsNotFoundError(err) {
return err
}

for _, reqNetwork := range reqNetworks {
if reqNetwork.Name != networkName || reqNetwork.Project != networkProjectName {
continue
}
for _, networkName := range networkNames {
found := false

found = true
for _, reqNetwork := range reqNetworks {
if reqNetwork.Name != networkName || reqNetwork.Project != networkProjectName {
continue
}

_, network, _, err := cluster.GetNetworkInAnyState(networkProjectName, networkName)
if err != nil {
return err
}
found = true

if reqNetwork.Type != network.Type {
return fmt.Errorf("Mismatching type for network %q in project %q", networkName, networkProjectName)
}
_, network, _, err := tx.GetNetworkInAnyState(ctx, networkProjectName, networkName)
if err != nil {
return err
}

// Exclude the keys which are node-specific.
exclude := db.NodeSpecificNetworkConfig
err = localUtil.CompareConfigs(network.Config, reqNetwork.Config, exclude)
if err != nil {
return fmt.Errorf("Mismatching config for network %q in project %q: %w", networkName, networkProjectName, err)
}
if reqNetwork.Type != network.Type {
return fmt.Errorf("Mismatching type for network %q in project %q", networkName, networkProjectName)
}

break
}
// Exclude the keys which are node-specific.
exclude := db.NodeSpecificNetworkConfig
err = localUtil.CompareConfigs(network.Config, reqNetwork.Config, exclude)
if err != nil {
return fmt.Errorf("Mismatching config for network %q in project %q: %w", network.Name, networkProjectName, err)
}

if !found {
return fmt.Errorf("Missing network %q in project %q", networkName, networkProjectName)
break
}

if !found {
return fmt.Errorf("Missing network %q in project %q", networkName, networkProjectName)
}
}
}

return nil
})
if err != nil {
return err
}

return nil
Expand Down
10 changes: 8 additions & 2 deletions cmd/incusd/api_project.go
Original file line number Diff line number Diff line change
Expand Up @@ -1499,8 +1499,14 @@ func projectValidateRestrictedSubnets(s *state.State, value string) error {
return fmt.Errorf("Not an IP network address %q", subnetStr)
}

// Check uplink exists and load config to compare subnets.
_, uplink, _, err := s.DB.Cluster.GetNetworkInAnyState(api.ProjectDefaultName, uplinkName)
var uplink *api.Network

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
// Check uplink exists and load config to compare subnets.
_, uplink, _, err = tx.GetNetworkInAnyState(ctx, api.ProjectDefaultName, uplinkName)

return err
})
if err != nil {
return fmt.Errorf("Invalid uplink network %q: %w", uplinkName, err)
}
Expand Down
12 changes: 10 additions & 2 deletions cmd/incusd/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ func (suite *containerTestSuite) TestContainer_ProfilesOverwriteDefaultNic() {
Name: "testFoo",
}

_, err := suite.d.State().DB.Cluster.CreateNetwork(api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil)
err := suite.d.State().DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
_, err := tx.CreateNetwork(ctx, api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil)

return err
})
suite.Req.Nil(err)

c, op, _, err := instance.CreateInternal(suite.d.State(), args, true)
Expand Down Expand Up @@ -157,7 +161,11 @@ func (suite *containerTestSuite) TestContainer_LoadFromDB() {

state := suite.d.State()

_, err := state.DB.Cluster.CreateNetwork(api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil)
err := state.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
_, err := tx.CreateNetwork(ctx, api.ProjectDefaultName, "unknownbr0", "", db.NetworkTypeBridge, nil)

return err
})
suite.Req.Nil(err)

// Create the container
Expand Down
10 changes: 9 additions & 1 deletion cmd/incusd/network_allocations.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ func networkAllocationsGet(d *Daemon, r *http.Request) response.Response {

// Then, get all the networks, their network forwards and their network load balancers.
for _, projectName := range projectNames {
networkNames, err := d.db.Cluster.GetNetworks(projectName)
var networkNames []string

err := d.db.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
var err error

networkNames, err = tx.GetNetworks(ctx, projectName)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading networks: %w", err))
}
Expand Down
Loading

0 comments on commit 5fa4cc7

Please sign in to comment.