Skip to content

Commit

Permalink
Fix: Race condition when adding workers on remote nodes (#714)
Browse files Browse the repository at this point in the history
- Add lock by pool name when adding workers in occupyAvailableMachines
func. This prevents multiple instances of gateway from adding workers to
the same pool at the same time

Resolve BE-2045
  • Loading branch information
nickpetrovic authored Nov 14, 2024
1 parent 3eec936 commit 64e47f3
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 11 deletions.
14 changes: 3 additions & 11 deletions pkg/common/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var (
workerNetworkLock string = "worker:network:%s:lock"
workerNetworkIpIndex string = "worker:network:%s:ip_index"
workerNetworkContainerIp string = "worker:network:%s:container_ip:%s"
workerPoolSizerLock string = "worker:pool_sizer:%s:lock"
)

var (
Expand All @@ -49,11 +50,6 @@ var (
taskRetryLock string = "task:%s:%s:%s:retry_lock"
)

var (
workerPoolLock string = "workerpool:lock:%s"
workerPoolState string = "workerpool:state:%s"
)

var (
workspacePrefix string = "workspace"

Expand Down Expand Up @@ -242,12 +238,8 @@ func (rl *redisKeys) WorkspaceAuthorizedToken(token string) string {
}

// WorkerPool keys
func (rk *redisKeys) WorkerPoolLock(poolName string) string {
return fmt.Sprintf(workerPoolLock, poolName)
}

func (rk *redisKeys) WorkerPoolState(poolName string) string {
return fmt.Sprintf(workerPoolState, poolName)
func (rk *redisKeys) WorkerPoolSizerLock(poolName string) string {
return fmt.Sprintf(workerPoolSizerLock, poolName)
}

// Tailscale keys
Expand Down
2 changes: 2 additions & 0 deletions pkg/repository/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type WorkerRepository interface {
GetContainerIps(networkPrefix string) ([]string, error)
SetNetworkLock(networkPrefix string, ttl, retries int) error
RemoveNetworkLock(networkPrefix string) error
SetWorkerPoolSizerLock(controllerName string) error
RemoveWorkerPoolSizerLock(controllerName string) error
}

type ContainerRepository interface {
Expand Down
13 changes: 13 additions & 0 deletions pkg/repository/worker_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,19 @@ func (r *WorkerRedisRepository) UpdateWorkerCapacity(worker *types.Worker, reque
return nil
}

func (r *WorkerRedisRepository) SetWorkerPoolSizerLock(poolName string) error {
err := r.lock.Acquire(context.TODO(), common.RedisKeys.WorkerPoolSizerLock(poolName), common.RedisLockOptions{TtlS: 3, Retries: 0})
if err != nil {
return err
}

return nil
}

func (r *WorkerRedisRepository) RemoveWorkerPoolSizerLock(poolName string) error {
return r.lock.Release(common.RedisKeys.WorkerPoolSizerLock(poolName))
}

func (r *WorkerRedisRepository) ScheduleContainerRequest(worker *types.Worker, request *types.ContainerRequest) error {
// Serialize the ContainerRequest -> JSON
requestJSON, err := json.Marshal(request)
Expand Down
7 changes: 7 additions & 0 deletions pkg/scheduler/pool_sizing.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ func (s *WorkerPoolSizer) Start() {
}

// occupyAvailableMachines ensures that all manually provisioned machines always have workers occupying them
// This only adds one worker per machine, so if a machine has more capacity, it will not be fully utilized unless
// this is called multiple times.
func (s *WorkerPoolSizer) occupyAvailableMachines() error {
if err := s.workerRepo.SetWorkerPoolSizerLock(s.controller.Name()); err != nil {
return err
}
defer s.workerRepo.RemoveWorkerPoolSizerLock(s.controller.Name())

machines, err := s.providerRepo.ListAllMachines(string(*s.workerPoolConfig.Provider), s.controller.Name(), true)
if err != nil {
return err
Expand Down
101 changes: 101 additions & 0 deletions pkg/scheduler/pool_sizing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package scheduler

import (
"context"
"fmt"
"testing"

"github.com/alicebob/miniredis/v2"
"github.com/beam-cloud/beta9/pkg/common"
repo "github.com/beam-cloud/beta9/pkg/repository"
"github.com/beam-cloud/beta9/pkg/types"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)

func TestAddWorkerIfNeeded(t *testing.T) {
Expand Down Expand Up @@ -331,3 +333,102 @@ func TestOccupyAvailableMachines(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 2, len(workers))
}

func TestOccupyAvailableMachinesConcurrency(t *testing.T) {
s, err := miniredis.Run()
assert.NotNil(t, s)
assert.Nil(t, err)

redisClient, err := common.NewRedisClient(types.RedisConfig{Addrs: []string{s.Addr()}, Mode: types.RedisModeSingle})
assert.NotNil(t, redisClient)
assert.Nil(t, err)

providerRepo := repo.NewProviderRedisRepositoryForTest(redisClient)
workerRepo := repo.NewWorkerRedisRepositoryForTest(redisClient)

ctx, cancel := context.WithCancel(context.Background())

poolName := "pool1"
controller := &ExternalWorkerPoolControllerForTest{
ctx: ctx,
name: poolName,
workerRepo: workerRepo,
providerRepo: providerRepo,
poolName: poolName,
providerName: string(types.ProviderGeneric),
}

sizer1 := &WorkerPoolSizer{
providerRepo: providerRepo,
workerRepo: workerRepo,
controller: controller,
workerPoolConfig: &types.WorkerPoolConfig{
Provider: &types.ProviderGeneric,
GPUType: "A10G",
Mode: types.PoolModeExternal,
},
workerPoolSizingConfig: &types.WorkerPoolSizingConfig{
DefaultWorkerCpu: 1000,
DefaultWorkerMemory: 1000,
DefaultWorkerGpuType: "A10G",
DefaultWorkerGpuCount: 1,
},
}

sizer2 := &WorkerPoolSizer{
providerRepo: providerRepo,
workerRepo: workerRepo,
controller: controller,
workerPoolConfig: &types.WorkerPoolConfig{
Provider: &types.ProviderGeneric,
GPUType: "A10G",
Mode: types.PoolModeExternal,
},
workerPoolSizingConfig: &types.WorkerPoolSizingConfig{
DefaultWorkerCpu: 1000,
DefaultWorkerMemory: 1000,
DefaultWorkerGpuType: "A10G",
DefaultWorkerGpuCount: 1,
},
}

maxMachinesAndWorkers := 100
for i := 0; i < maxMachinesAndWorkers; i++ {
machineName := fmt.Sprintf("machine-%d", i)
machineState := &types.ProviderMachineState{
Gpu: "A10G",
GpuCount: 2,
AutoConsolidate: false,
Cpu: 10000,
Memory: 10000,
Status: types.MachineStatusRegistered,
}
err = providerRepo.AddMachine(string(types.ProviderGeneric), poolName, machineName, machineState)
assert.NoError(t, err)

err = providerRepo.RegisterMachine(string(types.ProviderGeneric), poolName, machineName, machineState)
assert.NoError(t, err)
}

var g errgroup.Group

g.Go(func() error {
return sizer1.occupyAvailableMachines()
})
g.Go(func() error {
return sizer2.occupyAvailableMachines()
})

// One of the sizers should fail to occupy the machines because of a lock
err = g.Wait()
assert.Error(t, err)

// Check that only 100 workers were added, not 102, 105, etc.
// This is because occupyAvailableMachines adds just one worker per machine (if there's capacity and no lock)
// each itme it is called.
workers, err := workerRepo.GetAllWorkers()
assert.NoError(t, err)
assert.Equal(t, maxMachinesAndWorkers, len(workers))

cancel()
}

0 comments on commit 64e47f3

Please sign in to comment.