diff --git a/src/models.rs b/src/models.rs index 43fbe14f..ed61d1f5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -77,6 +77,7 @@ impl std::ops::DerefMut for DetectorParams { /// User request to orchestrator #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct GuardrailsHttpRequest { /// Text generation model ID pub model_id: String, @@ -343,6 +344,7 @@ pub struct ClassifiedGeneratedTextResult { /// The request format expected in the /api/v2/text/detection/content endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct TextContentDetectionHttpRequest { /// The content to run detectors on pub content: String, @@ -836,6 +838,7 @@ impl From for ClassifiedGenerat /// The request format expected in the /api/v2/text/generation-detection endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct GenerationWithDetectionHttpRequest { /// The model_id of the LLM to be invoked. pub model_id: String, @@ -917,6 +920,7 @@ pub enum GuardrailDetection { /// The request format expected in the /api/v2/text/context endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ContextDocsHttpRequest { /// The map of detectors to be used, along with their respective parameters, e.g. thresholds. pub detectors: HashMap, @@ -958,8 +962,9 @@ pub struct ContextDocsResult { pub detections: Vec, } -/// The request format expected in the /api/v2/text/detect/generated endpoint. +/// The request format expected in the /api/v2/text/detect/chat endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ChatDetectionHttpRequest { /// The map of detectors to be used, along with their respective parameters, e.g. thresholds. pub detectors: HashMap, @@ -1033,6 +1038,7 @@ pub struct ChatDetectionResult { /// The request format expected in the /api/v2/text/detect/generated endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct DetectionOnGeneratedHttpRequest { /// The prompt to be sent to the LLM. pub prompt: String, @@ -1120,6 +1126,7 @@ pub struct EvidenceObj { /// Stream content detection stream request #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] #[cfg_attr(test, derive(Default))] pub struct StreamingContentDetectionRequest { pub detectors: Option>, @@ -1206,6 +1213,32 @@ mod tests { let error = result.unwrap_err().to_string(); assert!(error.contains("`model_id` is required")); + // Additional unknown field (guardrails_config, a typo of the accurate "guardrail_config") + let json_data = r#" + { + "inputs": "The cow jumped over the moon.", + "model_id": "model-id", + "guardrails_config": { + "input": { + "models": { + "hap_model": {} + } + }, + "output": { + "models": { + "hap_model": {} + } + } + } + } + "#; + let result: Result = serde_json::from_str(json_data); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error + .to_string() + .contains("unknown field `guardrails_config`")); + // No inputs let request = GuardrailsHttpRequest { model_id: "model".to_string(), diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs index 55f1514d..1b40c9eb 100644 --- a/src/orchestrator/chat_completions_detection.rs +++ b/src/orchestrator/chat_completions_detection.rs @@ -14,20 +14,25 @@ limitations under the License. */ +use std::{ + collections::HashMap, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + use axum::http::HeaderMap; use futures::future::{join_all, try_join_all}; -use std::time::{SystemTime, UNIX_EPOCH}; -use std::{collections::HashMap, sync::Arc}; +use serde::{Deserialize, Serialize}; use tracing::{debug, info, instrument}; +use uuid::Uuid; use super::{ChatCompletionsDetectionTask, Context, Error, Orchestrator}; -use crate::clients::openai::OrchestratorWarning; use crate::{ clients::{ detector::{ChatDetectionRequest, ContentAnalysisRequest}, openai::{ ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse, - ChatDetections, Content, InputDetectionResult, OpenAiClient, + ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning, }, }, config::DetectorType, @@ -38,8 +43,6 @@ use crate::{ Chunk, UNSUITABLE_INPUT_MESSAGE, }, }; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; /// Internal structure to capture chat messages (both request and response) /// and prepare it for processing @@ -386,8 +389,10 @@ mod tests { use std::any::{Any, TypeId}; use super::*; - use crate::config::DetectorConfig; - use crate::orchestrator::{ClientMap, OrchestratorConfig}; + use crate::{ + config::DetectorConfig, + orchestrator::{ClientMap, OrchestratorConfig}, + }; // Test to verify preprocess_chat_messages works correctly for multiple content type detectors // with single message in chat request diff --git a/src/orchestrator/detector_processing/content.rs b/src/orchestrator/detector_processing/content.rs index fea862c6..96cca597 100644 --- a/src/orchestrator/detector_processing/content.rs +++ b/src/orchestrator/detector_processing/content.rs @@ -55,7 +55,6 @@ pub fn filter_chat_messages( #[cfg(test)] mod tests { use super::*; - use crate::orchestrator::chat_completions_detection::ChatMessageInternal; #[tokio::test] diff --git a/src/server.rs b/src/server.rs index 1b93421e..28d9ad93 100644 --- a/src/server.rs +++ b/src/server.rs @@ -489,7 +489,10 @@ async fn stream_content_detection( async fn detection_content( State(state): State>, headers: HeaderMap, - Json(request): Json, + WithRejection(Json(request), _): WithRejection< + Json, + Error, + >, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); info!(?trace_id, "handling request");