Skip to content

Commit

Permalink
Feat: Multi-gpu workloads (#762)
Browse files Browse the repository at this point in the history
- Enable multi-gpu workloads (on multi-gpu workers)
  • Loading branch information
luke-lombardi authored Dec 6, 2024
1 parent f5a3f7e commit a256510
Show file tree
Hide file tree
Showing 18 changed files with 633 additions and 524 deletions.
4 changes: 2 additions & 2 deletions pkg/abstractions/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func (cs *CmdContainerService) ExecuteCommand(in *pb.CommandExecutionRequest, st
gpuRequest = append(gpuRequest, stubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := stubConfig.Runtime.GpuCount
if stubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
8 changes: 6 additions & 2 deletions pkg/abstractions/endpoint/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func (i *endpointInstance) startContainers(containersToRun int) error {
gpuRequest = append(gpuRequest, i.StubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := i.StubConfig.Runtime.GpuCount
if i.StubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand All @@ -78,6 +78,10 @@ func (i *endpointInstance) startContainers(containersToRun int) error {
checkpointEnabled = false
}

if gpuCount > 1 {
checkpointEnabled = false
}

for c := 0; c < containersToRun; c++ {
containerId := i.genContainerId()

Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/function/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (t *FunctionTask) run(ctx context.Context, stub *types.StubWithRelated) err
gpuRequest = append(gpuRequest, stubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := stubConfig.Runtime.GpuCount
if stubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
8 changes: 6 additions & 2 deletions pkg/abstractions/taskqueue/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (i *taskQueueInstance) startContainers(containersToRun int) error {
gpuRequest = append(gpuRequest, i.StubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := i.StubConfig.Runtime.GpuCount
if i.StubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand All @@ -75,6 +75,10 @@ func (i *taskQueueInstance) startContainers(containersToRun int) error {
checkpointEnabled = false
}

if gpuCount > 1 {
checkpointEnabled = false
}

for c := 0; c < containersToRun; c++ {
runRequest := &types.ContainerRequest{
ContainerId: i.genContainerId(),
Expand Down
1 change: 1 addition & 0 deletions pkg/common/config.default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ gateway:
stubLimits:
memory: 32768
maxReplicas: 10
maxGpuCount: 2
imageService:
localCacheEnabled: true
registryStore: local
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ message GetOrCreateStubRequest {
uint32 concurrent_requests = 24;
string extra = 25;
bool checkpoint_enabled = 26;
uint32 gpu_count = 27;
}

message GetOrCreateStubResponse {
Expand Down
90 changes: 55 additions & 35 deletions pkg/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,6 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea

gpus := types.GPUTypesFromString(in.Gpu)

if len(gpus) > 0 {
concurrencyLimit, err := gws.backendRepo.GetConcurrencyLimitByWorkspaceId(ctx, authInfo.Workspace.ExternalId)
if err != nil && concurrencyLimit != nil && concurrencyLimit.GPULimit <= 0 {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "GPU concurrency limit is 0.",
}, nil
}

gpuCounts, err := gws.providerRepo.GetGPUCounts(gws.appConfig.Worker.Pools)
if err != nil {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "Failed to get GPU counts.",
}, nil
}

// T4s are currently in a different pool than other GPUs and won't show up in gpu counts
lowGpus := []string{}

for _, gpu := range gpus {
if gpuCounts[gpu.String()] <= 1 && gpu.String() != types.GPU_T4.String() {
lowGpus = append(lowGpus, gpu.String())
}
}

if len(lowGpus) > 0 {
warning = fmt.Sprintf("GPU capacity for %s is currently low.", strings.Join(lowGpus, ", "))
}
}

autoscaler := &types.Autoscaler{}
if in.Autoscaler.Type == "" {
autoscaler.Type = types.QueueDepthAutoscaler
Expand All @@ -83,12 +52,27 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
in.Extra = "{}"
}

if in.GpuCount > gws.appConfig.GatewayService.StubLimits.MaxGpuCount {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: fmt.Sprintf("GPU count must be %d or less.", gws.appConfig.GatewayService.StubLimits.MaxGpuCount),
}, nil
}

if in.GpuCount > 1 && !authInfo.Workspace.MultiGpuEnabled {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "Multi-GPU containers are not enabled for this workspace.",
}, nil
}

stubConfig := types.StubConfigV1{
Runtime: types.Runtime{
Cpu: in.Cpu,
Gpus: gpus,
Memory: in.Memory,
ImageId: in.ImageId,
Cpu: in.Cpu,
Gpus: gpus,
GpuCount: in.GpuCount,
Memory: in.Memory,
ImageId: in.ImageId,
},
Handler: in.Handler,
OnStart: in.OnStart,
Expand All @@ -107,6 +91,42 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
CheckpointEnabled: in.CheckpointEnabled,
}

// Ensure GPU count is at least 1 if a GPU is required
if stubConfig.RequiresGPU() && in.GpuCount == 0 {
stubConfig.Runtime.GpuCount = 1
}

if stubConfig.RequiresGPU() {
concurrencyLimit, err := gws.backendRepo.GetConcurrencyLimitByWorkspaceId(ctx, authInfo.Workspace.ExternalId)
if err != nil && concurrencyLimit != nil && concurrencyLimit.GPULimit <= 0 {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "GPU concurrency limit is 0.",
}, nil
}

gpuCounts, err := gws.providerRepo.GetGPUCounts(gws.appConfig.Worker.Pools)
if err != nil {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "Failed to get GPU counts.",
}, nil
}

// T4s are currently in a different pool than other GPUs and won't show up in gpu counts
lowGpus := []string{}

for _, gpu := range gpus {
if gpuCounts[gpu.String()] <= 1 && gpu.String() != types.GPU_T4.String() {
lowGpus = append(lowGpus, gpu.String())
}
}

if len(lowGpus) > 0 {
warning = fmt.Sprintf("GPU capacity for %s is currently low.", strings.Join(lowGpus, ", "))
}
}

// Get secrets
for _, secret := range in.Secrets {
secret, err := gws.backendRepo.GetSecretByName(ctx, authInfo.Workspace, secret.Name)
Expand Down
8 changes: 4 additions & 4 deletions pkg/repository/backend_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (r *PostgresBackendRepository) CreateWorkspace(ctx context.Context) (types.
func (r *PostgresBackendRepository) GetWorkspaceByExternalId(ctx context.Context, externalId string) (types.Workspace, error) {
var workspace types.Workspace

query := `SELECT id, name, created_at, concurrency_limit_id, volume_cache_enabled FROM workspace WHERE external_id = $1;`
query := `SELECT id, name, created_at, concurrency_limit_id, volume_cache_enabled, multi_gpu_enabled FROM workspace WHERE external_id = $1;`
err := r.client.GetContext(ctx, &workspace, query, externalId)
if err != nil {
return types.Workspace{}, err
Expand All @@ -155,7 +155,7 @@ func (r *PostgresBackendRepository) GetWorkspaceByExternalId(ctx context.Context
func (r *PostgresBackendRepository) GetWorkspaceByExternalIdWithSigningKey(ctx context.Context, externalId string) (types.Workspace, error) {
var workspace types.Workspace

query := `SELECT id, name, created_at, concurrency_limit_id, signing_key, volume_cache_enabled FROM workspace WHERE external_id = $1;`
query := `SELECT id, name, created_at, concurrency_limit_id, signing_key, volume_cache_enabled, multi_gpu_enabled FROM workspace WHERE external_id = $1;`
err := r.client.GetContext(ctx, &workspace, query, externalId)
if err != nil {
return types.Workspace{}, err
Expand Down Expand Up @@ -199,7 +199,7 @@ func (r *PostgresBackendRepository) AuthorizeToken(ctx context.Context, tokenKey
query := `
SELECT t.id, t.external_id, t.key, t.created_at, t.updated_at, t.active, t.disabled_by_cluster_admin , t.token_type, t.reusable, t.workspace_id,
w.id "workspace.id", w.name "workspace.name", w.external_id "workspace.external_id", w.signing_key "workspace.signing_key", w.created_at "workspace.created_at",
w.updated_at "workspace.updated_at", w.volume_cache_enabled "workspace.volume_cache_enabled"
w.updated_at "workspace.updated_at", w.volume_cache_enabled "workspace.volume_cache_enabled", w.multi_gpu_enabled "workspace.multi_gpu_enabled"
FROM token t
INNER JOIN workspace w ON t.workspace_id = w.id
WHERE t.key = $1 AND t.active = TRUE;
Expand Down Expand Up @@ -791,7 +791,7 @@ func (r *PostgresBackendRepository) GetStubByExternalId(ctx context.Context, ext
var stub types.StubWithRelated
qb := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).Select(
`s.id, s.external_id, s.name, s.type, s.config, s.config_version, s.object_id, s.workspace_id, s.created_at, s.updated_at,
w.id AS "workspace.id", w.external_id AS "workspace.external_id", w.name AS "workspace.name", w.created_at AS "workspace.created_at", w.updated_at AS "workspace.updated_at", w.signing_key AS "workspace.signing_key", w.volume_cache_enabled AS "workspace.volume_cache_enabled",
w.id AS "workspace.id", w.external_id AS "workspace.external_id", w.name AS "workspace.name", w.created_at AS "workspace.created_at", w.updated_at AS "workspace.updated_at", w.signing_key AS "workspace.signing_key", w.volume_cache_enabled AS "workspace.volume_cache_enabled", w.multi_gpu_enabled AS "workspace.multi_gpu_enabled",
o.id AS "object.id", o.external_id AS "object.external_id", o.hash AS "object.hash", o.size AS "object.size", o.workspace_id AS "object.workspace_id", o.created_at AS "object.created_at"`,
).
From("stub s").
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package backend_postgres_migrations

import (
"context"
"database/sql"

"github.com/pressly/goose/v3"
)

func init() {
goose.AddMigrationContext(upAddFieldWorkspaceMultiGpuEnabled, downDropFieldWorkspaceMultiGpuEnabled)
}

func upAddFieldWorkspaceMultiGpuEnabled(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE workspace ADD COLUMN multi_gpu_enabled BOOLEAN DEFAULT FALSE;`)
return err
}

func downDropFieldWorkspaceMultiGpuEnabled(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE workspace DROP COLUMN multi_gpu_enabled;`)
return err
}
16 changes: 11 additions & 5 deletions pkg/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Workspace struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at,omitempty"`
SigningKey *string `db:"signing_key" json:"signing_key"`
VolumeCacheEnabled bool `db:"volume_cache_enabled" json:"volume_cache_enabled"`
MultiGpuEnabled bool `db:"multi_gpu_enabled" json:"multi_gpu_enabled"`
ConcurrencyLimitId *uint `db:"concurrency_limit_id" json:"concurrency_limit_id,omitempty"`
ConcurrencyLimit *ConcurrencyLimit `db:"concurrency_limit" json:"concurrency_limit"`
}
Expand Down Expand Up @@ -186,6 +187,10 @@ type StubConfigV1 struct {
CheckpointEnabled bool `json:"checkpoint_enabled"`
}

func (c *StubConfigV1) RequiresGPU() bool {
return len(c.Runtime.Gpus) > 0 || c.Runtime.Gpu != ""
}

type AutoscalerType string

const (
Expand Down Expand Up @@ -291,11 +296,12 @@ type Image struct {
}

type Runtime struct {
Cpu int64 `json:"cpu"`
Gpu GpuType `json:"gpu"`
Memory int64 `json:"memory"`
ImageId string `json:"image_id"`
Gpus []GpuType `json:"gpus"`
Cpu int64 `json:"cpu"`
Gpu GpuType `json:"gpu"`
GpuCount uint32 `json:"gpu_count"`
Memory int64 `json:"memory"`
ImageId string `json:"image_id"`
Gpus []GpuType `json:"gpus"`
}

type GpuType string
Expand Down
1 change: 1 addition & 0 deletions pkg/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type CORSConfig struct {
type StubLimits struct {
Memory uint64 `key:"memory" json:"memory"`
MaxReplicas uint64 `key:"maxReplicas" json:"max_replicas"`
MaxGpuCount uint32 `key:"maxGpuCount" json:"max_gpu_count"`
}

type GatewayServiceConfig struct {
Expand Down
Loading

0 comments on commit a256510

Please sign in to comment.