Skip to content

Commit

Permalink
caching moderation results
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali committed Jan 3, 2025
1 parent 198f871 commit 2199bad
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 26 deletions.
214 changes: 189 additions & 25 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use anyhow::Result;
use lazy_static::lazy_static;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::select;
use tokio::sync::RwLock;

use super::configs::ModelConfig;
use crate::message::{Message, MessageContent};
Expand Down Expand Up @@ -74,12 +78,49 @@ impl ModerationResult {
}
}

#[derive(Debug, Clone, Default)]
pub struct ModerationCache {
cache: Arc<RwLock<HashMap<String, ModerationResult>>>,
}

impl ModerationCache {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}

pub async fn get(&self, content: &str) -> Option<ModerationResult> {
let cache = self.cache.read().await;
cache.get(content).cloned()
}

pub async fn set(&self, content: String, result: ModerationResult) {
let mut cache = self.cache.write().await;
cache.insert(content, result);
}
}

lazy_static! {
static ref DEFAULT_CACHE: ModerationCache = ModerationCache::new();
}

use async_trait::async_trait;
use serde_json::Value;

/// Trait for handling content moderation
#[async_trait]
pub trait Moderation: Send + Sync {
/// Get the moderation cache
fn moderation_cache(&self) -> &ModerationCache {
&DEFAULT_CACHE
}

/// Internal moderation method to be implemented by providers
async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}

/// Moderate the given content
///
/// # Arguments
Expand All @@ -88,7 +129,20 @@ pub trait Moderation: Send + Sync {
/// # Returns
/// A ModerationResult containing the moderation decision and details
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
// Check cache first
if let Some(cached) = self.moderation_cache().get(content).await {
return Ok(cached);
}

// If not in cache, do moderation
let result = self.moderate_content_internal(content).await?;

// Cache the result
self.moderation_cache()
.set(content.to_string(), result.clone())
.await;

Ok(result)
}
}

Expand Down Expand Up @@ -128,36 +182,37 @@ pub trait Provider: Send + Sync + Moderation {

// Get the content to moderate
let content = latest_user_msg.content.first().unwrap().as_text().unwrap();
println!("Content to moderate: {}", content);

// Create futures for both operations
let moderation_fut = self.moderate_content(content);
// Start completion and moderation immediately
let completion_fut = self.complete_internal(system, messages, tools);

// Pin the futures
tokio::pin!(moderation_fut);
let moderation_fut = self.moderate_content(content);
tokio::pin!(completion_fut);
tokio::pin!(moderation_fut);

// Use select! to run both concurrently
let result = select! {
// Run moderation and completion concurrently
select! {
moderation = &mut moderation_fut => {
// If moderation completes first, check the result
let moderation_result = moderation?;
if moderation_result.flagged {
let categories = moderation_result.categories
let result = moderation?;

if result.flagged {
let categories = result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
.join(", ");
return Err(anyhow::anyhow!(
"Content was flagged for moderation in categories: {}",
categories
));
}
// If moderation passes, wait for completion

// Moderation passed, wait for completion
Ok(completion_fut.await?)
}
completion = &mut completion_fut => {
// If completion finishes first, still check moderation
// Completion finished first, still need to check moderation
let completion_result = completion?;
let moderation_result = moderation_fut.await?;

if moderation_result.flagged {
let categories = moderation_result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
Expand All @@ -167,11 +222,10 @@ pub trait Provider: Send + Sync + Moderation {
categories
));
}

Ok(completion_result)
}
};

result
}
}

/// Internal completion method to be implemented by providers
Expand Down Expand Up @@ -240,7 +294,7 @@ mod tests {

#[async_trait]
impl Moderation for TestProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
// Return quickly with flagged content
Ok(ModerationResult::new(
true,
Expand Down Expand Up @@ -298,7 +352,7 @@ mod tests {

#[async_trait]
impl Moderation for TestProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
sleep(Duration::from_secs(1)).await;
// Return quickly with flagged content
Ok(ModerationResult::new(
Expand Down Expand Up @@ -362,13 +416,27 @@ mod tests {

#[tokio::test]
async fn test_moderation_pass_completion_pass() {
// Create a dedicated cache for this test
let cache = Arc::new(ModerationCache::new());

#[derive(Clone)]
struct TestProvider;
struct TestProvider {
cache: Arc<ModerationCache>,
}

impl TestProvider {
fn new(cache: Arc<ModerationCache>) -> Self {
Self { cache }
}
}

#[async_trait]
impl Moderation for TestProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
// Return quickly with flagged content
fn moderation_cache(&self) -> &ModerationCache {
&self.cache
}

async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}
Expand Down Expand Up @@ -404,7 +472,7 @@ mod tests {
}
}

let provider = TestProvider;
let provider = TestProvider::new(cache);
let test_message = Message {
role: Role::User,
created: chrono::Utc::now().timestamp(),
Expand All @@ -415,8 +483,8 @@ mod tests {
};

let result = provider.complete("system", &[test_message], &[]).await;
assert!(result.is_ok(), "Expected Ok result, got {:?}", result);

assert!(result.is_ok());
let (message, usage) = result.unwrap();
assert_eq!(message.content[0].as_text().unwrap(), "test response");
assert_eq!(usage.model, "test-model");
Expand All @@ -429,7 +497,7 @@ mod tests {

#[async_trait]
impl Moderation for TestProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
// Simulate some processing time
sleep(Duration::from_millis(100)).await;
Ok(ModerationResult::new(false, None, None))
Expand Down Expand Up @@ -484,4 +552,100 @@ mod tests {
assert_eq!(message.content[0].as_text().unwrap(), "test response");
assert_eq!(usage.model, "test-model");
}

#[tokio::test]
async fn test_moderation_cache() {
// Create a local cache for this test
let cache = Arc::new(ModerationCache::new());

#[derive(Clone)]
struct TestProvider {
moderation_count: Arc<RwLock<i32>>,
cache: Arc<ModerationCache>,
}

impl TestProvider {
fn new(cache: Arc<ModerationCache>, count: Arc<RwLock<i32>>) -> Self {
Self {
moderation_count: count,
cache,
}
}
}

#[async_trait]
impl Moderation for TestProvider {
fn moderation_cache(&self) -> &ModerationCache {
&self.cache
}

async fn moderate_content_internal(&self, _content: &str) -> Result<ModerationResult> {
// Increment the moderation count
let mut count = self.moderation_count.write().await;
*count += 1;

Ok(ModerationResult::new(false, None, None))
}
}

#[async_trait]
impl Provider for TestProvider {
fn get_usage(&self, _data: &Value) -> Result<Usage> {
Ok(Usage::new(Some(1), Some(1), Some(2)))
}

fn get_model_config(&self) -> &ModelConfig {
panic!("Should not be called");
}

async fn complete_internal(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {
Ok((
Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::text("test response")],
},
ProviderUsage::new(
"test-model".to_string(),
Usage::new(Some(1), Some(1), Some(2)),
None,
),
))
}
}

let count = Arc::new(RwLock::new(0));
let provider = TestProvider::new(cache.clone(), count.clone());
let test_message = Message {
role: Role::User,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: "test".to_string(),
annotations: None,
})],
};

// First call should trigger moderation
let result = provider
.complete("system", &[test_message.clone()], &[])
.await;
assert!(result.is_ok(), "First call failed: {:?}", result);

// Second call with same message should use cache
let result = provider.complete("system", &[test_message], &[]).await;
assert!(result.is_ok(), "Second call failed: {:?}", result);

// Check that moderation was only called once
let count = count.read().await;
assert_eq!(
*count, 1,
"Expected moderation to be called once, got {}",
*count
);
}
}
4 changes: 3 additions & 1 deletion crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl Provider for OpenAiProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {

// Not checking for o1 model here since system message is not supported by o1
let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?;

Expand All @@ -98,6 +99,7 @@ impl Provider for OpenAiProvider {
let model = get_model(&response);
let cost = cost(&usage, &model_pricing_for(&model));


Ok((message, ProviderUsage::new(model, usage, cost)))
}

Expand All @@ -108,7 +110,7 @@ impl Provider for OpenAiProvider {

#[async_trait]
impl Moderation for OpenAiProvider {
async fn moderate_content(&self, content: &str) -> Result<ModerationResult> {
async fn moderate_content_internal(&self, content: &str) -> Result<ModerationResult> {
let url = format!("{}/v1/moderations", self.config.host.trim_end_matches('/'));

let request =
Expand Down

0 comments on commit 2199bad

Please sign in to comment.