Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mutating queue name in StatefulSet Webhook. #3520

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package pod
import (
"cmp"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"slices"
Expand Down Expand Up @@ -562,29 +560,7 @@ func getRoleHash(p corev1.Pod) (string, error) {
if roleHash, ok := p.Annotations[RoleHashAnnotation]; ok {
return roleHash, nil
}

shape := map[string]interface{}{
"spec": map[string]interface{}{
"initContainers": containersShape(p.Spec.InitContainers),
"containers": containersShape(p.Spec.Containers),
"nodeSelector": p.Spec.NodeSelector,
"affinity": p.Spec.Affinity,
"tolerations": p.Spec.Tolerations,
"runtimeClassName": p.Spec.RuntimeClassName,
"priority": p.Spec.Priority,
"topologySpreadConstraints": p.Spec.TopologySpreadConstraints,
"overhead": p.Spec.Overhead,
"resourceClaims": p.Spec.ResourceClaims,
},
}

shapeJSON, err := json.Marshal(shape)
if err != nil {
return "", err
}

// Trim hash to 8 characters and return
return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil
return utilpod.GenerateShape(p.Spec)
}

// Load loads all pods in the group
Expand Down
13 changes: 0 additions & 13 deletions pkg/controller/jobs/pod/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,6 @@ func getPodOptions(integrationOpts map[string]any) (*configapi.PodIntegrationOpt

var _ admission.CustomDefaulter = &PodWebhook{}

func containersShape(containers []corev1.Container) (result []map[string]interface{}) {
for _, c := range containers {
result = append(result, map[string]interface{}{
"resources": map[string]interface{}{
"requests": c.Resources.Requests,
},
"ports": c.Ports,
})
}

return result
}

// addRoleHash calculates the role hash and adds it to the pod's annotations
func (p *Pod) addRoleHash() error {
if p.pod.Annotations == nil {
Expand Down
11 changes: 4 additions & 7 deletions pkg/controller/jobs/statefulset/statefulset_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
Expand Down Expand Up @@ -57,11 +56,6 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling StatefulSet")

// For now, handle only scaling down to zero.
if ptr.Deref(sts.Spec.Replicas, 1) != 0 {
return ctrl.Result{}, nil
}

err = r.fetchAndFinalizePods(ctx, req.Namespace, req.Name)
if err != nil {
return ctrl.Result{}, err
Expand All @@ -73,7 +67,7 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco
func (r *Reconciler) fetchAndFinalizePods(ctx context.Context, namespace, statefulSetName string) error {
podList := &corev1.PodList{}
if err := r.client.List(ctx, podList, client.InNamespace(namespace), client.MatchingLabels{
pod.GroupNameLabel: GetWorkloadName(statefulSetName),
StatefulSetNameLabel: statefulSetName,
}); err != nil {
return err
}
Expand All @@ -84,6 +78,9 @@ func (r *Reconciler) finalizePods(ctx context.Context, pods []corev1.Pod) error
log := ctrl.LoggerFrom(ctx)
return parallelize.Until(ctx, len(pods), func(i int) error {
p := &pods[i]
if p.Status.Phase != corev1.PodSucceeded && p.Status.Phase != corev1.PodFailed {
return nil
}
err := clientutil.Patch(ctx, r.client, p, true, func() (bool, error) {
removed := controllerutil.RemoveFinalizer(p, pod.PodFinalizer)
if removed {
Expand Down
22 changes: 10 additions & 12 deletions pkg/controller/jobs/statefulset/statefulset_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ var (
)

func TestReconciler(t *testing.T) {
baseSTS := statefulsettesting.MakeStatefulSet("sts", "ns").
Replicas(1).
Queue("lq")

cases := map[string]struct {
statefulSet appsv1.StatefulSet
pods []corev1.Pod
Expand All @@ -49,23 +53,17 @@ func TestReconciler(t *testing.T) {
wantErr error
}{
"statefulset with replicas != zero": {
statefulSet: *statefulsettesting.MakeStatefulSet("sts", "ns").
Replicas(1).
Queue("lq").
DeepCopy(),
wantStatefulSet: *statefulsettesting.MakeStatefulSet("sts", "ns").
Replicas(1).
Queue("lq").
DeepCopy(),
statefulSet: *baseSTS.DeepCopy(),
wantStatefulSet: *baseSTS.DeepCopy(),
pods: []corev1.Pod{
*testingjobspod.MakePod("pod", "ns").
Label(pod.GroupNameLabel, GetWorkloadName("sts")).
Label(pod.GroupNameLabel, MustGetWorkloadName(baseSTS.DeepCopy())).
Finalizer(pod.PodFinalizer).
Obj(),
},
wantPods: []corev1.Pod{
*testingjobspod.MakePod("pod", "ns").
Label(pod.GroupNameLabel, GetWorkloadName("sts")).
Label(pod.GroupNameLabel, MustGetWorkloadName(baseSTS.DeepCopy())).
Finalizer(pod.PodFinalizer).
Obj(),
},
Expand All @@ -81,13 +79,13 @@ func TestReconciler(t *testing.T) {
DeepCopy(),
pods: []corev1.Pod{
*testingjobspod.MakePod("pod", "ns").
Label(pod.GroupNameLabel, GetWorkloadName("sts")).
Label(pod.GroupNameLabel, MustGetWorkloadName(baseSTS.Replicas(0).DeepCopy())).
Finalizer(pod.PodFinalizer).
Obj(),
},
wantPods: []corev1.Pod{
*testingjobspod.MakePod("pod", "ns").
Label(pod.GroupNameLabel, GetWorkloadName("sts")).
Label(pod.GroupNameLabel, MustGetWorkloadName(baseSTS.Replicas(0).DeepCopy())).
Obj(),
},
},
Expand Down
86 changes: 42 additions & 44 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package statefulset
import (
"context"
"fmt"
utilpod "sigs.k8s.io/kueue/pkg/util/pod"

appsv1 "k8s.io/api/apps/v1"
apivalidation "k8s.io/apimachinery/pkg/api/validation"
Expand All @@ -35,16 +36,17 @@ import (
"sigs.k8s.io/kueue/pkg/controller/jobs/pod"
)

const (
StatefulSetNameLabel = "kueue.x-k8s.io/statefulset-name"
)

type Webhook struct {
client client.Client
manageJobsWithoutQueueName bool
client client.Client
}

func SetupWebhook(mgr ctrl.Manager, opts ...jobframework.Option) error {
options := jobframework.ProcessOptions(opts...)
func SetupWebhook(mgr ctrl.Manager, _ ...jobframework.Option) error {
wh := &Webhook{
client: mgr.GetClient(),
manageJobsWithoutQueueName: options.ManageJobsWithoutQueueName,
client: mgr.GetClient(),
}
return ctrl.NewWebhookManagedBy(mgr).
For(&appsv1.StatefulSet{}).
Expand All @@ -62,17 +64,21 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error {
log := ctrl.LoggerFrom(ctx).WithName("statefulset-webhook")
log.V(5).Info("Applying defaults")

cqLabel, ok := ss.Labels[constants.QueueLabel]
if !ok {
queueName := jobframework.QueueNameForObject(ss.Object())
if queueName == "" {
return nil
}

if ss.Spec.Template.Labels == nil {
ss.Spec.Template.Labels = make(map[string]string, 2)
ss.Spec.Template.Labels = make(map[string]string, 3)
}

ss.Spec.Template.Labels[constants.QueueLabel] = cqLabel
ss.Spec.Template.Labels[pod.GroupNameLabel] = GetWorkloadName(ss.Name)
ss.Spec.Template.Labels[StatefulSetNameLabel] = ss.Name
ss.Spec.Template.Labels[constants.QueueLabel] = queueName
groupName, err := GetWorkloadName(obj.(*appsv1.StatefulSet))
if err != nil {
return err
}
ss.Spec.Template.Labels[pod.GroupNameLabel] = groupName

if ss.Spec.Template.Annotations == nil {
ss.Spec.Template.Annotations = make(map[string]string, 2)
Expand Down Expand Up @@ -114,37 +120,16 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
log := ctrl.LoggerFrom(ctx).WithName("statefulset-webhook")
log.V(5).Info("Validating update")

allErrs := apivalidation.ValidateImmutableField(
newStatefulSet.GetLabels()[constants.QueueLabel],
oldStatefulSet.GetLabels()[constants.QueueLabel],
queueNameLabelPath,
)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
oldStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
podSpecQueueNameLabelPath,
)...)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.GetLabels()[pod.GroupNameLabel],
oldStatefulSet.GetLabels()[pod.GroupNameLabel],
groupNameLabelPath,
)...)

oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1)
newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1)

// Allow only scale down to zero and scale up from zero.
// TODO(#3279): Support custom resizes later
if newReplicas != 0 && oldReplicas != 0 {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Replicas,
oldStatefulSet.Spec.Replicas,
replicasPath,
)...)
}
oldQueueName := jobframework.QueueNameForObject(oldStatefulSet.Object())
newQueueName := jobframework.QueueNameForObject(newStatefulSet.Object())

if oldReplicas == 0 && newReplicas > 0 && newStatefulSet.Status.Replicas > 0 {
allErrs = append(allErrs, field.Forbidden(replicasPath, "scaling down is still in progress"))
allErrs := field.ErrorList{}
allErrs = append(allErrs, jobframework.ValidateQueueName(newStatefulSet.Object())...)

// Prevents updating the queue-name if at least one Pod is not suspended
// or if the queue-name has been deleted.
if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...)
}

return warnings, allErrs.ToAggregate()
Expand All @@ -154,7 +139,20 @@ func (wh *Webhook) ValidateDelete(context.Context, runtime.Object) (warnings adm
return nil, nil
}

func GetWorkloadName(statefulSetName string) string {
func GetWorkloadName(sts *appsv1.StatefulSet) (string, error) {
shape, err := utilpod.GenerateShape(sts.Spec.Template.Spec)
if err != nil {
return "", err
}
ownerName := fmt.Sprintf("%s-%s", sts.Name, shape)
// Passing empty UID as it is not available before object creation
return jobframework.GetWorkloadNameForOwnerWithGVK(statefulSetName, "", gvk)
return jobframework.GetWorkloadNameForOwnerWithGVK(ownerName, "", gvk), nil
}

func MustGetWorkloadName(sts *appsv1.StatefulSet) string {
name, err := GetWorkloadName(sts)
if err != nil {
panic(err)
}
return name
}
10 changes: 4 additions & 6 deletions pkg/controller/jobs/statefulset/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ import (

func TestDefault(t *testing.T) {
testCases := map[string]struct {
statefulset *appsv1.StatefulSet
manageJobsWithoutQueueName bool
enableIntegrations []string
want *appsv1.StatefulSet
statefulset *appsv1.StatefulSet
enableIntegrations []string
want *appsv1.StatefulSet
}{
"statefulset with queue": {
enableIntegrations: []string{"pod"},
Expand Down Expand Up @@ -80,8 +79,7 @@ func TestDefault(t *testing.T) {
cli := builder.Build()

w := &Webhook{
client: cli,
manageJobsWithoutQueueName: tc.manageJobsWithoutQueueName,
client: cli,
}

ctx, _ := utiltesting.ContextWithLog(t)
Expand Down
40 changes: 40 additions & 0 deletions pkg/util/pod/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ limitations under the License.
package pod

import (
"crypto/sha256"
"encoding/json"
"fmt"
"slices"

corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -56,3 +59,40 @@ func gateIndex(p *corev1.Pod, gateName string) int {
return g.Name == gateName
})
}

func GenerateShape(podSpec corev1.PodSpec) (string, error) {
shape := map[string]interface{}{
"spec": map[string]interface{}{
"initContainers": containersShape(podSpec.InitContainers),
"containers": containersShape(podSpec.Containers),
"nodeSelector": podSpec.NodeSelector,
"affinity": podSpec.Affinity,
"tolerations": podSpec.Tolerations,
"runtimeClassName": podSpec.RuntimeClassName,
"priority": podSpec.Priority,
"topologySpreadConstraints": podSpec.TopologySpreadConstraints,
"overhead": podSpec.Overhead,
"resourceClaims": podSpec.ResourceClaims,
},
}

shapeJSON, err := json.Marshal(shape)
if err != nil {
return "", err
}

// Trim hash to 8 characters and return
return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil
}

func containersShape(containers []corev1.Container) (result []map[string]interface{}) {
for _, c := range containers {
result = append(result, map[string]interface{}{
"resources": map[string]interface{}{
"requests": c.Resources.Requests,
},
"ports": c.Ports,
})
}
return result
}
Loading