diff --git a/Cargo.lock b/Cargo.lock index 721fdd342..a376815b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -502,9 +502,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccb2b3a7030dc9a3c9a08ce0b25decea5130e9db19619d4dffbbff34f75fe850" +checksum = "4cc56a5c96ec741de6c5e6bf1ce6948be969d6506dfa9c39cffc284e31e4979b" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -1856,6 +1856,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "aws-smithy-runtime", + "aws-smithy-runtime-api", "aws-smithy-types", "bincode", "blake3", @@ -1865,8 +1866,10 @@ dependencies = [ "futures", "hex", "http 1.1.0", + "http-body 1.0.0", "hyper", "hyper-rustls", + "log", "lz4_flex", "memory-stats", "nativelink-config", diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index 0a9daa983..cf57ebd0c 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -480,6 +480,14 @@ pub struct S3Store { #[serde(default)] pub retry: Retry, + /// The maximum buffer size to retain in case of a retryable error + /// during upload. Setting this to zero will disable upload buffering; + /// this means that in the event of a failure during upload, the entire + /// upload will be aborted and the client will likely receive an error. + /// + /// Default: 5MB. + pub max_retry_buffer_per_request: Option, + /// Maximum number of concurrent UploadPart requests per MultipartUpload. /// /// Default: 10. diff --git a/nativelink-store/BUILD.bazel b/nativelink-store/BUILD.bazel index 002a6e27c..adf0f67ef 100644 --- a/nativelink-store/BUILD.bazel +++ b/nativelink-store/BUILD.bazel @@ -49,8 +49,10 @@ rust_library( "@crates//:filetime", "@crates//:futures", "@crates//:hex", + "@crates//:http-body", "@crates//:hyper", "@crates//:hyper-rustls", + "@crates//:log", "@crates//:lz4_flex", "@crates//:parking_lot", "@crates//:prost", @@ -98,12 +100,14 @@ rust_test_suite( "@crates//:async-lock", "@crates//:aws-sdk-s3", "@crates//:aws-smithy-runtime", + "@crates//:aws-smithy-runtime-api", "@crates//:aws-smithy-types", "@crates//:bincode", "@crates//:bytes", "@crates//:filetime", "@crates//:futures", "@crates//:http", + "@crates//:http-body", "@crates//:hyper", "@crates//:memory-stats", "@crates//:once_cell", diff --git a/nativelink-store/Cargo.toml b/nativelink-store/Cargo.toml index b89c93990..1ea647768 100644 --- a/nativelink-store/Cargo.toml +++ b/nativelink-store/Cargo.toml @@ -21,8 +21,10 @@ bytes = "1.6.0" filetime = "0.2.23" futures = "0.3.30" hex = "0.4.3" +http-body = "1.0.0" hyper = { version = "0.14.28" } hyper-rustls = { version = "0.24.2", features = ["webpki-tokio"] } +log = "0.4.21" lz4_flex = "0.11.3" parking_lot = "0.12.1" prost = "0.12.4" @@ -46,3 +48,4 @@ http = "1.1.0" aws-smithy-types = "1.1.8" aws-sdk-s3 = { version = "1.21.0" } aws-smithy-runtime = { version = "1.2.1", features = ["test-util"] } +aws-smithy-runtime-api = "1.4.0" diff --git a/nativelink-store/src/s3_store.rs b/nativelink-store/src/s3_store.rs index c514b9e8e..6e11409df 100644 --- a/nativelink-store/src/s3_store.rs +++ b/nativelink-store/src/s3_store.rs @@ -22,7 +22,7 @@ use std::{cmp, env}; use async_trait::async_trait; use aws_config::default_provider::credentials; -use aws_config::BehaviorVersion; +use aws_config::{AppName, BehaviorVersion}; use aws_sdk_s3::config::Region; use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::get_object::GetObjectError; @@ -31,36 +31,103 @@ use aws_sdk_s3::primitives::{ByteStream, SdkBody}; use aws_sdk_s3::types::builders::{CompletedMultipartUploadBuilder, CompletedPartBuilder}; use aws_sdk_s3::Client; use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; +use bytes::Bytes; +use futures::future::FusedFuture; use futures::stream::{unfold, FuturesUnordered}; -use futures::{try_join, FutureExt, StreamExt, TryStreamExt}; -use hyper::client::connect::HttpConnector; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use http_body::{Frame, SizeHint}; +use hyper::client::connect::{Connected, Connection, HttpConnector}; use hyper::service::Service; use hyper::Uri; use hyper_rustls::{HttpsConnector, MaybeHttpsStream}; -use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt}; -use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; +use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; +use nativelink_util::buf_channel::{ + make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf, +}; use nativelink_util::common::DigestInfo; +use nativelink_util::fs; use nativelink_util::health_utils::{default_health_status_indicator, HealthStatusIndicator}; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::store_trait::{Store, UploadSizeInfo}; use rand::rngs::OsRng; use rand::Rng; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, SemaphorePermit}; use tokio::time::sleep; -use tokio_stream::wrappers::ReceiverStream; use tracing::info; use crate::cas_utils::is_zero_digest; // S3 parts cannot be smaller than this number. See: // https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html -const MIN_MULTIPART_SIZE: usize = 5 * 1024 * 1024; // 5mb. +const MIN_MULTIPART_SIZE: usize = 5 * 1024 * 1024; // 5MB. + +// S3 parts cannot be larger than this number. See: +// https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html +const MAX_MULTIPART_SIZE: usize = 5 * 1024 * 1024 * 1024; // 5GB. + +// S3 parts cannot be more than this number. See: +// https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html +const MAX_UPLOAD_PARTS: usize = 10_000; + +// Default max buffer size for retrying upload requests. +// Note: If you change this, adjust the docs in the config. +const DEFAULT_MAX_RETRY_BUFFER_PER_REQUEST: usize = 5 * 1024 * 1024; // 5MB. // Default limit for concurrent part uploads per multipart upload. // Note: If you change this, adjust the docs in the config. const DEFAULT_MULTIPART_MAX_CONCURRENT_UPLOADS: usize = 10; +pub struct ConnectionWithPermit { + connection: T, + _permit: SemaphorePermit<'static>, +} + +impl Connection for ConnectionWithPermit { + fn connected(&self) -> Connected { + self.connection.connected() + } +} + +impl AsyncRead for ConnectionWithPermit { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).connection).poll_read(cx, buf) + } +} + +impl AsyncWrite for ConnectionWithPermit { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).connection).poll_write(cx, buf) + } + + #[inline] + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).connection).poll_flush(cx) + } + + #[inline] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut Pin::get_mut(self).connection).poll_shutdown(cx) + } +} + #[derive(Clone)] pub struct TlsConnector { connector: HttpsConnector, @@ -97,10 +164,20 @@ impl TlsConnector { } } - async fn call_with_retry(&self, req: &Uri) -> Result, Error> { + async fn call_with_retry( + &self, + req: &Uri, + ) -> Result>, Error> { let retry_stream_fn = unfold(self.connector.clone(), move |mut connector| async move { + let _permit = fs::get_permit().await.unwrap(); match connector.call(req.clone()).await { - Ok(stream) => Some((RetryResult::Ok(stream), connector)), + Ok(connection) => Some(( + RetryResult::Ok(ConnectionWithPermit { + connection, + _permit, + }), + connector, + )), Err(e) => Some(( RetryResult::Retry(make_err!( Code::Unavailable, @@ -115,7 +192,7 @@ impl TlsConnector { } impl Service for TlsConnector { - type Response = MaybeHttpsStream; + type Response = ConnectionWithPermit>; type Error = Error; type Future = Pin> + Send + 'static>>; @@ -132,11 +209,36 @@ impl Service for TlsConnector { } } +pub struct BodyWrapper { + reader: DropCloserReadHalf, + size: u64, +} + +impl http_body::Body for BodyWrapper { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let reader = Pin::new(&mut Pin::get_mut(self).reader); + reader + .poll_next(cx) + .map(|maybe_bytes_res| maybe_bytes_res.map(|res| res.map(Frame::data))) + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.size) + } +} + pub struct S3Store { s3_client: Arc, bucket: String, key_prefix: String, retrier: Retrier, + max_retry_buffer_per_request: usize, multipart_max_concurrent_uploads: usize, } @@ -157,6 +259,12 @@ impl S3Store { let credential_provider = credentials::default_provider().await; let mut config_builder = aws_config::defaults(BehaviorVersion::v2023_11_09()) .credentials_provider(credential_provider) + .app_name(AppName::new("nativelink").expect("valid app name")) + .timeout_config( + aws_config::timeout::TimeoutConfig::builder() + .connect_timeout(Duration::from_secs(15)) + .build(), + ) .region(Region::new(Cow::Owned(config.region.clone()))) .http_client(http_client); // TODO(allada) When aws-sdk supports this env variable we should be able @@ -184,6 +292,9 @@ impl S3Store { jitter_fn, config.retry.to_owned(), ), + max_retry_buffer_per_request: config + .max_retry_buffer_per_request + .unwrap_or(DEFAULT_MAX_RETRY_BUFFER_PER_REQUEST), multipart_max_concurrent_uploads: config .multipart_max_concurrent_uploads .map_or(DEFAULT_MULTIPART_MAX_CONCURRENT_UPLOADS, |v| v), @@ -278,86 +389,120 @@ impl Store for S3Store { let max_size = match upload_size { UploadSizeInfo::ExactSize(sz) | UploadSizeInfo::MaxSize(sz) => sz, }; - // NOTE(blaise.bruer) It might be more optimal to use a different + + // Note(allada) It might be more optimal to use a different // heuristic here, but for simplicity we use a hard coded value. // Anything going down this if-statement will have the advantage of only // 1 network request for the upload instead of minimum of 3 required for // multipart upload requests. - if max_size < MIN_MULTIPART_SIZE { - let (body, content_length) = if let UploadSizeInfo::ExactSize(sz) = upload_size { - ( - ByteStream::new(SdkBody::from( - reader - .consume(Some(sz)) - .await - .err_tip(|| "Failed to take {sz} bytes from reader in S3")?, - )), - sz as i64, - ) - } else { - // Just in case, we want to capture the EOF, so +1. - let write_buf = reader - .consume(Some(max_size + 1)) - .await - .err_tip(|| "Failed to read file in upload to s3 in single chunk")?; - error_if!( - write_buf.len() > max_size, - "More data than provided max_size in s3_store {}", - digest.hash_str() - ); - let content_length = write_buf.len(); - ( - ByteStream::new(SdkBody::from(write_buf)), - content_length as i64, - ) - }; - + // + // Note(allada) If the upload size is not known, we go down the multipart upload path. + // This is not very efficient, but it greatly reduces the complexity of the code. + if max_size < MIN_MULTIPART_SIZE && matches!(upload_size, UploadSizeInfo::ExactSize(_)) { + reader.set_max_recent_data_size(self.max_retry_buffer_per_request); return self - .s3_client - .put_object() - .bucket(&self.bucket) - .key(s3_path.clone()) - .content_length(content_length) - .body(body) - .send() - .await - .map_or_else(|e| Err(make_err!(Code::Internal, "{e:?}")), |_| Ok(())) - .err_tip(|| "Failed to upload file to s3 in single chunk"); + .retrier + .retry(unfold(reader, move |mut reader| async move { + let UploadSizeInfo::ExactSize(sz) = upload_size else { + unreachable!("upload_size must be UploadSizeInfo::ExactSize here"); + }; + // We need to make a new pair here because the aws sdk does not give us + // back the body after we send it in order to retry. + let (mut tx, rx) = make_buf_channel_pair(); + + // Upload the data to the S3 backend. + let result = { + let reader_ref = &mut reader; + let (upload_res, bind_res) = tokio::join!( + self.s3_client + .put_object() + .bucket(&self.bucket) + .key(s3_path.clone()) + .content_length(sz as i64) + .body(ByteStream::from_body_1_x(BodyWrapper { + reader: rx, + size: sz as u64, + })) + .send() + .map_ok_or_else(|e| Err(make_err!(Code::Aborted, "{e:?}")), |_| Ok(())), + // Stream all data from the reader channel to the writer channel. + tx.bind(reader_ref) + ); + upload_res + .merge(bind_res) + .err_tip(|| "Failed to upload file to s3 in single chunk") + }; + + // If we failed to upload the file, check to see if we can retry. + let retry_result = result.map_or_else(|mut e| { + // Ensure our code is Code::Aborted, so the client can retry if possible. + e.code = Code::Aborted; + let bytes_received = reader.get_bytes_received(); + if let Err(try_reset_err) = reader.try_reset_stream() { + let e = e + .merge(try_reset_err) + .append(format!("Failed to retry upload with {bytes_received} bytes received in S3Store::update")); + log::error!("{e:?}"); + return RetryResult::Err(e); + } + let e = e.append(format!("Retry on upload happened with {bytes_received} bytes received in S3Store::update")); + log::info!("{e:?}"); + RetryResult::Retry(e) + }, |()| RetryResult::Ok(())); + Some((retry_result, reader)) + })) + .await; } - // S3 requires us to upload in parts if the size is greater than 5GB. The part size must be at least - // 5mb and can have up to 10,000 parts. - let bytes_per_upload_part = - cmp::max(MIN_MULTIPART_SIZE, max_size / (MIN_MULTIPART_SIZE - 1)); - - let response: CreateMultipartUploadOutput = self - .s3_client - .create_multipart_upload() - .bucket(&self.bucket) - .key(s3_path) - .send() - .await - .map_err(|e| { - make_err!( - Code::Internal, - "Failed to create multipart upload to s3: {e:?}" - ) - })?; - - let upload_id = response - .upload_id - .err_tip(|| "Expected upload_id to be set by s3 response")?; + let upload_id = &self + .retrier + .retry(unfold((), move |()| async move { + let retry_result = self + .s3_client + .create_multipart_upload() + .bucket(&self.bucket) + .key(s3_path) + .send() + .await + .map_or_else( + |e| { + RetryResult::Retry(make_err!( + Code::Aborted, + "Failed to create multipart upload to s3: {e:?}" + )) + }, + |CreateMultipartUploadOutput { upload_id, .. }| { + upload_id.map_or_else( + || { + RetryResult::Err(make_err!( + Code::Internal, + "Expected upload_id to be set by s3 response" + )) + }, + RetryResult::Ok, + ) + }, + ); + Some((retry_result, ())) + })) + .await?; - let complete_result = { - let mut part_number: i32 = 1; + // S3 requires us to upload in parts if the size is greater than 5GB. The part size must be at least + // 5mb (except last part) and can have up to 10,000 parts. + let bytes_per_upload_part = cmp::min( + cmp::max(MIN_MULTIPART_SIZE, max_size / (MIN_MULTIPART_SIZE - 1)), + MAX_MULTIPART_SIZE, + ); + let upload_parts = move || async move { // This will ensure we only have `multipart_max_concurrent_uploads` * `bytes_per_upload_part` // bytes in memory at any given time waiting to be uploaded. - let (tx, rx) = mpsc::channel(self.multipart_max_concurrent_uploads); - let upload_id_clone = upload_id.clone(); + let (tx, mut rx) = mpsc::channel(self.multipart_max_concurrent_uploads); let read_stream_fut = async move { - loop { + let retrier = &Pin::get_ref(self).retrier; + // Note: Our break condition is when we reach EOF. + for part_number in 1..i32::MAX { let write_buf = reader .consume(Some(bytes_per_upload_part)) .await @@ -366,92 +511,125 @@ impl Store for S3Store { break; // Reached EOF. } - let upload_fut = self - .s3_client - .upload_part() - .bucket(self.bucket.clone()) - .key(s3_path) - .upload_id(upload_id_clone.clone()) - .body(ByteStream::new(SdkBody::from(write_buf))) - .part_number(part_number) - .send() - .map(move |result| { - result.map_or_else( - |e| { - Err(make_err!( - Code::Internal, - "Failed to upload part {part_number} in S3 store: {e:?}" - )) - }, - |mut response| { - Ok(CompletedPartBuilder::default() - // Only set an entity tag if it exists. This saves - // 13 bytes per part on the final request if it can - // omit the `` string. - .set_e_tag(response.e_tag.take()) - .part_number(part_number) - .build()) - }, - ) - }); - tx.send(upload_fut).await.map_err(|e| { - make_err!( - Code::Internal, - "Could not send across mpsc for {part_number} in S3 store: {e:?}" - ) - })?; - part_number += 1; + tx.send(retrier.retry(unfold( + write_buf, + move |write_buf| { + async move { + let retry_result = self + .s3_client + .upload_part() + .bucket(&self.bucket) + .key(s3_path) + .upload_id(upload_id) + .body(ByteStream::new(SdkBody::from(write_buf.clone()))) + .part_number(part_number) + .send() + .await + .map_or_else( + |e| { + RetryResult::Retry(make_err!( + Code::Aborted, + "Failed to upload part {part_number} in S3 store: {e:?}" + )) + }, + |mut response| { + RetryResult::Ok( + CompletedPartBuilder::default() + // Only set an entity tag if it exists. This saves + // 13 bytes per part on the final request if it can + // omit the `` string. + .set_e_tag(response.e_tag.take()) + .part_number(part_number) + .build(), + ) + }, + ); + Some((retry_result, write_buf)) + } + } + ))).await.map_err(|_| make_err!(Code::Internal, "Failed to send part to channel in s3_store"))?; } - Ok(()) - }; - - // This will ensure we only have `multipart_max_concurrent_uploads` requests in flight - // at any given time. - let completed_parts_fut = ReceiverStream::new(rx) - .buffer_unordered(self.multipart_max_concurrent_uploads) - .try_collect::>(); - - // Wait for the entire stream to be read and all parts to be uploaded. - let ((), mut completed_parts) = - try_join!(read_stream_fut, completed_parts_fut).err_tip(|| "In s3 store")?; + Result::<_, Error>::Ok(()) + }.fuse(); + + let mut upload_futures = FuturesUnordered::new(); + + let mut completed_parts = Vec::with_capacity(cmp::min( + MAX_UPLOAD_PARTS, + (max_size / bytes_per_upload_part) + 1, + )); + tokio::pin!(read_stream_fut); + loop { + if read_stream_fut.is_terminated() && rx.is_empty() && upload_futures.is_empty() { + break; // No more data to process. + } + tokio::select! { + result = &mut read_stream_fut => result?, // Return error or wait for other futures. + Some(upload_result) = upload_futures.next() => completed_parts.push(upload_result?), + Some(fut) = rx.recv() => upload_futures.push(fut), + } + } // Even though the spec does not require parts to be sorted by number, we do it just in case // there's an S3 implementation that requires it. completed_parts.sort_unstable_by_key(|part| part.part_number); - self.s3_client - .complete_multipart_upload() - .bucket(&self.bucket) - .key(s3_path.clone()) - .multipart_upload( - CompletedMultipartUploadBuilder::default() - .set_parts(Some(completed_parts)) - .build(), - ) - .upload_id(upload_id.clone()) - .send() + self.retrier + .retry(unfold(completed_parts, move |completed_parts| async move { + Some(( + self.s3_client + .complete_multipart_upload() + .bucket(&self.bucket) + .key(s3_path) + .multipart_upload( + CompletedMultipartUploadBuilder::default() + .set_parts(Some(completed_parts.clone())) + .build(), + ) + .upload_id(upload_id) + .send() + .await + .map_or_else( + |e| { + RetryResult::Retry(make_err!( + Code::Aborted, + "Failed to complete multipart upload in S3 store: {e:?}" + )) + }, + |_| RetryResult::Ok(()), + ), + completed_parts, + )) + })) .await - .map_or_else(|e| Err(make_err!(Code::Internal, "{e:?}")), |_| Ok(())) - .err_tip(|| "Failed to complete multipart to s3")?; - Ok(()) }; - if complete_result.is_err() { - let abort_result = self - .s3_client - .abort_multipart_upload() - .bucket(&self.bucket) - .key(s3_path.clone()) - .upload_id(upload_id.clone()) - .send() - .await; - if let Err(err) = abort_result { - info!( - "\x1b[0;31ms3_store\x1b[0m: Failed to abort_multipart_upload: {:?}", - err - ); - } - } - complete_result + // Upload our parts and complete the multipart upload. + // If we fail attempt to abort the multipart upload (cleanup). + upload_parts() + .or_else(move |e| async move { + Result::<(), _>::Err(e).merge( + // Note: We don't retry here because this is just a best attempt. + self.s3_client + .abort_multipart_upload() + .bucket(&self.bucket) + .key(s3_path) + .upload_id(upload_id) + .send() + .await + .map_or_else( + |e| { + let err = make_err!( + Code::Aborted, + "Failed to abort multipart upload in S3 store : {e:?}" + ); + info!("{err:?}"); + Err(err) + }, + |_| Ok(()), + ), + ) + }) + .await } async fn get_part_ref( diff --git a/nativelink-store/tests/s3_store_test.rs b/nativelink-store/tests/s3_store_test.rs index 49f767033..ce4dcc341 100644 --- a/nativelink-store/tests/s3_store_test.rs +++ b/nativelink-store/tests/s3_store_test.rs @@ -17,18 +17,20 @@ use std::sync::Arc; use std::time::Duration; use aws_sdk_s3::config::{BehaviorVersion, Builder, Region}; +use aws_sdk_s3::primitives::ByteStream; use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient}; use aws_smithy_types::body::SdkBody; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use futures::join; +use futures::task::Poll; use http::header; use http::status::StatusCode; use hyper::Body; -use nativelink_error::{Error, ResultExt}; +use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_store::s3_store::S3Store; use nativelink_util::buf_channel::make_buf_channel_pair; use nativelink_util::common::{DigestInfo, JoinHandleDropGuard}; -use nativelink_util::store_trait::Store; +use nativelink_util::store_trait::{Store, UploadSizeInfo}; use sha2::{Digest, Sha256}; // TODO(aaronmondal): Figure out how to test the connector retry mechanism. @@ -183,30 +185,25 @@ mod s3_store_tests { async fn simple_update_ac() -> Result<(), Error> { const AC_ENTRY_SIZE: u64 = 199; const CONTENT_LENGTH: usize = 50; - let mut send_data = Vec::with_capacity(CONTENT_LENGTH); - for i in 0..send_data.capacity() { - send_data.push(((i * 3) % 256) as u8); + let mut send_data = BytesMut::new(); + for i in 0..CONTENT_LENGTH { + send_data.put_u8(((i % 93) + 33) as u8); // Printable characters only. } - - let mock_client = StaticReplayClient::new(vec![ReplayEvent::new( - http::Request::builder() - .uri(format!( - "https://{BUCKET_NAME}.s3.{REGION}.amazonaws.com/{VALID_HASH1}-{AC_ENTRY_SIZE}?x-id=PutObject", - )) - .method("PUT") - .header("content-type", "application/octet-stream") - .header("content-length", CONTENT_LENGTH.to_string()) - .body(SdkBody::from(send_data.clone())) + let send_data = send_data.freeze(); + + let (mock_client, request_receiver) = + aws_smithy_runtime::client::http::test_util::capture_request(Some( + aws_smithy_runtime_api::http::Response::new( + StatusCode::OK.into(), + SdkBody::empty(), // This is an upload, so server does not send a body. + ) + .try_into_http02x() .unwrap(), - http::Response::builder() - .status(StatusCode::OK) - .body(SdkBody::empty()) - .unwrap(), - )]); + )); let test_config = Builder::new() .behavior_version(BehaviorVersion::v2023_11_09()) .region(Region::from_static(REGION)) - .http_client(mock_client.clone()) + .http_client(mock_client) .build(); let s3_client = aws_sdk_s3::Client::from_conf(test_config); let store = S3Store::new_with_client_and_jitter( @@ -217,14 +214,64 @@ mod s3_store_tests { s3_client, Arc::new(move |_delay| Duration::from_secs(0)), )?; - let store_pin = Pin::new(&store); - store_pin - .update_oneshot( - DigestInfo::try_new(VALID_HASH1, AC_ENTRY_SIZE)?, - send_data.clone().into(), - ) - .await?; - mock_client.assert_requests_match(&[]); + let (mut tx, rx) = make_buf_channel_pair(); + // Make future responsible for processing the datastream + // and forwarding it to the s3 backend/server. + let mut update_fut = Box::pin(async move { + Pin::new(&store) + .update( + DigestInfo::try_new(VALID_HASH1, AC_ENTRY_SIZE)?, + rx, + UploadSizeInfo::ExactSize(CONTENT_LENGTH), + ) + .await + }); + + // Extract out the body stream sent by the s3 store. + let body_stream = { + // We need to poll here to get the request sent, but future + // wont be done until we send all the data (which we do later). + assert_eq!(Poll::Pending, futures::poll!(&mut update_fut)); + let sent_request = request_receiver.expect_request(); + assert_eq!(sent_request.method(), "PUT"); + assert_eq!(sent_request.uri(), format!("https://{BUCKET_NAME}.s3.{REGION}.amazonaws.com/{VALID_HASH1}-{AC_ENTRY_SIZE}?x-id=PutObject")); + ByteStream::from_body_0_4(sent_request.into_body()) + }; + + let send_data_copy = send_data.clone(); + // Create spawn that is responsible for sending the stream of data + // to the S3Store and processing/forwarding to the S3 backend. + let spawn_fut = tokio::spawn(async move { + tokio::try_join!(update_fut, async move { + for i in 0..CONTENT_LENGTH { + tx.send(send_data_copy.slice(i..(i + 1))).await?; + } + tx.send_eof() + }) + .or_else(|e| { + // Printing error to make it easier to debug, since ordering + // of futures is not guaranteed. + eprintln!("Error updating or sending in spawn: {e:?}"); + Err(e) + }) + }); + + // Wait for all the data to be received by the s3 backend server. + let data_sent_to_s3 = body_stream + .collect() + .await + .map_err(|e| make_input_err!("{e:?}"))?; + assert_eq!( + send_data, + data_sent_to_s3.into_bytes(), + "Expected data to match" + ); + + // Collect our spawn future to ensure it completes without error. + spawn_fut + .await + .err_tip(|| "Failed to launch spawn")? + .err_tip(|| "In spawn")?; Ok(()) } diff --git a/nativelink-util/src/buf_channel.rs b/nativelink-util/src/buf_channel.rs index d9a07a40f..0f3771838 100644 --- a/nativelink-util/src/buf_channel.rs +++ b/nativelink-util/src/buf_channel.rs @@ -243,8 +243,8 @@ impl DropCloserReadHalf { /// Sets the maximum size of the recent_data buffer. If the number of bytes /// received exceeds this size, the recent_data buffer will be cleared and /// no longer populated. - pub fn set_max_recent_data_size(&mut self, size: u64) { - self.max_recent_data_size = size; + pub fn set_max_recent_data_size(&mut self, size: usize) { + self.max_recent_data_size = size as u64; } /// Attempts to reset the stream to before any data was received. This will diff --git a/nativelink-util/src/fs.rs b/nativelink-util/src/fs.rs index 080e51870..e9bba313b 100644 --- a/nativelink-util/src/fs.rs +++ b/nativelink-util/src/fs.rs @@ -287,10 +287,8 @@ static TOTAL_FILE_SEMAPHORES: AtomicUsize = AtomicUsize::new(DEFAULT_OPEN_FILE_P pub static OPEN_FILE_SEMAPHORE: Semaphore = Semaphore::const_new(DEFAULT_OPEN_FILE_PERMITS); /// Try to acquire a permit from the open file semaphore. -/// This function will block, so it must be run from a thread that -/// can be blocked. #[inline] -async fn get_permit() -> Result, Error> { +pub async fn get_permit() -> Result, Error> { OPEN_FILE_SEMAPHORE .acquire() .await diff --git a/nativelink-util/src/retry.rs b/nativelink-util/src/retry.rs index 8157aeda0..42a32c33f 100644 --- a/nativelink-util/src/retry.rs +++ b/nativelink-util/src/retry.rs @@ -126,17 +126,20 @@ impl Retrier { .take(self.config.max_retries) // Remember this is number of retries, so will run max_retries + 1. } - pub fn retry<'a, T, Fut>( + // Clippy complains that this function can be `async fn`, but this is not true. + // If we use `async fn`, other places in our code will fail to compile stating + // something about the async blocks not matching. + // This appears to happen due to a compiler bug while inlining, because the + // function that it complained about was calling another function that called + // this one. + #[allow(clippy::manual_async_fn)] + pub fn retry<'a, T: Send>( &'a self, - operation: Fut, - ) -> Pin> + 'a + Send>> - where - Fut: futures::stream::Stream> + Send + 'a, - T: Send, - { - Box::pin(async move { + operation: impl futures::stream::Stream> + Send + 'a, + ) -> impl Future> + Send + 'a { + async move { let mut iter = self.get_retry_config(); - let mut operation = Box::pin(operation); + tokio::pin!(operation); let mut attempt = 0; loop { attempt += 1; @@ -149,7 +152,7 @@ impl Retrier { } Some(RetryResult::Ok(value)) => return Ok(value), Some(RetryResult::Err(e)) => { - return Err(e.append(format!("On attempt {attempt}"))) + return Err(e.append(format!("On attempt {attempt}"))); } Some(RetryResult::Retry(e)) => { if !self.should_retry(&e.code) { @@ -164,6 +167,6 @@ impl Retrier { } } } - }) + } } }