Skip to content

Commit

Permalink
feat(test): add configurable evaluation model for test results
Browse files Browse the repository at this point in the history
- Add --evaluation-model flag to override default model
- Fetch available models from /v1/models API
- Use first available model by default
- Record evaluation model in test results
  • Loading branch information
lukemarsden committed Nov 9, 2024
1 parent 733fb29 commit 1098831
Showing 1 changed file with 85 additions and 23 deletions.
108 changes: 85 additions & 23 deletions api/cmd/helix/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ type Content struct {
}

type TestResult struct {
TestName string `json:"test_name"`
Prompt string `json:"prompt"`
Response string `json:"response"`
Expected string `json:"expected"`
Result string `json:"result"`
Reason string `json:"reason"`
SessionID string `json:"session_id"`
Model string `json:"model"`
InferenceTime time.Duration `json:"inference_time"`
EvaluationTime time.Duration `json:"evaluation_time"`
HelixURL string `json:"helix_url"`
TestName string `json:"test_name"`
Prompt string `json:"prompt"`
Response string `json:"response"`
Expected string `json:"expected"`
Result string `json:"result"`
Reason string `json:"reason"`
SessionID string `json:"session_id"`
Model string `json:"model"`
EvaluationModel string `json:"evaluation_model"`
InferenceTime time.Duration `json:"inference_time"`
EvaluationTime time.Duration `json:"evaluation_time"`
HelixURL string `json:"helix_url"`
}

type ChatResponse struct {
Expand All @@ -66,6 +67,14 @@ type ChatResponse struct {
} `json:"choices"`
}

type ModelResponse struct {
Data []struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
} `json:"data"`
}

const htmlTemplate = `
<!DOCTYPE html>
<html lang="en">
Expand Down Expand Up @@ -326,22 +335,24 @@ const htmlTemplate = `

func NewTestCmd() *cobra.Command {
var yamlFile string
var evaluationModel string

cmd := &cobra.Command{
Use: "test",
Short: "Run tests for Helix app",
Long: `This command runs tests defined in helix.yaml or a specified YAML file and evaluates the results.`,
RunE: func(cmd *cobra.Command, args []string) error {
return runTest(cmd, yamlFile)
return runTest(cmd, yamlFile, evaluationModel)
},
}

cmd.Flags().StringVarP(&yamlFile, "file", "f", "helix.yaml", "Path to the YAML file containing test definitions")
cmd.Flags().StringVar(&evaluationModel, "evaluation-model", "", "Model to use for evaluating test results")

return cmd
}

func runTest(cmd *cobra.Command, yamlFile string) error {
func runTest(cmd *cobra.Command, yamlFile string, evaluationModel string) error {
appConfig, helixYamlContent, err := readHelixYaml(yamlFile)
if err != nil {
return err
Expand All @@ -350,6 +361,8 @@ func runTest(cmd *cobra.Command, yamlFile string) error {
testID := system.GenerateTestRunID()
namespacedAppName := fmt.Sprintf("%s/%s", testID, appConfig.Name)

fmt.Printf("Using evaluation model: %s\n", evaluationModel)

// Deploy the app with the namespaced name and appConfig
appID, err := deployApp(namespacedAppName, yamlFile)
if err != nil {
Expand All @@ -374,7 +387,16 @@ func runTest(cmd *cobra.Command, yamlFile string) error {

helixURL := getHelixURL()

results, totalTime, err := runTests(appConfig, appID, apiKey, helixURL)
// Get available models if evaluation model is not specified
if evaluationModel == "" {
models, err := getAvailableModels(apiKey, helixURL)
if err != nil {
return fmt.Errorf("error getting available models: %v", err)
}
evaluationModel = models[0]
}

results, totalTime, err := runTests(appConfig, appID, apiKey, helixURL, evaluationModel)
if err != nil {
return err
}
Expand Down Expand Up @@ -435,7 +457,7 @@ func getHelixURL() string {
return helixURL
}

func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([]TestResult, time.Duration, error) {
func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL, evaluationModel string) ([]TestResult, time.Duration, error) {
var results []TestResult
totalStartTime := time.Now()

Expand All @@ -452,7 +474,7 @@ func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([
semaphore <- struct{}{}
defer func() { <-semaphore }()

result, err := runSingleTest(assistantName, testName, step, appID, apiKey, helixURL, assistant.Model)
result, err := runSingleTest(assistantName, testName, step, appID, apiKey, helixURL, assistant.Model, evaluationModel)
if err != nil {
result.Reason = err.Error()
result.Result = "ERROR"
Expand Down Expand Up @@ -492,16 +514,17 @@ func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([
return results, totalTime, nil
}

func runSingleTest(assistantName, testName string, step types.TestStep, appID, apiKey, helixURL, model string) (TestResult, error) {
func runSingleTest(assistantName, testName string, step types.TestStep, appID, apiKey, helixURL, model, evaluationModel string) (TestResult, error) {
inferenceStartTime := time.Now()

// partial result in case of error
result := TestResult{
TestName: fmt.Sprintf("%s - %s", assistantName, testName),
Prompt: step.Prompt,
Expected: step.ExpectedOutput,
Model: model,
HelixURL: helixURL,
TestName: fmt.Sprintf("%s - %s", assistantName, testName),
Prompt: step.Prompt,
Expected: step.ExpectedOutput,
Model: model,
EvaluationModel: evaluationModel,
HelixURL: helixURL,
}

chatReq := ChatRequest{
Expand All @@ -527,7 +550,7 @@ func runSingleTest(assistantName, testName string, step types.TestStep, appID, a
evaluationStartTime := time.Now()

evalReq := ChatRequest{
Model: "llama3.1:8b-instruct-q8_0",
Model: evaluationModel,
System: "You are an AI assistant tasked with evaluating test results. Output only PASS or FAIL followed by a brief explanation on the next line. Be fairly liberal about what you consider to be a PASS, as long as everything specifically requested is present. However, if the response is not as expected, you should output FAIL.",
Messages: []Message{
{
Expand Down Expand Up @@ -857,3 +880,42 @@ func openBrowser(url string) {
fmt.Printf("Error opening browser: %v\n", err)
}
}

func getAvailableModels(apiKey, helixURL string) ([]string, error) {
req, err := http.NewRequest("GET", helixURL+"/api/v1/models", nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}

req.Header.Set("Authorization", "Bearer "+apiKey)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error fetching models: %v", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response: %v", err)
}

var modelResp ModelResponse
err = json.Unmarshal(body, &modelResp)
if err != nil {
return nil, fmt.Errorf("error parsing response JSON: %v", err)
}

if len(modelResp.Data) == 0 {
return nil, fmt.Errorf("no models available")
}

// Extract model IDs from the response
var models []string
for _, model := range modelResp.Data {
models = append(models, model.ID)
}

return models, nil
}

0 comments on commit 1098831

Please sign in to comment.