diff --git a/api/cmd/helix/test.go b/api/cmd/helix/test.go index d4cec278..781544c7 100644 --- a/api/cmd/helix/test.go +++ b/api/cmd/helix/test.go @@ -44,17 +44,18 @@ type Content struct { } type TestResult struct { - TestName string `json:"test_name"` - Prompt string `json:"prompt"` - Response string `json:"response"` - Expected string `json:"expected"` - Result string `json:"result"` - Reason string `json:"reason"` - SessionID string `json:"session_id"` - Model string `json:"model"` - InferenceTime time.Duration `json:"inference_time"` - EvaluationTime time.Duration `json:"evaluation_time"` - HelixURL string `json:"helix_url"` + TestName string `json:"test_name"` + Prompt string `json:"prompt"` + Response string `json:"response"` + Expected string `json:"expected"` + Result string `json:"result"` + Reason string `json:"reason"` + SessionID string `json:"session_id"` + Model string `json:"model"` + EvaluationModel string `json:"evaluation_model"` + InferenceTime time.Duration `json:"inference_time"` + EvaluationTime time.Duration `json:"evaluation_time"` + HelixURL string `json:"helix_url"` } type ChatResponse struct { @@ -66,6 +67,14 @@ type ChatResponse struct { } `json:"choices"` } +type ModelResponse struct { + Data []struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + } `json:"data"` +} + const htmlTemplate = ` @@ -326,22 +335,24 @@ const htmlTemplate = ` func NewTestCmd() *cobra.Command { var yamlFile string + var evaluationModel string cmd := &cobra.Command{ Use: "test", Short: "Run tests for Helix app", Long: `This command runs tests defined in helix.yaml or a specified YAML file and evaluates the results.`, RunE: func(cmd *cobra.Command, args []string) error { - return runTest(cmd, yamlFile) + return runTest(cmd, yamlFile, evaluationModel) }, } cmd.Flags().StringVarP(&yamlFile, "file", "f", "helix.yaml", "Path to the YAML file containing test definitions") + cmd.Flags().StringVar(&evaluationModel, "evaluation-model", "", "Model to use for evaluating test results") return cmd } -func runTest(cmd *cobra.Command, yamlFile string) error { +func runTest(cmd *cobra.Command, yamlFile string, evaluationModel string) error { appConfig, helixYamlContent, err := readHelixYaml(yamlFile) if err != nil { return err @@ -350,6 +361,8 @@ func runTest(cmd *cobra.Command, yamlFile string) error { testID := system.GenerateTestRunID() namespacedAppName := fmt.Sprintf("%s/%s", testID, appConfig.Name) + fmt.Printf("Using evaluation model: %s\n", evaluationModel) + // Deploy the app with the namespaced name and appConfig appID, err := deployApp(namespacedAppName, yamlFile) if err != nil { @@ -374,7 +387,16 @@ func runTest(cmd *cobra.Command, yamlFile string) error { helixURL := getHelixURL() - results, totalTime, err := runTests(appConfig, appID, apiKey, helixURL) + // Get available models if evaluation model is not specified + if evaluationModel == "" { + models, err := getAvailableModels(apiKey, helixURL) + if err != nil { + return fmt.Errorf("error getting available models: %v", err) + } + evaluationModel = models[0] + } + + results, totalTime, err := runTests(appConfig, appID, apiKey, helixURL, evaluationModel) if err != nil { return err } @@ -435,7 +457,7 @@ func getHelixURL() string { return helixURL } -func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([]TestResult, time.Duration, error) { +func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL, evaluationModel string) ([]TestResult, time.Duration, error) { var results []TestResult totalStartTime := time.Now() @@ -452,7 +474,7 @@ func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([ semaphore <- struct{}{} defer func() { <-semaphore }() - result, err := runSingleTest(assistantName, testName, step, appID, apiKey, helixURL, assistant.Model) + result, err := runSingleTest(assistantName, testName, step, appID, apiKey, helixURL, assistant.Model, evaluationModel) if err != nil { result.Reason = err.Error() result.Result = "ERROR" @@ -492,16 +514,17 @@ func runTests(appConfig types.AppHelixConfig, appID, apiKey, helixURL string) ([ return results, totalTime, nil } -func runSingleTest(assistantName, testName string, step types.TestStep, appID, apiKey, helixURL, model string) (TestResult, error) { +func runSingleTest(assistantName, testName string, step types.TestStep, appID, apiKey, helixURL, model, evaluationModel string) (TestResult, error) { inferenceStartTime := time.Now() // partial result in case of error result := TestResult{ - TestName: fmt.Sprintf("%s - %s", assistantName, testName), - Prompt: step.Prompt, - Expected: step.ExpectedOutput, - Model: model, - HelixURL: helixURL, + TestName: fmt.Sprintf("%s - %s", assistantName, testName), + Prompt: step.Prompt, + Expected: step.ExpectedOutput, + Model: model, + EvaluationModel: evaluationModel, + HelixURL: helixURL, } chatReq := ChatRequest{ @@ -527,7 +550,7 @@ func runSingleTest(assistantName, testName string, step types.TestStep, appID, a evaluationStartTime := time.Now() evalReq := ChatRequest{ - Model: "llama3.1:8b-instruct-q8_0", + Model: evaluationModel, System: "You are an AI assistant tasked with evaluating test results. Output only PASS or FAIL followed by a brief explanation on the next line. Be fairly liberal about what you consider to be a PASS, as long as everything specifically requested is present. However, if the response is not as expected, you should output FAIL.", Messages: []Message{ { @@ -857,3 +880,42 @@ func openBrowser(url string) { fmt.Printf("Error opening browser: %v\n", err) } } + +func getAvailableModels(apiKey, helixURL string) ([]string, error) { + req, err := http.NewRequest("GET", helixURL+"/api/v1/models", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error fetching models: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response: %v", err) + } + + var modelResp ModelResponse + err = json.Unmarshal(body, &modelResp) + if err != nil { + return nil, fmt.Errorf("error parsing response JSON: %v", err) + } + + if len(modelResp.Data) == 0 { + return nil, fmt.Errorf("no models available") + } + + // Extract model IDs from the response + var models []string + for _, model := range modelResp.Data { + models = append(models, model.ID) + } + + return models, nil +}