Skip to content

Commit

Permalink
GrpcStore Write Retry
Browse files Browse the repository at this point in the history
The current implementation of retry in GrpcStore is awkward and only allows
retry up until the first call to the WriteRequestStreamWrapper, however this
has a buffer of the first message in it.  Therefore, with a bit of
refactoring we are able to retry up until the second message is requested by
Tonic without any degredation in performance.  This has the added benefit
of being able to refactor the interface to be a Stream.
  • Loading branch information
chrisstaite-menlo committed Jan 30, 2024
1 parent d46180c commit 6f132dd
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 95 deletions.
6 changes: 3 additions & 3 deletions nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,11 @@ impl ByteStreamServer {
// by counting the number of bytes sent from the client. If they send
// less than the amount they said they were going to send and then
// close the stream, we know there's a problem.
Ok(None) => return Err(make_input_err!("Client closed stream before sending all data")),
None => return Err(make_input_err!("Client closed stream before sending all data")),
// Code path for client stream error. Probably client disconnect.
Err(err) => return Err(err),
Some(Err(err)) => return Err(err),
// Code path for received chunk of data.
Ok(Some(write_request)) => write_request,
Some(Ok(write_request)) => write_request,
};

if write_request.write_offset < 0 {
Expand Down
2 changes: 1 addition & 1 deletion nativelink-service/tests/bytestream_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ pub mod write_tests {
// Now disconnect our stream.
drop(tx);
let (result, _bs_server) = join_handle.await?;
assert!(result.is_ok(), "Expected success to be returned");
result?;
}
{
// Check to make sure our store recorded the data properly.
Expand Down
158 changes: 97 additions & 61 deletions nativelink-store/src/grpc_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::marker::Send;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use async_trait::async_trait;
Expand All @@ -35,6 +36,7 @@ use nativelink_proto::google::bytestream::{
};
use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
use nativelink_util::common::DigestInfo;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::retry::{ExponentialBackoff, Retrier, RetryResult};
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::tls_utils;
Expand Down Expand Up @@ -90,6 +92,65 @@ impl Stream for FirstStream {
}
}

/// This structure wraps all of the information required to perform a write
/// request on the GrpcStore, it is used to allow a write to retry on failure.
struct WriteState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
instance_name: String,
error: Option<Error>,
read_stream: WriteRequestStreamWrapper<T, E>,
client: ByteStreamClient<Channel>,
}

/// A wrapper around WriteState to allow it to be reclaimed from the underlying
/// write call in the case of failure.
struct WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
shared_state: Arc<Mutex<WriteState<T, E>>>,
}

impl<T, E> Stream for WriteStateWrapper<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
type Item = WriteRequest;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// This should be an uncontended lock since write was called.
let mut local_state = self.shared_state.lock();
let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) else {
return Poll::Pending;
};
const IS_UPLOAD_TRUE: bool = true;
let result = match maybe_message {
Some(Ok(mut message)) => match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) {
Ok(mut resource_name) => {
resource_name.instance_name = &local_state.instance_name;
message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE);
Some(message)
}
Err(err) => {
error!("{err:?}");
None
}
},
Some(Err(err)) => {
local_state.error = Some(err);
None
}
None => None,
};
Poll::Ready(result)
}
}

impl GrpcStore {
pub async fn new(config: &nativelink_config::stores::GrpcStore) -> Result<Self, Error> {
let jitter_amt = config.retry.jitter;
Expand Down Expand Up @@ -305,80 +366,55 @@ impl GrpcStore {
"CAS operation on AC store"
);

struct LocalState<T, E>
where
T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
E: Into<Error> + 'static,
{
instance_name: String,
error: Mutex<Option<Error>>,
read_stream: Mutex<Option<WriteRequestStreamWrapper<T, E>>>,
client: ByteStreamClient<Channel>,
}

let local_state = Arc::new(LocalState {
let local_state = Arc::new(Mutex::new(WriteState {
instance_name: self.instance_name.clone(),
error: Mutex::new(None),
read_stream: Mutex::new(Some(stream)),
error: None,
read_stream: stream,
client: self.bytestream_client.clone(),
});
}));

let retry_config = self.get_retry_config();
let result = self
.retrier
.retry(
retry_config,
unfold(local_state, move |local_state| async move {
let stream = unfold((None, local_state.clone()), move |(stream, local_state)| async {
// Only consume the stream on the first request to read,
// then pass it for future requests in the unfold.
let mut stream = stream.or_else(|| local_state.read_stream.lock().take())?;
let maybe_message = stream.next().await;
if let Ok(maybe_message) = maybe_message {
if let Some(mut message) = maybe_message {
// `resource_name` pattern is: "{instance_name}/uploads/{uuid}/blobs/{hash}/{size}".
let first_slash_pos = match message.resource_name.find('/') {
Some(pos) => pos,
None => {
error!("{}", "Resource name should follow pattern {instance_name}/uploads/{uuid}/blobs/{hash}/{size}");
return None;
}
};
message.resource_name = format!(
"{}/{}",
&local_state.instance_name,
message.resource_name.get((first_slash_pos + 1)..).unwrap()
);
return Some((message, (Some(stream), local_state)));
let mut client = local_state.lock().client.clone();
// The client write may occur on a separate thread and
// therefore in order to share the state with it we have to
// wrap it in a Mutex and retrieve it after the write
// has completed. There is no way to get the value back
// from the client.
let result = client
.write(WriteStateWrapper {
shared_state: local_state.clone(),
})
.await;

// Get the state back from StateWrapper, this should be
// uncontended since write has returned.
let mut local_state_locked = local_state.lock();

let result = if let Some(err) = &local_state_locked.error {
// If there was an error with the stream, then don't
// retry.
RetryResult::Err(err.clone())
} else {
// On error determine whether it is possible to retry.
match result.err_tip(|| "in GrpcStore::write") {
Err(err) => {
if local_state_locked.read_stream.is_retryable() {
local_state_locked.read_stream.reset();
RetryResult::Retry(err)
} else {
RetryResult::Err(err.append("Retry is not possible"))
}
}
return None;
Ok(response) => RetryResult::Ok(response),
}
// TODO(allada) I'm sure there's a way to do this without a mutex, but rust can be super
// picky with borrowing through a stream await.
*local_state.error.lock() = Some(maybe_message.unwrap_err());
None
});

let result = local_state.client.clone()
.write(stream)
.await
.err_tip(|| "in GrpcStore::write");

// If the stream has been consumed, don't retry, but
// otherwise it's ok to try again.
let result = if local_state.read_stream.lock().is_some() {
result.map_or_else(RetryResult::Retry, RetryResult::Ok)
} else {
result.map_or_else(RetryResult::Err, RetryResult::Ok)
};

// If there was an error with the stream, then don't retry.
let result = if let Some(err) = local_state.error.lock().take() {
RetryResult::Err(err)
} else {
result
};

drop(local_state_locked);
Some((result, local_state))
}),
)
Expand Down
24 changes: 23 additions & 1 deletion nativelink-util/src/resource_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub struct ResourceInfo<'a> {
pub compressor: Option<&'a str>,
pub digest_function: Option<&'a str>,
pub hash: &'a str,
size: &'a str,
pub expected_size: usize,
pub optional_metadata: Option<&'a str>,
}
Expand Down Expand Up @@ -129,6 +130,25 @@ impl<'a> ResourceInfo<'a> {
}
Ok(output)
}

pub fn to_string(&self, is_upload: bool) -> String {
[
Some(self.instance_name),
is_upload.then_some("uploads"),
self.uuid,
Some(self.compressor.map_or("blobs", |_| "compressed-blobs")),
self.compressor,
self.digest_function,
Some(self.hash),
Some(self.size),
self.optional_metadata,
]
.into_iter()
.flatten()
.filter(|part| !part.is_empty())
.collect::<Vec<&str>>()
.join("/")
}
}

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -177,8 +197,9 @@ fn recursive_parse<'a>(
output.compressor = Some(part);
*bytes_processed += part.len() + SLASH_SIZE;
return Ok(state);
} else {
return Err(make_input_err!("Expected compressor, got {part}"));
}
continue;
}
State::DigestFunction => {
state = State::Hash;
Expand All @@ -196,6 +217,7 @@ fn recursive_parse<'a>(
return Ok(State::Size);
}
State::Size => {
output.size = part;
output.expected_size = part
.parse::<usize>()
.map_err(|_| make_input_err!("Digest size_bytes was not convertible to usize. Got: {}", part))?;
Expand Down
Loading

0 comments on commit 6f132dd

Please sign in to comment.