Skip to content

Commit

Permalink
add encryption logic to chat completions streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Dec 6, 2024
1 parent a9d2977 commit aab4f80
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 11 deletions.
2 changes: 2 additions & 0 deletions atoma-service/src/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,9 @@ async fn handle_streaming_response(
state.keystore.clone(),
state.address_index,
model.to_string(),
client_encryption_metadata,
timer,
state.encryption_sender.clone(),
))
.keep_alive(
axum::response::sse::KeepAlive::new()
Expand Down
4 changes: 2 additions & 2 deletions atoma-service/src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub(crate) async fn handle_confidential_compute_encryption_response(
client_encryption_metadata: Option<EncryptionMetadata>,
) -> Result<(), StatusCode> {
if let Some(EncryptionMetadata {
client_dh_public_key,
proxy_x25519_public_key,
salt,
}) = client_encryption_metadata
{
Expand All @@ -100,7 +100,7 @@ pub(crate) async fn handle_confidential_compute_encryption_response(
ConfidentialComputeEncryptionRequest {
plaintext: response_body.to_string().as_bytes().to_vec(),
salt,
diffie_hellman_public_key: client_dh_public_key,
proxy_x25519_public_key,
},
sender,
))
Expand Down
13 changes: 7 additions & 6 deletions atoma-service/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ const IMAGE_N: &str = "n";
/// Metadata for confidential compute encryption requests
#[derive(Clone, Debug)]
pub struct EncryptionMetadata {
/// The client's Diffie-Hellman public key
pub client_dh_public_key: [u8; DH_PUBLIC_KEY_SIZE],
/// The client's proxy X25519 public key
pub proxy_x25519_public_key: [u8; DH_PUBLIC_KEY_SIZE],
/// The salt
pub salt: Vec<u8>,
}
Expand Down Expand Up @@ -143,11 +143,11 @@ impl RequestMetadata {
/// ```
pub fn with_client_encryption_metadata(
mut self,
client_dh_public_key: [u8; DH_PUBLIC_KEY_SIZE],
proxy_x25519_public_key: [u8; DH_PUBLIC_KEY_SIZE],
salt: Vec<u8>,
) -> Self {
self.client_encryption_metadata = Some(EncryptionMetadata {
client_dh_public_key,
proxy_x25519_public_key,
salt,
});
self
Expand All @@ -170,6 +170,7 @@ impl RequestMetadata {
self.endpoint_path = endpoint_path;
self
}
}

/// Middleware for verifying the signature of incoming requests.
///
Expand Down Expand Up @@ -621,7 +622,7 @@ pub async fn confidential_compute_middleware(
let confidential_compute_decryption_request = ConfidentialComputeDecryptionRequest {
ciphertext: cyphertext_bytes,
nonce: nonce_bytes,
salt: salt_bytes,
salt: salt_bytes.clone(),
proxy_x25519_public_key: proxy_x25519_public_key_bytes,
node_x25519_public_key: node_x25519_public_key_bytes,
};
Expand All @@ -641,7 +642,7 @@ pub async fn confidential_compute_middleware(
let body = Body::from(plaintext);
req_parts.extensions.insert(
RequestMetadata::default()
.with_client_encryption_metadata(diffie_hellman_public_key_bytes, salt_bytes),
.with_client_encryption_metadata(proxy_x25519_public_key_bytes, salt_bytes),
);
let req = Request::from_parts(req_parts, body);
Ok(next.run(req).await)
Expand Down
2 changes: 1 addition & 1 deletion atoma-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type DecryptionRequest = (
oneshot::Sender<ConfidentialComputeDecryptionResponse>,
);

type EncryptionRequest = (
pub(crate) type EncryptionRequest = (
ConfidentialComputeEncryptionRequest,
oneshot::Sender<ConfidentialComputeEncryptionResponse>,
);
Expand Down
174 changes: 172 additions & 2 deletions atoma-service/src/streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use std::{
task::{Context, Poll},
};

use atoma_confidential::types::{
ConfidentialComputeEncryptionRequest, ConfidentialComputeEncryptionResponse,
};
use atoma_state::types::AtomaAtomaStateManagerEvent;
use atoma_utils::hashing::blake2b_hash;
use axum::body::Bytes;
Expand All @@ -13,14 +16,20 @@ use futures::Stream;
use prometheus::HistogramTimer;
use serde_json::{json, Value};
use sui_keys::keystore::FileBasedKeystore;
use tokio::sync::{
mpsc::UnboundedSender,
oneshot::{self, error::TryRecvError},
};
use tracing::{error, instrument};

use crate::{
handlers::prometheus::{
CHAT_COMPLETIONS_DECODING_TIME, CHAT_COMPLETIONS_INPUT_TOKENS_METRICS,
CHAT_COMPLETIONS_OUTPUT_TOKENS_METRICS,
},
middleware::EncryptionMetadata,
server::utils,
server::EncryptionRequest,
};

/// The chunk that indicates the end of a streaming response
Expand Down Expand Up @@ -62,6 +71,14 @@ pub struct Streamer {
first_token_generation_timer: Option<HistogramTimer>,
/// The decoding phase timer for the request.
decoding_phase_timer: Option<HistogramTimer>,
/// The client encryption metadata for the request
client_encryption_metadata: Option<EncryptionMetadata>,
/// Confidential compute encryption sender
confidential_compute_encryption_sender: UnboundedSender<EncryptionRequest>,
/// The receiver for the encryption response
encryption_response_receiver: Option<oneshot::Receiver<ConfidentialComputeEncryptionResponse>>,
/// Boolean value flagging if the stream is currently waiting for a chunk encryption
waiting_for_encrypted_chunk: bool,
}

/// Represents the various states of a streaming process
Expand Down Expand Up @@ -89,7 +106,9 @@ impl Streamer {
keystore: Arc<FileBasedKeystore>,
address_index: usize,
model: String,
client_encryption_metadata: Option<EncryptionMetadata>,
first_token_generation_timer: HistogramTimer,
confidential_compute_encryption_sender: UnboundedSender<EncryptionRequest>,
) -> Self {
Self {
stream: Box::pin(stream),
Expand All @@ -104,6 +123,10 @@ impl Streamer {
model,
first_token_generation_timer: Some(first_token_generation_timer),
decoding_phase_timer: None,
client_encryption_metadata,
confidential_compute_encryption_sender,
encryption_response_receiver: None,
waiting_for_encrypted_chunk: false,
}
}

Expand Down Expand Up @@ -218,6 +241,117 @@ impl Streamer {

Ok(signature)
}

/// Handles the processing of an encrypted chunk response from the confidential compute service.
///
/// This method attempts to receive and process an encrypted response from a previously initiated
/// encryption request. It manages the oneshot channel receiver that was set up to receive the
/// encrypted data.
///
/// # Returns
///
/// Returns a `Result<Option<Value>, Error>` where:
/// * `Ok(Some(Value))` - Successfully received and processed an encrypted chunk
/// * `Ok(None)` - No encrypted chunk is available yet (receiver is empty)
/// * `Err(Error)` - The channel has been dropped or another error occurred
///
/// # State Changes
///
/// * Consumes `encryption_response_receiver` using `take()` to process the response
///
/// # Example Response Format
///
/// When successful, returns a JSON object containing:
/// ```json
/// {
/// "ciphertext": "encrypted_data_here",
/// "nonce": "nonce_value_here"
/// }
/// ```
#[instrument(
level = "debug",
skip(self),
fields(path = "streamer-handle_encrypted_chunk")
)]
fn handle_encrypted_chunk(&mut self) -> Result<Option<Value>, Error> {
if let Some(mut receiver) = self.encryption_response_receiver.take() {
match receiver.try_recv() {
Ok(ConfidentialComputeEncryptionResponse { ciphertext, nonce }) => {
// Construct encrypted JSON
let encrypted_chunk = json!({
"ciphertext": ciphertext,
"nonce": nonce,
});

return Ok(Some(encrypted_chunk));
}
Err(e) => {
if e == TryRecvError::Empty {
return Ok(None);
}
return Err(Error::new(format!(
"Oneshot sender channel has been dropped"
)));
}
}
}
Ok(None)
}

/// Handles the encryption request for a chunk of streaming data.
///
/// This method initiates the encryption process for a given chunk by:
/// 1. Creating a oneshot channel for receiving the encryption response
/// 2. Sending the encryption request with the chunk data to the confidential compute service
/// 3. Setting up the streamer to wait for the encrypted response
///
/// # Arguments
///
/// * `chunk` - The JSON value containing the data to be encrypted
/// * `proxy_x25519_public_key` - The X25519 public key of the proxy (32 bytes)
/// * `salt` - The salt value used in the encryption process
///
/// # Returns
///
/// Returns a `Result<(), Error>` where:
/// * `Ok(())` - The encryption request was successfully sent
/// * `Err(Error)` - An error occurred while sending the encryption request
///
/// # State Changes
///
/// * Sets `waiting_for_encrypted_chunk` to `true`
/// * Updates `encryption_response_receiver` with the new receiver
#[instrument(
level = "debug",
skip_all,
fields(
proxy_x25519_public_key = ?proxy_x25519_public_key
)
)]
fn handle_encryption_request(
&mut self,
chunk: &Value,
proxy_x25519_public_key: [u8; 32],
salt: Vec<u8>,
) -> Result<(), Error> {
let (sender, receiver) = oneshot::channel();
self.confidential_compute_encryption_sender
.send((
ConfidentialComputeEncryptionRequest {
plaintext: chunk.to_string().into(),
proxy_x25519_public_key,
salt,
},
sender,
))
.map_err(|e| {
error!("Error sending encryption request: {}", e);
Error::new(format!("Error sending encryption request: {}", e))
})?;
self.waiting_for_encrypted_chunk = true;
self.encryption_response_receiver = Some(receiver);
Ok(())
}
}

impl Stream for Streamer {
Expand All @@ -230,6 +364,17 @@ impl Stream for Streamer {

match self.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
if self.waiting_for_encrypted_chunk {
match self.handle_encrypted_chunk() {
Ok(Some(chunk)) => {
self.waiting_for_encrypted_chunk = false;
return Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)));
}
Err(e) => return Poll::Ready(Some(Err(e))),
Ok(None) => return Poll::Pending,
}
}

if self.status != StreamStatus::Started {
self.status = StreamStatus::Started;
}
Expand Down Expand Up @@ -286,15 +431,40 @@ impl Stream for Streamer {
self.status = StreamStatus::Completed;
let signature = self.handle_final_chunk(usage)?;
chunk["signature"] = json!(signature);
Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)))
let client_encryption_metadata = self.client_encryption_metadata.clone();
if let Some(EncryptionMetadata {
proxy_x25519_public_key,
salt,
}) = client_encryption_metadata
{
// NOTE: We only need to perform chunk encryption when sending the chunk back to the client
self.handle_encryption_request(&chunk, proxy_x25519_public_key, salt)?;
// NOTE: We don't expect the encryption to be ready immediately, so we return pending
// for now, so next time we poll, we'll check if the encryption is ready
Poll::Pending
} else {
Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)))
}
} else {
error!("Error getting usage from chunk");
Poll::Ready(Some(Err(Error::new("Error getting usage from chunk"))))
}
} else {
// Accumulate regular chunks
self.accumulated_response.push(chunk.clone());
Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)))
let should_encrypt = self
.client_encryption_metadata
.as_ref()
.map(|metadata| (metadata.proxy_x25519_public_key, metadata.salt.clone()));
if let Some((proxy_x25519_public_key, salt)) = should_encrypt {
// NOTE: We only need to perform chunk encryption when sending the chunk back to the client
self.handle_encryption_request(&chunk, proxy_x25519_public_key, salt)?;
// NOTE: We don't expect the encryption to be ready immediately, so we return pending
// for now, so next time we poll, we'll check if the encryption is ready
Poll::Pending
} else {
Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)))
}
}
}
Poll::Ready(Some(Err(e))) => {
Expand Down
1 change: 1 addition & 0 deletions atoma-service/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ mod middleware {
payload_hash: [0u8; 32],
request_type: RequestType::ChatCompletions,
endpoint_path: "/".to_string(),
client_encryption_metadata: None,
};

let mut req = Request::builder()
Expand Down

0 comments on commit aab4f80

Please sign in to comment.