Skip to content

Commit

Permalink
Update LMEvalJob CRD to pass secrets
Browse files Browse the repository at this point in the history
Add the `envSecrets` field to pass the secrets.
Each of the EnvSecret could be either an API key
or a secret object reference.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Jul 23, 2024
1 parent e6b8f73 commit 5248cbb
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 57 deletions.
15 changes: 15 additions & 0 deletions api/v1beta1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package v1beta1

import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

Expand Down Expand Up @@ -59,6 +60,17 @@ type Arg struct {
Value string `json:"value,omitempty"`
}

type EnvSecret struct {
// Environment's name
Env string `json:"env"`
// The secret is from a secret object
// +optional
SecretRef *corev1.SecretKeySelector `json:"secretRef,omitempty"`
// The secret is from a plain text
// +optional
Secret *string `json:"secret,omitempty"`
}

// LMEvalJobSpec defines the desired state of LMEvalJob
type LMEvalJobSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
Expand Down Expand Up @@ -86,6 +98,9 @@ type LMEvalJobSpec struct {
// model, will be saved at per-document granularity
// +optional
LogSamples *bool `json:"logSamples,omitempty"`
// Assign secrets to the environment variables
// +optional
EnvSecrets []EnvSecret `json:"envSecrets,omitempty"`
}

// LMEvalJobStatus defines the observed state of LMEvalJob
Expand Down
100 changes: 52 additions & 48 deletions backend/controller/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ const (
DestDriverPath = "/opt/app-root/src/bin/driver"
PodImageKey = "pod-image"
DriverImageKey = "driver-image"
DriverServiceAccountKey = "driver-serviceaccount"
PodCheckingIntervalKey = "pod-checking-interval"
ImagePullPolicyKey = "image-pull-policy"
GrpcPortKey = "grpc-port"
Expand Down Expand Up @@ -82,15 +81,14 @@ var (
}

optionKeys = map[string]string{
"PodImage": PodImageKey,
"DriverImage": DriverImageKey,
"DriverServiceAccount": DriverServiceAccountKey,
"PodCheckingInterval": PodCheckingIntervalKey,
"ImagePullPolicy": ImagePullPolicyKey,
"GrpcPort": GrpcPortKey,
"GrpcService": GrpcServiceKey,
"GrpcServerSecret": GrpcServerSecretKey,
"GrpcClientSecret": GrpcClientSecretKey,
"PodImage": PodImageKey,
"DriverImage": DriverImageKey,
"PodCheckingInterval": PodCheckingIntervalKey,
"ImagePullPolicy": ImagePullPolicyKey,
"GrpcPort": GrpcPortKey,
"GrpcService": GrpcServiceKey,
"GrpcServerSecret": GrpcServerSecretKey,
"GrpcClientSecret": GrpcClientSecretKey,
}
)

Expand All @@ -113,16 +111,15 @@ type LMEvalJobReconciler struct {
}

type ServiceOptions struct {
PodImage string
DriverImage string
DriverServiceAccount string
PodCheckingInterval time.Duration
ImagePullPolicy corev1.PullPolicy
GrpcPort int
GrpcService string
GrpcServerSecret string
GrpcClientSecret string
grpcTLSMode TLSMode
PodImage string
DriverImage string
PodCheckingInterval time.Duration
ImagePullPolicy corev1.PullPolicy
GrpcPort int
GrpcService string
GrpcServerSecret string
GrpcClientSecret string
grpcTLSMode TLSMode
}

// +kubebuilder:rbac:groups=foundation-model-stack.github.com.github.com,resources=lmevaljobs,verbs=get;list;watch;create;update;patch;delete
Expand Down Expand Up @@ -296,15 +293,14 @@ func (r *LMEvalJobReconciler) updateStatus(ctx context.Context, newStatus *backe
func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
ctx context.Context, configmap *corev1.ConfigMap) error {
r.options = &ServiceOptions{
DriverImage: DefaultDriverImage,
PodImage: DefaultPodImage,
DriverServiceAccount: DefaultDriverServiceAccount,
PodCheckingInterval: DefaultPodCheckingInterval,
ImagePullPolicy: DefaultImagePullPolicy,
GrpcPort: DefaultGrpcPort,
GrpcService: DefaultGrpcService,
GrpcServerSecret: DefaultGrpcServerSecret,
GrpcClientSecret: DefaultGrpcClientSecret,
DriverImage: DefaultDriverImage,
PodImage: DefaultPodImage,
PodCheckingInterval: DefaultPodCheckingInterval,
ImagePullPolicy: DefaultImagePullPolicy,
GrpcPort: DefaultGrpcPort,
GrpcService: DefaultGrpcService,
GrpcServerSecret: DefaultGrpcServerSecret,
GrpcClientSecret: DefaultGrpcClientSecret,
}

log := log.FromContext(ctx)
Expand Down Expand Up @@ -462,18 +458,18 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo
mainIndex := slices.IndexFunc(pod.Status.ContainerStatuses, func(s corev1.ContainerStatus) bool {
return s.Name == "main"
})
if mainIndex == -1 || pod.Status.ContainerStatuses[mainIndex].LastTerminationState.Terminated == nil {
if mainIndex == -1 || pod.Status.ContainerStatuses[mainIndex].State.Terminated == nil {
// wait for the main container to finish
return ctrl.Result{Requeue: true, RequeueAfter: r.options.PodCheckingInterval}, nil
}

// main container finished. update status
job.Status.State = lmevalservicev1beta1.CompleteJobState
if pod.Status.ContainerStatuses[mainIndex].LastTerminationState.Terminated.ExitCode == 0 {
if pod.Status.ContainerStatuses[mainIndex].State.Terminated.ExitCode == 0 {
job.Status.Reason = lmevalservicev1beta1.SucceedReason
} else {
job.Status.Reason = lmevalservicev1beta1.FailedReason
job.Status.Message = pod.Status.ContainerStatuses[mainIndex].LastTerminationState.Terminated.Reason
job.Status.Message = pod.Status.ContainerStatuses[mainIndex].State.Terminated.Reason
}

err = r.Status().Update(ctx, job)
Expand Down Expand Up @@ -575,19 +571,7 @@ func (r *LMEvalJobReconciler) createPod(job *lmevalservicev1beta1.LMEvalJob) *co
var runAsUser int64 = 1001030000
var secretMode int32 = 420

var envVars = []corev1.EnvVar{
{
Name: "GENAI_KEY",
ValueFrom: &corev1.EnvVarSource{
SecretKeyRef: &corev1.SecretKeySelector{
Key: "key",
LocalObjectReference: corev1.LocalObjectReference{
Name: "genai-key",
},
},
},
},
}
var envVars = generateEnvs(job.Spec.EnvSecrets)

var volumeMounts = []corev1.VolumeMount{
{
Expand Down Expand Up @@ -757,9 +741,8 @@ func (r *LMEvalJobReconciler) createPod(job *lmevalservicev1beta1.LMEvalJob) *co
Type: corev1.SeccompProfileTypeRuntimeDefault,
},
},
ServiceAccountName: r.options.DriverServiceAccount,
Volumes: volumes,
RestartPolicy: corev1.RestartPolicyNever,
Volumes: volumes,
RestartPolicy: corev1.RestartPolicyNever,
},
}
return &pod
Expand Down Expand Up @@ -826,3 +809,24 @@ func argsToString(args []lmevalservicev1beta1.Arg) string {
}
return strings.Join(equalForms, ",")
}

func generateEnvs(secrets []lmevalservicev1beta1.EnvSecret) []corev1.EnvVar {
var envs = []corev1.EnvVar{}
for _, secret := range secrets {
if secret.Secret != nil {
envs = append(envs, corev1.EnvVar{
Name: secret.Env,
Value: *secret.Secret,
})
} else if secret.SecretRef != nil {
envs = append(envs, corev1.EnvVar{
Name: secret.Env,
ValueFrom: &corev1.EnvVarSource{
SecretKeyRef: secret.SecretRef,
},
})
}
}

return envs
}
13 changes: 11 additions & 2 deletions backend/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,18 @@ func (d *driverImpl) Run() error {
if err := d.updateStatus(lmevalservicev1beta1.RunningJobState); err != nil {
return err
}
err := d.exec()
execErr := d.exec()

return d.updateCompleteStatus(err)
// dump stderr and stdout to the console
var toConsole = func(file string) {
if data, err := os.ReadFile(file); err == nil {
os.Stdout.Write(data)
}
}
toConsole(filepath.Join(d.Option.OutputPath, "stdout.log"))
toConsole(filepath.Join(d.Option.OutputPath, "stderr.log"))

return d.updateCompleteStatus(execErr)
}

func (d *driverImpl) Cleanup() {
Expand Down
36 changes: 29 additions & 7 deletions backend/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ package driver
import (
"context"
"flag"
"fmt"
"net"
"os"
"testing"

"github.com/foundation-model-stack/fms-lm-eval-service/backend/api/v1beta1"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand All @@ -30,7 +36,23 @@ var (
driverLog = ctrl.Log.WithName("driver-test")
)

type DummyUpdateServer struct {
v1beta1.UnimplementedLMEvalJobUpdateServiceServer
}

func (*DummyUpdateServer) UpdateStatus(context.Context, *v1beta1.JobStatus) (*v1beta1.Response, error) {
return &v1beta1.Response{
Code: v1beta1.ResponseCode_OK,
Message: "updated the job status successfully",
}, nil
}

func Test_Driver(t *testing.T) {
server := grpc.NewServer()
v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &DummyUpdateServer{})
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082))
assert.Nil(t, err)
go server.Serve(lis)

opts := zap.Options{
Development: true,
Expand All @@ -44,18 +66,18 @@ func Test_Driver(t *testing.T) {
Context: context.Background(),
JobNamespace: "fms-lm-eval-service-system",
JobName: "evaljob-sample",
GrpcService: "lm-eval-grpc",
GrpcService: "localhost",
GrpcPort: 8082,
OutputPath: ".",
Logger: driverLog,
Args: []string{"--", "sh", "-ec", "echo \"tttttttttttttttttttt\""},
})

if err != nil {
t.Errorf("Create Driver failed: %v", err)
}
assert.Nil(t, err)

if err := driver.Run(); err != nil {
t.Errorf("Unexpected error: %v", err)
}
assert.Nil(t, driver.Run())

server.Stop()
assert.Nil(t, os.Remove("./stderr.log"))
assert.Nil(t, os.Remove("./stdout.log"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ spec:
spec:
description: LMEvalJobSpec defines the desired state of LMEvalJob
properties:
envSecrets:
description: Assign secrets to the environment variables
items:
properties:
env:
description: Environment's name
type: string
secret:
description: The secret is from a plain text
type: string
secretRef:
description: The secret is from a secret object
properties:
key:
description: The key of the secret to select from. Must
be a valid secret key.
type: string
name:
description: |-
Name of the referent.
More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
TODO: Add other useful fields. apiVersion, kind, uid?
type: string
optional:
description: Specify whether the Secret or its key must
be defined
type: boolean
required:
- key
type: object
x-kubernetes-map-type: atomic
required:
- env
type: object
type: array
genArgs:
description: Map to `--gen_kwargs` parameter for the underlying library.
items:
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/onsi/ginkgo/v2 v2.17.1
github.com/onsi/gomega v1.32.0
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.34.2
k8s.io/api v0.30.0
Expand Down Expand Up @@ -50,6 +51,7 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_golang v1.16.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
Expand Down

0 comments on commit 5248cbb

Please sign in to comment.