Skip to content

Commit

Permalink
fix multiple returns from moderations
Browse files Browse the repository at this point in the history
  • Loading branch information
conneroisu committed Sep 6, 2024
1 parent 1a26ef0 commit c812499
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 9 additions & 5 deletions moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ type ModerationRequest struct {

// Moderation represents one of possible moderation results.
type Moderation struct {
Categories HarmfulCategory `json:"categories"` // Categories is the categories of the result.
Flagged bool `json:"flagged"` // Flagged is the flagged of the result.
Categories []HarmfulCategory `json:"categories"` // Categories is the categories of the result.
Flagged bool `json:"flagged"` // Flagged is the flagged of the result.
}

// Moderate — perform a moderation api call over a string.
Expand All @@ -157,10 +157,14 @@ func (c *Client) Moderate(
if err != nil {
return
}
if strings.Contains(resp.Choices[0].Message.Content, "unsafe") {
split := strings.Split(resp.Choices[0].Message.Content, "\n")[1]
response.Categories = SectionMap[strings.TrimSpace(split)]
content := resp.Choices[0].Message.Content
println(content)
if strings.Contains(content, "unsafe") {
response.Flagged = true
split := strings.Split(strings.Split(content, "\n")[1], ",")
for _, s := range split {
response.Categories = append(response.Categories, SectionMap[strings.TrimSpace(s)])
}
}
return
}
4 changes: 2 additions & 2 deletions moderation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestModerate(t *testing.T) {
a := assert.New(t)
a.NoError(err, "Moderation error")
a.Equal(true, mod.Flagged)
a.Equal(mod.Categories, groq.CategoryViolentCrimes)
a.Equal(mod.Categories, []groq.HarmfulCategory{groq.CategoryViolentCrimes, groq.CategoryNonviolentCrimes})
}

func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
Expand All @@ -35,7 +35,7 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
{
Message: groq.ChatCompletionMessage{
Role: groq.ChatMessageRoleAssistant,
Content: "unsafe\nS1",
Content: "unsafe\nS1,S2",
},
FinishReason: "stop",
},
Expand Down

0 comments on commit c812499

Please sign in to comment.