From 42ad62597818551269289abf633176bf57d010f8 Mon Sep 17 00:00:00 2001 From: Blaise Bruer Date: Fri, 11 Oct 2024 00:42:20 -0500 Subject: [PATCH] [Bug fix] Adds retry logic to redis store Adds retry logic and configs for redis store. This will setup redis to reconnect and retry commands if the connection to redis is lost. closes #1266 --- nativelink-config/src/stores.rs | 65 +++ nativelink-metric-collector/Cargo.toml | 2 +- .../redis_store_awaited_action_db_test.rs | 29 +- nativelink-store/src/default_store_factory.rs | 2 +- nativelink-store/src/redis_store.rs | 503 ++++++++++-------- nativelink-store/tests/redis_store_test.rs | 265 +++++++-- 6 files changed, 598 insertions(+), 268 deletions(-) diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index fd4bcfb2f..9c4e21a87 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -924,6 +924,71 @@ pub struct RedisStore { /// Default: standard, #[serde(default)] pub mode: RedisMode, + + /// When using pubsub interface, this is the maximum number of items to keep + /// queued up before dropping old items. + /// + /// Default: 4096 + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub broadcast_channel_capacity: usize, + + /// The amount of time in milliseconds until the redis store considers the + /// command to be timed out. This will trigger a retry of the command and + /// potentially a reconnection to the redis server. + /// + /// Default: 10000 (10 seconds) + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub command_timeout_ms: u64, + + /// The amount of time in milliseconds until the redis store considers the + /// connection to unresponsive. This will trigger a reconnection to the + /// redis server. + /// + /// Default: 3000 (3 seconds) + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub connection_timeout_ms: u64, + + /// The amount of data to read from the redis server at a time. + /// This is used to limit the amount of memory used when reading + /// large objects from the redis server as well as limiting the + /// amount of time a single read operation can take. + /// + /// IMPORTANT: If this value is too high, the `command_timeout_ms` + /// might be triggered if the latency or throughput to the redis + /// server is too low. + /// + /// Default: 64KiB + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub read_chunk_size: usize, + + /// The number of connections to keep open to the redis server(s). + /// + /// Default: 3 + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub connection_pool_size: usize, + + /// The maximum number of upload chunks to allow per update. + /// This is used to limit the amount of memory used when uploading + /// large objects to the redis server. A good rule of thumb is to + /// think of the data as: + /// AVAIL_MEMORY / (read_chunk_size * max_chunk_uploads_per_update) = THORETICAL_MAX_CONCURRENT_UPLOADS + /// (note: it is a good idea to divide AVAIL_MAX_MEMORY by ~10 to account for other memory usage) + /// + /// Default: 10 + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + pub max_chunk_uploads_per_update: usize, + + /// Retry configuration to use when a network request fails. + /// See the `Retry` struct for more information. + /// + /// Default: Retry { + /// max_retries: 0, /* unlimited */ + /// delay: 0.1, /* 100ms */ + /// jitter: 0.5, /* 50% */ + /// retry_on_errors: None, /* not used in redis store */ + /// } + #[serde(default)] + pub retry: Retry, } #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] diff --git a/nativelink-metric-collector/Cargo.toml b/nativelink-metric-collector/Cargo.toml index de688c4ef..3fe2c628f 100644 --- a/nativelink-metric-collector/Cargo.toml +++ b/nativelink-metric-collector/Cargo.toml @@ -6,7 +6,7 @@ rust-version = "1.79.0" [dependencies] nativelink-metric = { path = "../nativelink-metric" } -opentelemetry = { version = "0.24.0", default-features = false } +opentelemetry = { version = "0.24.0", features = ["metrics"], default-features = false } parking_lot = "0.12.3" serde = { version = "1.0.210", default-features = false } tracing = { version = "0.1.40", default-features = false } diff --git a/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs index 419551534..00709cb2e 100644 --- a/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs +++ b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs @@ -20,11 +20,12 @@ use std::time::{Duration, SystemTime}; use bytes::Bytes; use fred::bytes_utils::string::Str; +use fred::clients::SubscriberClient; use fred::error::{RedisError, RedisErrorKind}; use fred::mocks::{MockCommand, Mocks}; -use fred::prelude::Builder; -use fred::types::{RedisConfig, RedisValue}; -use mock_instant::thread_local::SystemTime as MockSystemTime; +use fred::prelude::{Builder, RedisPool}; +use fred::types::{PerformanceConfig, RedisConfig, RedisValue}; +use mock_instant::global::SystemTime as MockSystemTime; use nativelink_error::Error; use nativelink_macro::nativelink_test; use nativelink_scheduler::awaited_action_db::{ @@ -46,6 +47,7 @@ const INSTANCE_NAME: &str = "instance_name"; const TEMP_UUID: &str = "550e8400-e29b-41d4-a716-446655440000"; const SCRIPT_VERSION: &str = "3e762c15"; const VERSION_SCRIPT_HASH: &str = "fdf1152fd21705c8763752809b86b55c5d4511ce"; +const MAX_CHUNK_UPLOADS_PER_UPDATE: usize = 10; fn mock_uuid_generator() -> String { uuid::Uuid::parse_str(TEMP_UUID).unwrap().to_string() @@ -145,6 +147,20 @@ impl Drop for MockRedisBackend { } } +fn make_clients(mut builder: Builder) -> (RedisPool, SubscriberClient) { + const CONNECTION_POOL_SIZE: usize = 1; + let client_pool = builder + .set_performance_config(PerformanceConfig { + broadcast_channel_capacity: 4096, + ..Default::default() + }) + .build_pool(CONNECTION_POOL_SIZE) + .unwrap(); + + let subscriber_client = builder.build_subscriber_client().unwrap(); + (client_pool, subscriber_client) +} + #[nativelink_test] async fn add_action_smoke_test() -> Result<(), Error> { const CLIENT_OPERATION_ID: &str = "my_client_operation_id"; @@ -389,13 +405,16 @@ async fn add_action_smoke_test() -> Result<(), Error> { mocks: Some(Arc::clone(&mocks) as Arc), ..Default::default() }); - + let (client_pool, subscriber_client) = make_clients(builder); Arc::new( RedisStore::new_from_builder_and_parts( - builder, + client_pool, + subscriber_client, Some(SUB_CHANNEL.into()), mock_uuid_generator, String::new(), + 4064, + MAX_CHUNK_UPLOADS_PER_UPDATE, ) .unwrap(), ) diff --git a/nativelink-store/src/default_store_factory.rs b/nativelink-store/src/default_store_factory.rs index 60a5b1c31..b72b8e1a4 100644 --- a/nativelink-store/src/default_store_factory.rs +++ b/nativelink-store/src/default_store_factory.rs @@ -53,7 +53,7 @@ pub fn store_factory<'a>( StoreConfig::experimental_s3_store(config) => { S3Store::new(config, SystemTime::now).await? } - StoreConfig::redis_store(config) => RedisStore::new(config)?, + StoreConfig::redis_store(config) => RedisStore::new(config.clone())?, StoreConfig::verify(config) => VerifyStore::new( config, store_factory(&config.backend, store_manager, None).await?, diff --git a/nativelink-store/src/redis_store.rs b/nativelink-store/src/redis_store.rs index 5a5d4bbe8..80db12341 100644 --- a/nativelink-store/src/redis_store.rs +++ b/nativelink-store/src/redis_store.rs @@ -21,14 +21,15 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use const_format::formatcp; -use fred::clients::{RedisClient, RedisPool, SubscriberClient}; +use fred::clients::{RedisPool, SubscriberClient}; use fred::interfaces::{ClientLike, KeysInterface, PubsubInterface}; use fred::prelude::{EventInterface, HashesInterface, RediSearchInterface}; use fred::types::{ Builder, ConnectionConfig, FtCreateOptions, PerformanceConfig, ReconnectPolicy, RedisConfig, - RedisKey, RedisMap, RedisValue, Script, SearchSchema, SearchSchemaKind, + RedisKey, RedisMap, RedisValue, Script, SearchSchema, SearchSchemaKind, UnresponsiveConfig, }; -use futures::{FutureExt, Stream, StreamExt}; +use futures::stream::FuturesUnordered; +use futures::{future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use nativelink_config::stores::RedisMode; use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; @@ -51,9 +52,36 @@ use uuid::Uuid; use crate::cas_utils::is_zero_digest; use crate::redis_utils::ft_aggregate; -// TODO(caass): These (and other settings) should be made configurable via nativelink-config. -pub const READ_CHUNK_SIZE: usize = 64 * 1024; -const CONNECTION_POOL_SIZE: usize = 3; +/// The default size of the read chunk when reading data from Redis. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_READ_CHUNK_SIZE: usize = 64 * 1024; + +/// The default size of the connection pool if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_CONNECTION_POOL_SIZE: usize = 3; + +/// The default delay between retries if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_RETRY_DELAY: f32 = 0.1; +/// The amount of jitter to add to the retry delay if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_RETRY_JITTER: f32 = 0.5; + +/// The default maximum capacity of the broadcast channel if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_BROADCAST_CHANNEL_CAPACITY: usize = 4096; + +/// The default connection timeout in milliseconds if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_CONNECTION_TIMEOUT_MS: u64 = 3000; + +/// The default command timeout in milliseconds if not specified. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_COMMAND_TIMEOUT_MS: u64 = 10_000; + +/// The default maximum number of chunk uploads per update. +/// Note: If this changes it should be updated in the config documentation. +const DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE: usize = 10; #[allow(clippy::trivially_copy_pass_by_ref)] fn to_hex(value: &u32) -> String { @@ -92,6 +120,15 @@ pub struct RedisStore { #[metric(help = "Prefix to append to all keys before sending to Redis")] key_prefix: String, + /// The amount of data to read from Redis at a time. + #[metric(help = "The amount of data to read from Redis at a time")] + read_chunk_size: usize, + + /// The maximum number of chunk uploads per update. + /// This is used to limit the number of chunk uploads per update to prevent + #[metric(help = "The maximum number of chunk uploads per update")] + max_chunk_uploads_per_update: usize, + /// Redis script used to update a value in redis if the version matches. /// This is done by incrementing the version number and then setting the new data /// only if the version number matches the existing version number. @@ -103,7 +140,7 @@ pub struct RedisStore { impl RedisStore { /// Create a new `RedisStore` from the given configuration. - pub fn new(config: &nativelink_config::stores::RedisStore) -> Result, Error> { + pub fn new(mut config: nativelink_config::stores::RedisStore) -> Result, Error> { if config.addresses.is_empty() { return Err(make_err!( Code::InvalidArgument, @@ -125,48 +162,105 @@ impl RedisStore { ) })?; + let reconnect_policy = { + if config.retry.delay == 0.0 { + config.retry.delay = DEFAULT_RETRY_DELAY; + } + if config.retry.jitter == 0.0 { + config.retry.jitter = DEFAULT_RETRY_JITTER; + } + + let max_retries = u32::try_from(config.retry.max_retries) + .err_tip(|| "max_retries could not be converted to u32 in RedisStore::new")?; + let min_delay_ms = (config.retry.delay * 1000.0) as u32; + let max_delay_ms = 8000; + let jitter = (config.retry.jitter * config.retry.delay * 1000.0) as u32; + + let mut reconnect_policy = ReconnectPolicy::new_exponential( + max_retries, /* max_retries, 0 is unlimited */ + min_delay_ms, /* min_delay */ + max_delay_ms, /* max_delay */ + 2, /* mult */ + ); + reconnect_policy.set_jitter(jitter); + reconnect_policy + }; + + { + if config.broadcast_channel_capacity == 0 { + config.broadcast_channel_capacity = DEFAULT_BROADCAST_CHANNEL_CAPACITY; + } + if config.connection_timeout_ms == 0 { + config.connection_timeout_ms = DEFAULT_CONNECTION_TIMEOUT_MS; + } + if config.command_timeout_ms == 0 { + config.command_timeout_ms = DEFAULT_COMMAND_TIMEOUT_MS; + } + if config.connection_pool_size == 0 { + config.connection_pool_size = DEFAULT_CONNECTION_POOL_SIZE; + } + if config.read_chunk_size == 0 { + config.read_chunk_size = DEFAULT_READ_CHUNK_SIZE; + } + if config.max_chunk_uploads_per_update == 0 { + config.max_chunk_uploads_per_update = DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE; + } + } + let connection_timeout = Duration::from_millis(config.connection_timeout_ms); + let command_timeout = Duration::from_millis(config.command_timeout_ms); + let mut builder = Builder::from_config(redis_config); builder .set_performance_config(PerformanceConfig { - default_command_timeout: Duration::from_secs(config.response_timeout_s), + default_command_timeout: command_timeout, + broadcast_channel_capacity: config.broadcast_channel_capacity, ..Default::default() }) .set_connection_config(ConnectionConfig { - connection_timeout: Duration::from_secs(config.connection_timeout_s), - internal_command_timeout: Duration::from_secs(config.response_timeout_s), + connection_timeout, + internal_command_timeout: command_timeout, + unresponsive: UnresponsiveConfig { + max_timeout: Some(connection_timeout), + // This number needs to be less than the connection timeout. + // We use 4 as it is a good balance between not spamming the server + // and not waiting too long. + interval: connection_timeout / 4, + }, ..Default::default() }) - // TODO(caass): Make this configurable. - .set_policy(ReconnectPolicy::new_constant(1, 0)); + .set_policy(reconnect_policy); + + let client_pool = builder + .build_pool(config.connection_pool_size) + .err_tip(|| "while creating redis connection pool")?; + + let subscriber_client = builder + .build_subscriber_client() + .err_tip(|| "while creating redis subscriber client")?; Self::new_from_builder_and_parts( - builder, + client_pool, + subscriber_client, config.experimental_pub_sub_channel.clone(), || Uuid::new_v4().to_string(), config.key_prefix.clone(), + config.read_chunk_size, + config.max_chunk_uploads_per_update, ) .map(Arc::new) } /// Used for testing when determinism is required. pub fn new_from_builder_and_parts( - mut builder: Builder, + client_pool: RedisPool, + subscriber_client: SubscriberClient, pub_sub_channel: Option, temp_name_generator_fn: fn() -> String, key_prefix: String, + read_chunk_size: usize, + max_chunk_uploads_per_update: usize, ) -> Result { - let client_pool = builder - .set_performance_config(PerformanceConfig { - broadcast_channel_capacity: 4096, - ..Default::default() - }) - .build_pool(CONNECTION_POOL_SIZE) - .err_tip(|| "while creating redis connection pool")?; - - let subscriber_client = builder - .build_subscriber_client() - .err_tip(|| "while creating redis subscriber client")?; - // Fires off a background task using `tokio::spawn`. + // Start connection pool (this will retry forever by default). client_pool.connect(); subscriber_client.connect(); @@ -177,6 +271,8 @@ impl RedisStore { fingerprint_create_index: fingerprint_create_index_template(), temp_name_generator_fn, key_prefix, + read_chunk_size, + max_chunk_uploads_per_update, update_if_version_matches_script: Script::from_lua(LUA_VERSION_SET_SCRIPT), subscription_manager: Mutex::new(None), }) @@ -204,10 +300,6 @@ impl RedisStore { } } } - - pub fn get_client(&self) -> RedisClient { - self.client_pool.next().clone() - } } #[async_trait] @@ -217,59 +309,49 @@ impl StoreDriver for RedisStore { keys: &[StoreKey<'_>], results: &mut [Option], ) -> Result<(), Error> { - // TODO(caass): Optimize for the case where `keys.len() == 1` - let pipeline = self.client_pool.next().pipeline(); - - results.iter_mut().for_each(|result| *result = None); - - for (idx, key) in keys.iter().enumerate() { - // Don't bother with zero-length digests. - if is_zero_digest(key.borrow()) { - results[idx] = Some(0); - continue; - } - - let encoded_key = self.encode_key(key); - - // This command is queued in memory, but not yet sent down the pipeline; the `await` returns instantly. - pipeline - .strlen::<(), _>(encoded_key.as_ref()) - .await - .err_tip(|| "In RedisStore::has_with_results")?; - } - - // Send the queued commands. - let mut responses = pipeline.all::>().await?.into_iter(); - let mut remaining_results = results.iter_mut().filter(|option| { - // Anything that's `Some` was already set from `is_zero_digest`. - option.is_none() - }); - - // Similar to `Iterator::zip`, but with some verification at the end that the lengths were equal. - while let (Some(response), Some(result_slot)) = (responses.next(), remaining_results.next()) - { - if response == 0 { - // Redis returns 0 when the key doesn't exist AND when the key exists with value of length 0. - // Since we already checked zero-lengths with `is_zero_digest`, this means the value doesn't exist. - continue; - } + // TODO(allada) We could use pipeline here, but it makes retry more + // difficult and it doesn't work very well in cluster mode. + // If we wanted to optimize this with pipeline be careful to + // implement retry and to support cluster mode. + let client = self.client_pool.next(); + keys.iter() + .zip(results.iter_mut()) + .map(|(key, result)| async move { + // We need to do a special pass to ensure our zero key exist. + if is_zero_digest(key.borrow()) { + *result = Some(0); + return Ok::<_, Error>(()); + } + let encoded_key = self.encode_key(key); + let pipeline = client.pipeline(); + pipeline + .strlen::<(), _>(encoded_key.as_ref()) + .await + .err_tip(|| { + format!("In RedisStore::has_with_results::strlen for {encoded_key}") + })?; + // Redis returns 0 when the key doesn't exist + // AND when the key exists with value of length 0. + // Therefore, we need to check both length and existence + // and do it in a pipeline for efficiency. + pipeline + .exists::<(), _>(encoded_key.as_ref()) + .await + .err_tip(|| { + format!("In RedisStore::has_with_results::exists for {encoded_key}") + })?; + let (blob_len, exists) = pipeline + .all::<(u64, bool)>() + .await + .err_tip(|| "In RedisStore::has_with_results::query")?; - *result_slot = Some(response); - } + *result = if exists { Some(blob_len) } else { None }; - if responses.next().is_some() { - Err(make_err!( - Code::Internal, - "Received more responses than expected in RedisStore::has_with_results" - )) - } else if remaining_results.next().is_some() { - Err(make_err!( - Code::Internal, - "Received fewer responses than expected in RedisStore::has_with_results" - )) - } else { - Ok(()) - } + Ok::<_, Error>(()) + }) + .collect::>() + .try_collect() + .await } async fn update( @@ -294,112 +376,84 @@ impl StoreDriver for RedisStore { &final_key ); - let client = self.client_pool.next(); - - // This loop is a little confusing at first glance, but essentially the process is: - // - Get as much data from the reader as possible - // - When the reader is empty, but the writer isn't done sending data, write that data to redis - // - When the writer is done sending data, write the data and break from the loop - // - // At one extreme, we could append data in redis every time we read some bytes -- that is, make one TCP request - // per channel read. This is wasteful since we anticipate reading many small chunks of bytes from the reader. - // - // At the other extreme, we could make a single TCP request to write all of the data all at once. - // This could also be an issue if we read loads of data, since we'd send one massive TCP request - // rather than a few moderately-sized requests. - // - // To compromise, we buffer opportunistically -- when the reader doesn't have any data ready to read, but it's - // not done getting data, we flush the data we _have_ read to redis before waiting for the reader to get more. - // - // As a result of this, there will be a span of time where a key in Redis has only partial data. We want other - // observers to notice atomic updates to keys, rather than partial updates, so we first write to a temporary key - // and then rename that key once we're done appending data. - let mut is_first_chunk = true; - let mut eof_reached = false; - let mut pipe = client.pipeline(); - - while !eof_reached { - pipe = client.pipeline(); - let mut pipe_size = 0; - const MAX_PIPE_SIZE: usize = 5 * 1024 * 1024; // 5 MB - + if is_zero_digest(key.borrow()) { let chunk = reader - .recv() + .peek() .await - .err_tip(|| "Failed to reach chunk in update in redis store")?; - + .err_tip(|| "Failed to peek in RedisStore::update")?; if chunk.is_empty() { - // There are three cases where we receive an empty chunk: - // 1. The first chunk of a zero-digest key. We're required to treat all zero-digest keys as if they exist - // and are empty per the RBE spec, so we can just return early as if we've pushed it -- any attempts to - // read the value later will similarly avoid the network trip. - // 2. This is an empty first chunk of a non-zero-digest key. In this case, we _do_ need to push up an - // empty key, but can skip the rest of the process around renaming since there's only the one operation. - // 3. This is the last chunk (EOF) of a regular key. In that case we can skip pushing this chunk. - // - // In all three cases, we're done pushing data and can move it from the temporary key to the final key. - if is_first_chunk && is_zero_digest(key.borrow()) { - // Case 1, a zero-digest key. - return Ok(()); - } else if is_first_chunk { - // Case 2, an empty non-zero-digest key. - pipe.append::<(), _, _>(&temp_key, "") - .await - .err_tip(|| "While appending to temp key in RedisStore::update")?; - }; - - // Note: setting `eof_reached = true` and calling `continue` is semantically equivalent to `break`. - // Since we need to use the `eof_reached` flag in the inner loop, we do the same here - // for consistency. - eof_reached = true; - continue; - } else { - // Not EOF, but we've now received our first chunk. - is_first_chunk = false; + reader + .drain() + .await + .err_tip(|| "Failed to drain in RedisStore::update")?; + // Zero-digest keys are special -- we don't need to do anything with it. + return Ok(()); } + } - // Queue the append, but don't execute until we've received all the chunks. - pipe_size += chunk.len(); - pipe.append::<(), _, _>(&temp_key, chunk) - .await - .err_tip(|| "Failed to append to temp key in RedisStore::update")?; - - // Opportunistically grab any other chunks already in the reader. - while let Some(chunk) = reader - .try_recv() - .transpose() - .err_tip(|| "Failed to reach chunk in update in redis store")? - { - if chunk.is_empty() { - eof_reached = true; - break; - } else { - pipe_size += chunk.len(); - pipe.append::<(), _, _>(&temp_key, chunk) + let client = self.client_pool.next(); + + let mut read_stream = reader + .scan(0u32, |bytes_read, chunk_res| { + future::ready(Some( + chunk_res + .err_tip(|| "Failed to read chunk in update in redis store") + .and_then(|chunk| { + let offset = *bytes_read; + let chunk_len = u32::try_from(chunk.len()).err_tip(|| { + "Could not convert chunk length to u32 in RedisStore::update" + })?; + let new_bytes_read = bytes_read + .checked_add(chunk_len) + .err_tip(|| "Overflow protection in RedisStore::update")?; + *bytes_read = new_bytes_read; + Ok::<_, Error>((offset, *bytes_read, chunk)) + }), + )) + }) + .map(|res| { + let (offset, end_pos, chunk) = res?; + let temp_key_ref = &temp_key; + Ok(async move { + client + .setrange::<(), _, _>(temp_key_ref, offset, chunk) .await - .err_tip(|| "Failed to append to temp key in RedisStore::update")?; - } + .err_tip(|| { + "While appending to append to temp key in RedisStore::update" + })?; + Ok::(end_pos) + }) + }) + .try_buffer_unordered(self.max_chunk_uploads_per_update); - // Stop appending if the pipeline is already holding 5MB of data. - if pipe_size >= MAX_PIPE_SIZE { - break; - } + let mut total_len: u32 = 0; + while let Some(last_pos) = read_stream.try_next().await? { + if last_pos > total_len { + total_len = last_pos; } + } - // We've exhausted the reader (or hit the 5MB cap), but more data is expected. - // Executing the queued commands appends the data we just received to the temp key. - pipe.all::<()>() - .await - .err_tip(|| "Failed to append to temporary key in RedisStore::update")?; + let blob_len = client + .strlen::(&temp_key) + .await + .err_tip(|| format!("In RedisStore::update strlen check for {temp_key}"))?; + // This is a safety check to ensure that in the event some kind of retry was to happen + // and the data was appended to the key twice, we reject the data. + if blob_len != u64::from(total_len) { + return Err(make_input_err!( + "Data length mismatch in RedisStore::update for {}({}) - expected {} bytes, got {} bytes", + key.borrow().as_str(), + temp_key, + total_len, + blob_len, + )); } // Rename the temp key so that the data appears under the real key. Any data already present in the real key is lost. - pipe.rename::<(), _, _>(&temp_key, final_key.as_ref()) + client + .rename::<(), _, _>(&temp_key, final_key.as_ref()) .await .err_tip(|| "While queueing key rename in RedisStore::update()")?; - pipe.all::<()>() - .await - .err_tip(|| "While renaming key in RedisStore::update()")?; // If we have a publish channel configured, send a notice that the key has been set. if let Some(pub_sub_channel) = &self.pub_sub_channel { @@ -442,9 +496,12 @@ impl StoreDriver for RedisStore { .saturating_add(length.unwrap_or(isize::MAX as usize)) .saturating_sub(1); - // And we don't ever want to read more than `READ_CHUNK_SIZE` bytes at a time, so we'll need to iterate. + // And we don't ever want to read more than `read_chunk_size` bytes at a time, so we'll need to iterate. let mut chunk_start = data_start; - let mut chunk_end = cmp::min(data_start.saturating_add(READ_CHUNK_SIZE) - 1, data_end); + let mut chunk_end = cmp::min( + data_start.saturating_add(self.read_chunk_size) - 1, + data_end, + ); loop { let chunk: Bytes = client @@ -452,7 +509,7 @@ impl StoreDriver for RedisStore { .await .err_tip(|| "In RedisStore::get_part::getrange")?; - let didnt_receive_full_chunk = chunk.len() < READ_CHUNK_SIZE; + let didnt_receive_full_chunk = chunk.len() < self.read_chunk_size; let reached_end_of_data = chunk_end == data_end; if didnt_receive_full_chunk || reached_end_of_data { @@ -474,7 +531,10 @@ impl StoreDriver for RedisStore { // ...and go grab the next chunk. chunk_start = chunk_end + 1; - chunk_end = cmp::min(chunk_start.saturating_add(READ_CHUNK_SIZE) - 1, data_end); + chunk_end = cmp::min( + chunk_start.saturating_add(self.read_chunk_size) - 1, + data_end, + ); } // If we didn't write any data, check if the key exists, if not return a NotFound error. @@ -1007,57 +1067,54 @@ impl SchedulerStore for RedisStore { .await }) }; - let stream = if let Ok(stream) = run_ft_aggregate()?.await { - stream - } else { - let create_result = self - .client_pool - .next() - .ft_create::<(), _>( - format!("{}", get_index_name!(K::KEY_PREFIX, K::INDEX_NAME)), - FtCreateOptions { - on: Some(fred::types::IndexKind::Hash), - prefixes: vec![K::KEY_PREFIX.into()], - nohl: true, - nofields: true, - nofreqs: true, - nooffsets: true, - temporary: Some(INDEX_TTL_S), - ..Default::default() - }, - vec![SearchSchema { - field_name: K::INDEX_NAME.into(), - alias: None, - kind: SearchSchemaKind::Tag { - sortable: true, - unf: false, - separator: None, - casesensitive: false, - withsuffixtrie: false, - noindex: false, + let stream = run_ft_aggregate()? + .or_else(|_| async move { + let create_result = self + .client_pool + .next() + .ft_create::<(), _>( + format!("{}", get_index_name!(K::KEY_PREFIX, K::INDEX_NAME)), + FtCreateOptions { + on: Some(fred::types::IndexKind::Hash), + prefixes: vec![K::KEY_PREFIX.into()], + nohl: true, + nofields: true, + nofreqs: true, + nooffsets: true, + temporary: Some(INDEX_TTL_S), + ..Default::default() }, - }], - ) - .await - .err_tip(|| { + vec![SearchSchema { + field_name: K::INDEX_NAME.into(), + alias: None, + kind: SearchSchemaKind::Tag { + sortable: true, + unf: false, + separator: None, + casesensitive: false, + withsuffixtrie: false, + noindex: false, + }, + }], + ) + .await + .err_tip(|| { + format!( + "Error with ft_create in RedisStore::search_by_index_prefix({})", + get_index_name!(K::KEY_PREFIX, K::INDEX_NAME), + ) + }); + let run_result = run_ft_aggregate()?.await.err_tip(|| { format!( - "Error with ft_create in RedisStore::search_by_index_prefix({})", + "Error with second ft_aggregate in RedisStore::search_by_index_prefix({})", get_index_name!(K::KEY_PREFIX, K::INDEX_NAME), ) }); - let run_result = run_ft_aggregate()?.await.err_tip(|| { - format!( - "Error with second ft_aggregate in RedisStore::search_by_index_prefix({})", - get_index_name!(K::KEY_PREFIX, K::INDEX_NAME), - ) - }); - // Creating the index will race which is ok. If it fails to create, we only - // error if the second ft_aggregate call fails and fails to create. - match run_result { - Ok(stream) => stream, - Err(e) => return create_result.merge(Err(e)), - } - }; + // Creating the index will race which is ok. If it fails to create, we only + // error if the second ft_aggregate call fails and fails to create. + run_result.or_else(move |e| create_result.merge(Err(e))) + }) + .await?; Ok(stream.map(|result| { let mut redis_map = result.err_tip(|| "Error in stream of in RedisStore::search_by_index_prefix")?; diff --git a/nativelink-store/tests/redis_store_test.rs b/nativelink-store/tests/redis_store_test.rs index 438f5224f..8432d0bda 100644 --- a/nativelink-store/tests/redis_store_test.rs +++ b/nativelink-store/tests/redis_store_test.rs @@ -19,16 +19,17 @@ use std::thread::panicking; use bytes::{Bytes, BytesMut}; use fred::bytes_utils::string::Str; +use fred::clients::SubscriberClient; use fred::error::RedisError; use fred::mocks::{MockCommand, Mocks}; -use fred::prelude::Builder; -use fred::types::{RedisConfig, RedisValue}; +use fred::prelude::{Builder, RedisPool}; +use fred::types::{PerformanceConfig, RedisConfig, RedisValue}; use nativelink_error::{Code, Error}; use nativelink_macro::nativelink_test; use nativelink_metric::{MetricFieldData, MetricKind, MetricsComponent, RootMetricsComponent}; use nativelink_metric_collector::MetricsCollectorLayer; use nativelink_store::cas_utils::ZERO_BYTE_DIGESTS; -use nativelink_store::redis_store::{RedisStore, READ_CHUNK_SIZE}; +use nativelink_store::redis_store::RedisStore; use nativelink_store::store_manager::StoreManager; use nativelink_util::buf_channel::make_buf_channel_pair; use nativelink_util::common::DigestInfo; @@ -42,6 +43,9 @@ use tracing_subscriber::layer::SubscriberExt; const VALID_HASH1: &str = "3031323334353637383961626364656630303030303030303030303030303030"; const TEMP_UUID: &str = "550e8400-e29b-41d4-a716-446655440000"; +const DEFAULT_READ_CHUNK_SIZE: usize = 1024; +const DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE: usize = 10; + fn mock_uuid_generator() -> String { uuid::Uuid::parse_str(TEMP_UUID).unwrap().to_string() } @@ -166,6 +170,20 @@ impl Drop for MockRedisBackend { } } +fn make_clients(mut builder: Builder) -> (RedisPool, SubscriberClient) { + const CONNECTION_POOL_SIZE: usize = 1; + let client_pool = builder + .set_performance_config(PerformanceConfig { + broadcast_channel_capacity: 4096, + ..Default::default() + }) + .build_pool(CONNECTION_POOL_SIZE) + .unwrap(); + + let subscriber_client = builder.build_subscriber_client().unwrap(); + (client_pool, subscriber_client) +} + #[nativelink_test] async fn upload_and_get_data() -> Result<(), Error> { // Construct the data we want to send. Since it's small, we expect it to be sent in a single chunk. @@ -187,12 +205,22 @@ async fn upload_and_get_data() -> Result<(), Error> { // Append the real value to the temp key. .expect( MockCommand { - cmd: Str::from_static("APPEND"), + cmd: Str::from_static("SETRANGE"), subcommand: None, - args: vec![temp_key.clone(), chunk_data], + args: vec![temp_key.clone(), 0.into(), chunk_data], }, Ok(RedisValue::Array(vec![RedisValue::Null])), ) + .expect( + MockCommand { + cmd: Str::from_static("STRLEN"), + subcommand: None, + args: vec![temp_key.clone()], + }, + Ok(RedisValue::Array(vec![RedisValue::Integer( + data.len() as i64 + )])), + ) // Move the data from the fake key to the real key. .expect( MockCommand { @@ -214,6 +242,14 @@ async fn upload_and_get_data() -> Result<(), Error> { }, Ok(RedisValue::Integer(2)), ) + .expect( + MockCommand { + cmd: Str::from_static("EXISTS"), + subcommand: None, + args: vec![real_key.clone()], + }, + Ok(RedisValue::Integer(1)), + ) // Retrieve the data from the real key. .expect( MockCommand { @@ -230,8 +266,17 @@ async fn upload_and_get_data() -> Result<(), Error> { mocks: Some(Arc::clone(&mocks) as Arc), ..Default::default() }); - - RedisStore::new_from_builder_and_parts(builder, None, mock_uuid_generator, String::new())? + let (client_pool, subscriber_client) = make_clients(builder); + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; store.update_oneshot(digest, data.clone()).await.unwrap(); @@ -269,12 +314,22 @@ async fn upload_and_get_data_with_prefix() -> Result<(), Error> { mocks .expect( MockCommand { - cmd: Str::from_static("APPEND"), + cmd: Str::from_static("SETRANGE"), subcommand: None, - args: vec![temp_key.clone(), chunk_data], + args: vec![temp_key.clone(), 0.into(), chunk_data], }, Ok(RedisValue::Array(vec![RedisValue::Null])), ) + .expect( + MockCommand { + cmd: Str::from_static("STRLEN"), + subcommand: None, + args: vec![temp_key.clone()], + }, + Ok(RedisValue::Array(vec![RedisValue::Integer( + data.len() as i64 + )])), + ) .expect( MockCommand { cmd: Str::from_static("RENAME"), @@ -291,6 +346,14 @@ async fn upload_and_get_data_with_prefix() -> Result<(), Error> { }, Ok(RedisValue::Integer(2)), ) + .expect( + MockCommand { + cmd: Str::from_static("EXISTS"), + subcommand: None, + args: vec![real_key.clone()], + }, + Ok(RedisValue::Integer(1)), + ) .expect( MockCommand { cmd: Str::from_static("GETRANGE"), @@ -307,12 +370,17 @@ async fn upload_and_get_data_with_prefix() -> Result<(), Error> { ..Default::default() }); + let (client_pool, subscriber_client) = make_clients(builder); RedisStore::new_from_builder_and_parts( - builder, + client_pool, + subscriber_client, None, mock_uuid_generator, prefix.to_string(), - )? + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; store.update_oneshot(digest, data.clone()).await.unwrap(); @@ -338,12 +406,16 @@ async fn upload_empty_data() -> Result<(), Error> { let data = Bytes::from_static(b""); let digest = ZERO_BYTE_DIGESTS[0]; + let (client_pool, subscriber_client) = make_clients(Builder::default_centralized()); // We expect to skip both uploading and downloading when the digest is known zero. let store = RedisStore::new_from_builder_and_parts( - Builder::default_centralized(), + client_pool, + subscriber_client, None, mock_uuid_generator, String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, ) .unwrap(); @@ -364,11 +436,15 @@ async fn upload_empty_data_with_prefix() -> Result<(), Error> { let digest = ZERO_BYTE_DIGESTS[0]; let prefix = "TEST_PREFIX-"; + let (client_pool, subscriber_client) = make_clients(Builder::default_centralized()); let store = RedisStore::new_from_builder_and_parts( - Builder::default_centralized(), + client_pool, + subscriber_client, None, mock_uuid_generator, prefix.to_string(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, ) .unwrap(); @@ -385,7 +461,7 @@ async fn upload_empty_data_with_prefix() -> Result<(), Error> { #[nativelink_test] async fn test_large_downloads_are_chunked() -> Result<(), Error> { - // Requires multiple chunks as data is larger than 64K. + const READ_CHUNK_SIZE: usize = 1024; let data = Bytes::from(vec![0u8; READ_CHUNK_SIZE + 128]); let digest = DigestInfo::try_new(VALID_HASH1, 1)?; @@ -399,12 +475,22 @@ async fn test_large_downloads_are_chunked() -> Result<(), Error> { mocks .expect( MockCommand { - cmd: Str::from_static("APPEND"), + cmd: Str::from_static("SETRANGE"), subcommand: None, - args: vec![temp_key.clone(), data.clone().into()], + args: vec![temp_key.clone(), 0.into(), data.clone().into()], }, Ok(RedisValue::Array(vec![RedisValue::Null])), ) + .expect( + MockCommand { + cmd: Str::from_static("STRLEN"), + subcommand: None, + args: vec![temp_key.clone()], + }, + Ok(RedisValue::Array(vec![RedisValue::Integer( + data.len() as i64 + )])), + ) .expect( MockCommand { cmd: Str::from_static("RENAME"), @@ -421,6 +507,14 @@ async fn test_large_downloads_are_chunked() -> Result<(), Error> { }, Ok(RedisValue::Integer(data.len().try_into().unwrap())), ) + .expect( + MockCommand { + cmd: Str::from_static("EXISTS"), + subcommand: None, + args: vec![real_key.clone()], + }, + Ok(RedisValue::Integer(1)), + ) .expect( MockCommand { cmd: Str::from_static("GETRANGE"), @@ -456,7 +550,17 @@ async fn test_large_downloads_are_chunked() -> Result<(), Error> { ..Default::default() }); - RedisStore::new_from_builder_and_parts(builder, None, mock_uuid_generator, String::new())? + let (client_pool, subscriber_client) = make_clients(builder); + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; store.update_oneshot(digest, data.clone()).await.unwrap(); @@ -473,8 +577,7 @@ async fn test_large_downloads_are_chunked() -> Result<(), Error> { .unwrap(); assert_eq!( - get_result, - data.clone(), + get_result, data, "Expected redis store to have updated value", ); @@ -483,8 +586,8 @@ async fn test_large_downloads_are_chunked() -> Result<(), Error> { #[nativelink_test] async fn yield_between_sending_packets_in_update() -> Result<(), Error> { - let data_p1 = Bytes::from(vec![b'A'; 6 * 1024]); - let data_p2 = Bytes::from(vec![b'B'; 4 * 1024]); + let data_p1 = Bytes::from(vec![b'A'; DEFAULT_READ_CHUNK_SIZE + 512]); + let data_p2 = Bytes::from(vec![b'B'; DEFAULT_READ_CHUNK_SIZE]); let mut data = BytesMut::new(); data.extend_from_slice(&data_p1); @@ -499,26 +602,39 @@ async fn yield_between_sending_packets_in_update() -> Result<(), Error> { let mocks = Arc::new(MockRedisBackend::new()); let first_append = MockCommand { - cmd: Str::from_static("APPEND"), + cmd: Str::from_static("SETRANGE"), subcommand: None, - args: vec![temp_key.clone(), data_p1.clone().into()], + args: vec![temp_key.clone(), 0.into(), data_p1.clone().into()], }; mocks - // We expect multiple `"APPEND"`s as we send data in multiple chunks + // We expect multiple `"SETRANGE"`s as we send data in multiple chunks .expect( first_append.clone(), Ok(RedisValue::Array(vec![RedisValue::Null])), ) .expect( MockCommand { - cmd: Str::from_static("APPEND"), + cmd: Str::from_static("SETRANGE"), subcommand: None, - args: vec![temp_key.clone(), data_p2.clone().into()], + args: vec![ + temp_key.clone(), + data_p1.len().try_into().unwrap(), + data_p2.clone().into(), + ], }, Ok(RedisValue::Array(vec![RedisValue::Null])), ) - // The rest of the process looks the same. + .expect( + MockCommand { + cmd: Str::from_static("STRLEN"), + subcommand: None, + args: vec![temp_key.clone()], + }, + Ok(RedisValue::Array(vec![RedisValue::Integer( + data.len() as i64 + )])), + ) .expect( MockCommand { cmd: Str::from_static("RENAME"), @@ -535,11 +651,47 @@ async fn yield_between_sending_packets_in_update() -> Result<(), Error> { }, Ok(RedisValue::Integer(2)), ) + .expect( + MockCommand { + cmd: Str::from_static("EXISTS"), + subcommand: None, + args: vec![real_key.clone()], + }, + Ok(RedisValue::Integer(1)), + ) .expect( MockCommand { cmd: Str::from_static("GETRANGE"), subcommand: None, - args: vec![real_key, RedisValue::Integer(0), RedisValue::Integer(10239)], + args: vec![ + real_key.clone(), + RedisValue::Integer(0), + RedisValue::Integer((DEFAULT_READ_CHUNK_SIZE - 1) as i64), + ], + }, + Ok(RedisValue::Bytes(data.clone())), + ) + .expect( + MockCommand { + cmd: Str::from_static("GETRANGE"), + subcommand: None, + args: vec![ + real_key.clone(), + RedisValue::Integer(DEFAULT_READ_CHUNK_SIZE as i64), + RedisValue::Integer((DEFAULT_READ_CHUNK_SIZE * 2 - 1) as i64), + ], + }, + Ok(RedisValue::Bytes(data.clone())), + ) + .expect( + MockCommand { + cmd: Str::from_static("GETRANGE"), + subcommand: None, + args: vec![ + real_key, + RedisValue::Integer((DEFAULT_READ_CHUNK_SIZE * 2) as i64), + RedisValue::Integer((data_p1.len() + data_p2.len() - 1) as i64), + ], }, Ok(RedisValue::Bytes(data.clone())), ); @@ -551,7 +703,17 @@ async fn yield_between_sending_packets_in_update() -> Result<(), Error> { ..Default::default() }); - RedisStore::new_from_builder_and_parts(builder, None, mock_uuid_generator, String::new())? + let (client_pool, subscriber_client) = make_clients(builder); + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; let (mut tx, rx) = make_buf_channel_pair(); @@ -610,7 +772,7 @@ async fn zero_len_items_exist_check() -> Result<(), Error> { RedisValue::Integer(0), // We expect to be asked for data from `0..READ_CHUNK_SIZE`, but since GETRANGE is inclusive // the actual call should be from `0..=(READ_CHUNK_SIZE - 1)`. - RedisValue::Integer(READ_CHUNK_SIZE as i64 - 1), + RedisValue::Integer(DEFAULT_READ_CHUNK_SIZE as i64 - 1), ], }, Ok(RedisValue::String(Str::from_static(""))), @@ -631,7 +793,17 @@ async fn zero_len_items_exist_check() -> Result<(), Error> { ..Default::default() }); - RedisStore::new_from_builder_and_parts(builder, None, mock_uuid_generator, String::new())? + let (client_pool, subscriber_client) = make_clients(builder); + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; let result = store.get_part_unchunked(digest, 0, None).await; @@ -650,7 +822,17 @@ async fn dont_loop_forever_on_empty() -> Result<(), Error> { ..Default::default() }); - RedisStore::new_from_builder_and_parts(builder, None, mock_uuid_generator, String::new())? + let (client_pool, subscriber_client) = make_clients(builder); + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap() }; let digest = DigestInfo::try_new(VALID_HASH1, 2).unwrap(); @@ -685,12 +867,19 @@ async fn test_redis_fingerprint_metric() -> Result<(), Error> { ..Default::default() }); - Store::new(Arc::new(RedisStore::new_from_builder_and_parts( - builder, - None, - mock_uuid_generator, - String::new(), - )?)) + let (client_pool, subscriber_client) = make_clients(builder); + Store::new(Arc::new( + RedisStore::new_from_builder_and_parts( + client_pool, + subscriber_client, + None, + mock_uuid_generator, + String::new(), + DEFAULT_READ_CHUNK_SIZE, + DEFAULT_MAX_CHUNK_UPLOADS_PER_UPDATE, + ) + .unwrap(), + )) }; store_manager.add_store("redis_store", store);