Skip to content

Commit

Permalink
Move db network load balancer functions to ClusterTx
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Hipp <[email protected]>
  • Loading branch information
monstermunchkin committed Dec 13, 2023
1 parent 4714023 commit 98846c8
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 131 deletions.
8 changes: 7 additions & 1 deletion cmd/incusd/network_allocations.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,13 @@ func networkAllocationsGet(d *Daemon, r *http.Request) response.Response {
)
}

loadBalancers, err := d.db.Cluster.GetNetworkLoadBalancers(r.Context(), n.ID(), false)
var loadBalancers map[int64]*api.NetworkLoadBalancer

err = d.db.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
loadBalancers, err = tx.GetNetworkLoadBalancers(ctx, n.ID(), false)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed getting load-balancers for network %q in project %q: %w", networkName, projectName, err))
}
Expand Down
34 changes: 30 additions & 4 deletions cmd/incusd/network_load_balancers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,6 +11,7 @@ import (

"github.com/lxc/incus/internal/server/auth"
clusterRequest "github.com/lxc/incus/internal/server/cluster/request"
"github.com/lxc/incus/internal/server/db"
"github.com/lxc/incus/internal/server/lifecycle"
"github.com/lxc/incus/internal/server/network"
"github.com/lxc/incus/internal/server/project"
Expand Down Expand Up @@ -160,7 +162,13 @@ func networkLoadBalancersGet(d *Daemon, r *http.Request) response.Response {
memberSpecific := false // Get load balancers for all cluster members.

if localUtil.IsRecursionRequest(r) {
records, err := s.DB.Cluster.GetNetworkLoadBalancers(r.Context(), n.ID(), memberSpecific)
var records map[int64]*api.NetworkLoadBalancer

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
records, err = tx.GetNetworkLoadBalancers(ctx, n.ID(), memberSpecific)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading network load balancers: %w", err))
}
Expand All @@ -173,7 +181,13 @@ func networkLoadBalancersGet(d *Daemon, r *http.Request) response.Response {
return response.SyncResponse(true, loadBalancers)
}

listenAddresses, err := s.DB.Cluster.GetNetworkLoadBalancerListenAddresses(n.ID(), memberSpecific)
var listenAddresses map[int64]string

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
listenAddresses, err = tx.GetNetworkLoadBalancerListenAddresses(ctx, n.ID(), memberSpecific)

return err
})
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading network load balancers: %w", err))
}
Expand Down Expand Up @@ -426,7 +440,13 @@ func networkLoadBalancerGet(d *Daemon, r *http.Request) response.Response {
targetMember := request.QueryParam(r, "target")
memberSpecific := targetMember != ""

_, loadBalancer, err := s.DB.Cluster.GetNetworkLoadBalancer(r.Context(), n.ID(), memberSpecific, listenAddress)
var loadBalancer *api.NetworkLoadBalancer

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
_, loadBalancer, err = tx.GetNetworkLoadBalancer(ctx, n.ID(), memberSpecific, listenAddress)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down Expand Up @@ -551,7 +571,13 @@ func networkLoadBalancerPut(d *Daemon, r *http.Request) response.Response {
memberSpecific := targetMember != ""

if r.Method == http.MethodPatch {
_, loadBalancer, err := s.DB.Cluster.GetNetworkLoadBalancer(r.Context(), n.ID(), memberSpecific, listenAddress)
var loadBalancer *api.NetworkLoadBalancer

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
_, loadBalancer, err = tx.GetNetworkLoadBalancer(ctx, n.ID(), memberSpecific, listenAddress)

return err
})
if err != nil {
return response.SmartError(err)
}
Expand Down
197 changes: 86 additions & 111 deletions internal/server/db/network_load_balancers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
// CreateNetworkLoadBalancer creates a new Network Load Balancer.
// If memberSpecific is true, then the load balancer is associated to the current member, rather than being
// associated to all members.
func (c *Cluster) CreateNetworkLoadBalancer(networkID int64, memberSpecific bool, info *api.NetworkLoadBalancersPost) (int64, error) {
func (c *ClusterTx) CreateNetworkLoadBalancer(ctx context.Context, networkID int64, memberSpecific bool, info *api.NetworkLoadBalancersPost) (int64, error) {
var err error
var loadBalancerID int64
var nodeID any
Expand All @@ -43,30 +43,23 @@ func (c *Cluster) CreateNetworkLoadBalancer(networkID int64, memberSpecific bool
}
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Insert a new Network Load Balancer record.
result, err := tx.tx.Exec(`
// Insert a new Network Load Balancer record.
result, err := c.tx.ExecContext(ctx, `
INSERT INTO networks_load_balancers
(network_id, node_id, listen_address, description, backends, ports)
VALUES (?, ?, ?, ?, ?, ?)
`, networkID, nodeID, info.ListenAddress, info.Description, string(backendsJSON), string(portsJSON))
if err != nil {
return err
}

loadBalancerID, err = result.LastInsertId()
if err != nil {
return err
}
if err != nil {
return -1, err
}

// Save config.
err = networkLoadBalancerConfigAdd(tx.tx, loadBalancerID, info.Config)
if err != nil {
return err
}
loadBalancerID, err = result.LastInsertId()
if err != nil {
return -1, err
}

return nil
})
// Save config.
err = networkLoadBalancerConfigAdd(c.tx, loadBalancerID, info.Config)
if err != nil {
return -1, err
}
Expand Down Expand Up @@ -102,7 +95,7 @@ func networkLoadBalancerConfigAdd(tx *sql.Tx, loadBalancerID int64, config map[s
}

// UpdateNetworkLoadBalancer updates an existing Network Load Balancer.
func (c *Cluster) UpdateNetworkLoadBalancer(networkID int64, loadBalancerID int64, info *api.NetworkLoadBalancerPut) error {
func (c *ClusterTx) UpdateNetworkLoadBalancer(ctx context.Context, networkID int64, loadBalancerID int64, info *api.NetworkLoadBalancerPut) error {
var err error
var backendsJSON, portsJSON []byte

Expand All @@ -120,39 +113,32 @@ func (c *Cluster) UpdateNetworkLoadBalancer(networkID int64, loadBalancerID int6
}
}

err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Update existing Network Load Balancer record.
res, err := tx.tx.Exec(`
// Update existing Network Load Balancer record.
res, err := c.tx.ExecContext(ctx, `
UPDATE networks_load_balancers
SET description = ?, backends = ?, ports = ?
WHERE network_id = ? and id = ?
`, info.Description, string(backendsJSON), string(portsJSON), networkID, loadBalancerID)
if err != nil {
return err
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if err != nil {
return err
}

if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network load balancer not found")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}

// Save config.
_, err = tx.tx.Exec("DELETE FROM networks_load_balancers_config WHERE network_load_balancer_id=?", loadBalancerID)
if err != nil {
return err
}
if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network load balancer not found")
}

err = networkLoadBalancerConfigAdd(tx.tx, loadBalancerID, info.Config)
if err != nil {
return err
}
// Save config.
_, err = c.tx.ExecContext(ctx, "DELETE FROM networks_load_balancers_config WHERE network_load_balancer_id=?", loadBalancerID)
if err != nil {
return err
}

return nil
})
err = networkLoadBalancerConfigAdd(c.tx, loadBalancerID, info.Config)
if err != nil {
return err
}
Expand All @@ -161,34 +147,32 @@ func (c *Cluster) UpdateNetworkLoadBalancer(networkID int64, loadBalancerID int6
}

// DeleteNetworkLoadBalancer deletes an existing Network Load Balancer.
func (c *Cluster) DeleteNetworkLoadBalancer(networkID int64, loadBalancerID int64) error {
return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
// Delete existing Network Load Balancer record.
res, err := tx.tx.Exec(`
func (c *ClusterTx) DeleteNetworkLoadBalancer(ctx context.Context, networkID int64, loadBalancerID int64) error {
// Delete existing Network Load Balancer record.
res, err := c.tx.ExecContext(ctx, `
DELETE FROM networks_load_balancers
WHERE network_id = ? and id = ?
`, networkID, loadBalancerID)
if err != nil {
return err
}
if err != nil {
return err
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}

if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network load balancer not found")
}
if rowsAffected <= 0 {
return api.StatusErrorf(http.StatusNotFound, "Network load balancer not found")
}

return nil
})
return nil
}

// GetNetworkLoadBalancer returns the Network Load Balancer ID and info for the given network ID and listen address.
// If memberSpecific is true, then the search is restricted to load balancers that belong to this member or belong
// to all members.
func (c *Cluster) GetNetworkLoadBalancer(ctx context.Context, networkID int64, memberSpecific bool, listenAddress string) (int64, *api.NetworkLoadBalancer, error) {
func (c *ClusterTx) GetNetworkLoadBalancer(ctx context.Context, networkID int64, memberSpecific bool, listenAddress string) (int64, *api.NetworkLoadBalancer, error) {
loadBalancers, err := c.GetNetworkLoadBalancers(ctx, networkID, memberSpecific, listenAddress)
if (err == nil && len(loadBalancers) <= 0) || errors.Is(err, sql.ErrNoRows) {
return -1, nil, api.StatusErrorf(http.StatusNotFound, "Network load balancer not found")
Expand Down Expand Up @@ -239,7 +223,7 @@ func networkLoadBalancerConfig(ctx context.Context, tx *ClusterTx, loadBalancerI
// network ID keyed on Load Balancer ID.
// If memberSpecific is true, then the search is restricted to load balancers that belong to this member or belong
// to all members.
func (c *Cluster) GetNetworkLoadBalancerListenAddresses(networkID int64, memberSpecific bool) (map[int64]string, error) {
func (c *ClusterTx) GetNetworkLoadBalancerListenAddresses(ctx context.Context, networkID int64, memberSpecific bool) (map[int64]string, error) {
var q *strings.Builder = &strings.Builder{}
args := []any{networkID}

Expand All @@ -258,21 +242,19 @@ func (c *Cluster) GetNetworkLoadBalancerListenAddresses(networkID int64, memberS

loadBalancers := make(map[int64]string)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return query.Scan(ctx, tx.Tx(), q.String(), func(scan func(dest ...any) error) error {
var loadBalancerID int64 = int64(-1)
var listenAddress string
err := query.Scan(ctx, c.tx, q.String(), func(scan func(dest ...any) error) error {
var loadBalancerID int64 = int64(-1)
var listenAddress string

err := scan(&loadBalancerID, &listenAddress)
if err != nil {
return err
}
err := scan(&loadBalancerID, &listenAddress)
if err != nil {
return err
}

loadBalancers[loadBalancerID] = listenAddress
loadBalancers[loadBalancerID] = listenAddress

return nil
}, args...)
})
return nil
}, args...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -391,7 +373,7 @@ func (c *ClusterTx) GetProjectNetworkLoadBalancerListenAddressesOnMember(ctx con
// GetNetworkLoadBalancers returns map of Network Load Balancers for the given network ID keyed on Load Balancer ID.
// If memberSpecific is true, then the search is restricted to load balancers that belong to this member or belong
// to all members. Can optionally retrieve only specific network load balancers by listen address.
func (c *Cluster) GetNetworkLoadBalancers(ctx context.Context, networkID int64, memberSpecific bool, listenAddresses ...string) (map[int64]*api.NetworkLoadBalancer, error) {
func (c *ClusterTx) GetNetworkLoadBalancers(ctx context.Context, networkID int64, memberSpecific bool, listenAddresses ...string) (map[int64]*api.NetworkLoadBalancer, error) {
var q *strings.Builder = &strings.Builder{}
args := []any{networkID}

Expand Down Expand Up @@ -423,54 +405,47 @@ func (c *Cluster) GetNetworkLoadBalancers(ctx context.Context, networkID int64,
var err error
loadBalancers := make(map[int64]*api.NetworkLoadBalancer)

err = c.Transaction(ctx, func(ctx context.Context, tx *ClusterTx) error {
err = query.Scan(ctx, tx.Tx(), q.String(), func(scan func(dest ...any) error) error {
var loadBalancerID int64 = int64(-1)
var backendsJSON, portsJSON string
var loadBalancer api.NetworkLoadBalancer

err := scan(&loadBalancerID, &loadBalancer.ListenAddress, &loadBalancer.Description, &loadBalancer.Location, &backendsJSON, &portsJSON)
if err != nil {
return err
}

loadBalancer.Backends = []api.NetworkLoadBalancerBackend{}
if backendsJSON != "" {
err = json.Unmarshal([]byte(backendsJSON), &loadBalancer.Backends)
if err != nil {
return fmt.Errorf("Failed unmarshalling backends: %w", err)
}
}

loadBalancer.Ports = []api.NetworkLoadBalancerPort{}
if portsJSON != "" {
err = json.Unmarshal([]byte(portsJSON), &loadBalancer.Ports)
if err != nil {
return fmt.Errorf("Failed unmarshalling ports: %w", err)
}
}

loadBalancers[loadBalancerID] = &loadBalancer
err = query.Scan(ctx, c.tx, q.String(), func(scan func(dest ...any) error) error {
var loadBalancerID int64 = int64(-1)
var backendsJSON, portsJSON string
var loadBalancer api.NetworkLoadBalancer

return nil
}, args...)
err := scan(&loadBalancerID, &loadBalancer.ListenAddress, &loadBalancer.Description, &loadBalancer.Location, &backendsJSON, &portsJSON)
if err != nil {
return err
}

// Populate config.
for loadBalancerID := range loadBalancers {
err = networkLoadBalancerConfig(ctx, tx, loadBalancerID, loadBalancers[loadBalancerID])
loadBalancer.Backends = []api.NetworkLoadBalancerBackend{}
if backendsJSON != "" {
err = json.Unmarshal([]byte(backendsJSON), &loadBalancer.Backends)
if err != nil {
return err
return fmt.Errorf("Failed unmarshalling backends: %w", err)
}
}

loadBalancer.Ports = []api.NetworkLoadBalancerPort{}
if portsJSON != "" {
err = json.Unmarshal([]byte(portsJSON), &loadBalancer.Ports)
if err != nil {
return fmt.Errorf("Failed unmarshalling ports: %w", err)
}
}

loadBalancers[loadBalancerID] = &loadBalancer

return nil
})
}, args...)
if err != nil {
return nil, err
}

// Populate config.
for loadBalancerID := range loadBalancers {
err = networkLoadBalancerConfig(ctx, c, loadBalancerID, loadBalancers[loadBalancerID])
if err != nil {
return nil, err
}
}

return loadBalancers, nil
}
Loading

0 comments on commit 98846c8

Please sign in to comment.