Skip to content

Commit

Permalink
Add Moderation trait to OpenRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali committed Jan 10, 2025
1 parent ba24be7 commit 7ded65b
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 13 deletions.
6 changes: 3 additions & 3 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl Provider for AnthropicProvider {

#[async_trait]
impl Moderation for AnthropicProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down Expand Up @@ -346,7 +346,7 @@ mod tests {
let messages = vec![Message::user().with_text("Hello?")];

let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[])
.complete_internal("You are a helpful assistant.", &messages, &[])
.await?;

if let MessageContent::Text(text) = &message.content[0] {
Expand Down Expand Up @@ -405,7 +405,7 @@ mod tests {
);

let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[tool])
.complete_internal("You are a helpful assistant.", &messages, &[tool])
.await?;

if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use tokio::sync::RwLock;

use super::configs::ModelConfig;
use crate::message::{Message, MessageContent};
use mcp_core::content::TextContent;
use mcp_core::role::Role;
use mcp_core::tool::Tool;

Expand Down Expand Up @@ -241,6 +240,7 @@ pub trait Provider: Send + Sync + Moderation {
#[cfg(test)]
mod tests {
use super::*;
use mcp_core::content::TextContent;
use serde_json::json;
use std::time::Duration;
use tokio::time::sleep;
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl Provider for DatabricksProvider {

#[async_trait]
impl Moderation for DatabricksProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/goose/src/providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl Provider for GoogleProvider {

#[async_trait]
impl Moderation for GoogleProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down Expand Up @@ -659,7 +659,7 @@ mod tests {

// Call the complete method
let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[])
.complete_internal("You are a helpful assistant.", &messages, &[])
.await?;

// Assert the response
Expand Down Expand Up @@ -690,7 +690,7 @@ mod tests {

// Call the complete method
let (message, usage) = provider
.complete(
.complete_internal(
"You are a helpful assistant.",
&messages,
&[create_test_tool()],
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl Provider for GroqProvider {

#[async_trait]
impl Moderation for GroqProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl Provider for MockProvider {

#[async_trait]
impl Moderation for MockProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
2 changes: 1 addition & 1 deletion crates/goose/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Provider for OllamaProvider {

#[async_trait]
impl Moderation for OllamaProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ impl Moderation for OpenAiProvider {
.send()
.await?;

let response_json: serde_json::Value = response.json().await?;
let response_json = handle_response(serde_json::to_value(&request)?, response)
.await?
.unwrap();

let flagged = response_json["results"][0]["flagged"]
.as_bool()
Expand Down
10 changes: 9 additions & 1 deletion crates/goose/src/providers/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde_json::Value;
use std::time::Duration;

use super::base::ProviderUsage;
use super::base::{Moderation, ModerationResult};
use super::base::{Provider, Usage};
use super::configs::OpenAiProviderConfig;
use super::configs::{ModelConfig, ProviderModelConfig};
Expand Down Expand Up @@ -73,7 +74,7 @@ impl Provider for OpenRouterProvider {
cost
)
)]
async fn complete(
async fn complete_internal(
&self,
system: &str,
messages: &[Message],
Expand Down Expand Up @@ -112,6 +113,13 @@ impl Provider for OpenRouterProvider {
}
}

#[async_trait]
impl Moderation for OpenRouterProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 7ded65b

Please sign in to comment.