Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handling for unknown fields in orchestrator request #285

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl std::ops::DerefMut for DetectorParams {

/// User request to orchestrator
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GuardrailsHttpRequest {
/// Text generation model ID
pub model_id: String,
Expand Down Expand Up @@ -343,6 +344,7 @@ pub struct ClassifiedGeneratedTextResult {

/// The request format expected in the /api/v2/text/detection/content endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct TextContentDetectionHttpRequest {
/// The content to run detectors on
pub content: String,
Expand Down Expand Up @@ -836,6 +838,7 @@ impl From<pb::caikit_data_model::nlp::GeneratedTextResult> for ClassifiedGenerat

/// The request format expected in the /api/v2/text/generation-detection endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GenerationWithDetectionHttpRequest {
/// The model_id of the LLM to be invoked.
pub model_id: String,
Expand Down Expand Up @@ -917,6 +920,7 @@ pub enum GuardrailDetection {

/// The request format expected in the /api/v2/text/context endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ContextDocsHttpRequest {
/// The map of detectors to be used, along with their respective parameters, e.g. thresholds.
pub detectors: HashMap<String, DetectorParams>,
Expand Down Expand Up @@ -958,8 +962,9 @@ pub struct ContextDocsResult {
pub detections: Vec<DetectionResult>,
}

/// The request format expected in the /api/v2/text/detect/generated endpoint.
/// The request format expected in the /api/v2/text/detect/chat endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatDetectionHttpRequest {
/// The map of detectors to be used, along with their respective parameters, e.g. thresholds.
pub detectors: HashMap<String, DetectorParams>,
Expand Down Expand Up @@ -1033,6 +1038,7 @@ pub struct ChatDetectionResult {

/// The request format expected in the /api/v2/text/detect/generated endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DetectionOnGeneratedHttpRequest {
/// The prompt to be sent to the LLM.
pub prompt: String,
Expand Down Expand Up @@ -1120,6 +1126,7 @@ pub struct EvidenceObj {

/// Stream content detection stream request
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
#[cfg_attr(test, derive(Default))]
pub struct StreamingContentDetectionRequest {
pub detectors: Option<HashMap<String, DetectorParams>>,
Expand Down Expand Up @@ -1206,6 +1213,32 @@ mod tests {
let error = result.unwrap_err().to_string();
assert!(error.contains("`model_id` is required"));

// Additional unknown field (guardrails_config, a typo of the accurate "guardrail_config")
let json_data = r#"
{
"inputs": "The cow jumped over the moon.",
"model_id": "model-id",
"guardrails_config": {
"input": {
"models": {
"hap_model": {}
}
},
"output": {
"models": {
"hap_model": {}
}
}
}
}
"#;
let result: Result<GuardrailsHttpRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error
.to_string()
.contains("unknown field `guardrails_config`"));

// No inputs
let request = GuardrailsHttpRequest {
model_id: "model".to_string(),
Expand Down
21 changes: 13 additions & 8 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,25 @@
limitations under the License.

*/
use std::{
collections::HashMap,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};

use axum::http::HeaderMap;
use futures::future::{join_all, try_join_all};
use std::time::{SystemTime, UNIX_EPOCH};
use std::{collections::HashMap, sync::Arc};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, instrument};
use uuid::Uuid;

use super::{ChatCompletionsDetectionTask, Context, Error, Orchestrator};
use crate::clients::openai::OrchestratorWarning;
use crate::{
clients::{
detector::{ChatDetectionRequest, ContentAnalysisRequest},
openai::{
ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse,
ChatDetections, Content, InputDetectionResult, OpenAiClient,
ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning,
},
},
config::DetectorType,
Expand All @@ -38,8 +43,6 @@ use crate::{
Chunk, UNSUITABLE_INPUT_MESSAGE,
},
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;

/// Internal structure to capture chat messages (both request and response)
/// and prepare it for processing
Expand Down Expand Up @@ -386,8 +389,10 @@ mod tests {
use std::any::{Any, TypeId};

use super::*;
use crate::config::DetectorConfig;
use crate::orchestrator::{ClientMap, OrchestratorConfig};
use crate::{
config::DetectorConfig,
orchestrator::{ClientMap, OrchestratorConfig},
};

// Test to verify preprocess_chat_messages works correctly for multiple content type detectors
// with single message in chat request
Expand Down
1 change: 0 additions & 1 deletion src/orchestrator/detector_processing/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ pub fn filter_chat_messages(
#[cfg(test)]
mod tests {
use super::*;

use crate::orchestrator::chat_completions_detection::ChatMessageInternal;

#[tokio::test]
Expand Down
5 changes: 4 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ async fn stream_content_detection(
async fn detection_content(
State(state): State<Arc<ServerState>>,
headers: HeaderMap,
Json(request): Json<models::TextContentDetectionHttpRequest>,
WithRejection(Json(request), _): WithRejection<
Json<models::TextContentDetectionHttpRequest>,
Error,
>,
) -> Result<impl IntoResponse, Error> {
let trace_id = Span::current().context().span().span_context().trace_id();
info!(?trace_id, "handling request");
Expand Down
Loading