From 3917a9078f5ef3a4b5cb0077d3da6e4fdd251012 Mon Sep 17 00:00:00 2001 From: Blaise Bruer Date: Tue, 10 Sep 2024 00:37:01 -0500 Subject: [PATCH] [Feature] Add Redis Scheduler Adds an experimental redis scheduler that can be used as a distributed state-persistent scheduler backend. This scheduler is optimized to have each worker be its own scheduler or many small schedulers. closes: #359 --- Cargo.lock | 51 ++ nativelink-config/src/schedulers.rs | 10 + nativelink-scheduler/BUILD.bazel | 5 + nativelink-scheduler/Cargo.toml | 3 + .../src/awaited_action_db/awaited_action.rs | 14 +- .../src/default_scheduler_factory.rs | 40 +- .../src/memory_awaited_action_db.rs | 3 +- nativelink-scheduler/src/simple_scheduler.rs | 1 - .../src/store_awaited_action_db.rs | 16 +- .../redis_store_awaited_action_db_test.rs | 449 +++++++++++++ .../tests/simple_scheduler_test.rs | 3 +- nativelink-store/BUILD.bazel | 5 + nativelink-store/Cargo.toml | 11 +- nativelink-store/src/lib.rs | 1 + nativelink-store/src/redis_store.rs | 631 +++++++++++++++++- .../src/redis_utils/ft_aggregate.rs | 127 ++++ nativelink-store/src/redis_utils/mod.rs | 16 + nativelink-util/src/action_messages.rs | 7 +- nativelink-util/src/store_trait.rs | 2 + 19 files changed, 1362 insertions(+), 33 deletions(-) create mode 100644 nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs create mode 100644 nativelink-store/src/redis_utils/ft_aggregate.rs create mode 100644 nativelink-store/src/redis_utils/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 1facda817..e56790fe9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -798,6 +798,26 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "const_format" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a214c7af3d04997541b18d432afaff4c455e79e2029079647e72fc2bd27673" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f6ff08fd20f4f299298a28e2dfa8a8ba1036e6cd2460ac1de7b425d76f2500" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + [[package]] name = "constant_time_eq" version = "0.3.1" @@ -1058,6 +1078,7 @@ dependencies = [ "rustls 0.23.12", "rustls-native-certs 0.7.3", "semver", + "sha-1", "socket2", "tokio", "tokio-rustls 0.26.0", @@ -1830,6 +1851,7 @@ dependencies = [ "async-lock", "async-trait", "bytes", + "fred", "futures", "lru", "mock_instant", @@ -1905,6 +1927,8 @@ dependencies = [ "blake3", "byteorder", "bytes", + "bytes-utils", + "const_format", "filetime", "fred", "futures", @@ -1924,6 +1948,7 @@ dependencies = [ "nativelink-util", "once_cell", "parking_lot", + "patricia_tree", "pretty_assertions", "prost", "rand", @@ -2175,6 +2200,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "patricia_tree" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f2f4539bffe53fc4b4da301df49d114b845b077bd5727b7fe2bd9d8df2ae68" +dependencies = [ + "bitflags", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2850,6 +2884,17 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "sha-1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha1" version = "0.10.6" @@ -3349,6 +3394,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-xid" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 2e027359e..bd5627da1 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -130,6 +130,16 @@ pub struct SimpleScheduler { pub enum ExperimentalSimpleSchedulerBackend { /// Use an in-memory store for the scheduler. memory, + /// Use a redis store for the scheduler. + redis(ExperimentalRedisSchedulerBackend), +} + +#[derive(Deserialize, Debug, Default)] +#[serde(deny_unknown_fields)] +pub struct ExperimentalRedisSchedulerBackend { + /// A reference to the redis store to use for the scheduler. + /// Note: This MUST resolve to a RedisStore. + pub redis_store: StoreRefName, } /// A scheduler that simply forwards requests to an upstream scheduler. This diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index 7e407768b..8759db281 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -60,6 +60,7 @@ rust_test_suite( "tests/action_messages_test.rs", "tests/cache_lookup_scheduler_test.rs", "tests/property_modifier_scheduler_test.rs", + "tests/redis_store_awaited_action_db_test.rs", "tests/simple_scheduler_test.rs", ], compile_data = [ @@ -79,10 +80,14 @@ rust_test_suite( "//nativelink-store", "//nativelink-util", "@crates//:async-lock", + "@crates//:bytes", + "@crates//:fred", "@crates//:futures", "@crates//:mock_instant", + "@crates//:parking_lot", "@crates//:pretty_assertions", "@crates//:prost", + "@crates//:serde_json", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:uuid", diff --git a/nativelink-scheduler/Cargo.toml b/nativelink-scheduler/Cargo.toml index 1ffc4b181..cdf8e9877 100644 --- a/nativelink-scheduler/Cargo.toml +++ b/nativelink-scheduler/Cargo.toml @@ -35,3 +35,6 @@ static_assertions = "1.1.0" [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } pretty_assertions = { version = "1.4.0", features = ["std"] } +fred = { version = "9.1.2", default-features = false, features = [ + "mocks", +] } diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs index ba25c0cd5..b3154c461 100644 --- a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -28,7 +28,7 @@ use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; /// The version of the awaited action. /// This number will always increment by one each time /// the action is updated. -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] struct AwaitedActionVersion(u64); impl MetricsComponent for AwaitedActionVersion { @@ -80,7 +80,7 @@ pub struct AwaitedAction { } impl AwaitedAction { - pub fn new(operation_id: OperationId, action_info: Arc) -> Self { + pub fn new(operation_id: OperationId, action_info: Arc, now: SystemTime) -> Self { let stage = ActionStage::Queued; let sort_key = AwaitedActionSortKey::new_with_unique_key( action_info.priority, @@ -102,7 +102,7 @@ impl AwaitedAction { operation_id, sort_key, attempts: 0, - last_worker_updated_timestamp: SystemTime::now(), + last_worker_updated_timestamp: now, worker_id: None, state, } @@ -120,11 +120,11 @@ impl AwaitedAction { self.version = AwaitedActionVersion(self.version.0 + 1); } - pub(crate) fn action_info(&self) -> &Arc { + pub fn action_info(&self) -> &Arc { &self.action_info } - pub(crate) fn operation_id(&self) -> &OperationId { + pub fn operation_id(&self) -> &OperationId { &self.operation_id } @@ -132,7 +132,7 @@ impl AwaitedAction { self.sort_key } - pub(crate) fn state(&self) -> &Arc { + pub fn state(&self) -> &Arc { &self.state } @@ -158,7 +158,7 @@ impl AwaitedAction { /// Sets the current state of the action and notifies subscribers. /// Returns true if the state was set, false if there are no subscribers. - pub(crate) fn set_state(&mut self, mut state: Arc, now: Option) { + pub fn set_state(&mut self, mut state: Arc, now: Option) { std::mem::swap(&mut self.state, &mut state); if let Some(now) = now { self.keep_alive(now); diff --git a/nativelink-scheduler/src/default_scheduler_factory.rs b/nativelink-scheduler/src/default_scheduler_factory.rs index 963a79f5f..acb182938 100644 --- a/nativelink-scheduler/src/default_scheduler_factory.rs +++ b/nativelink-scheduler/src/default_scheduler_factory.rs @@ -17,7 +17,8 @@ use std::time::SystemTime; use nativelink_config::schedulers::{ExperimentalSimpleSchedulerBackend, SchedulerConfig}; use nativelink_config::stores::EvictionPolicy; -use nativelink_error::{Error, ResultExt}; +use nativelink_error::{make_input_err, Error, ResultExt}; +use nativelink_store::redis_store::RedisStore; use nativelink_store::store_manager::StoreManager; use nativelink_util::instant_wrapper::InstantWrapper; use nativelink_util::operation_state_manager::ClientStateManager; @@ -28,6 +29,7 @@ use crate::grpc_scheduler::GrpcScheduler; use crate::memory_awaited_action_db::MemoryAwaitedActionDb; use crate::property_modifier_scheduler::PropertyModifierScheduler; use crate::simple_scheduler::SimpleScheduler; +use crate::store_awaited_action_db::StoreAwaitedActionDb; use crate::worker_scheduler::WorkerScheduler; /// Default timeout for recently completed actions in seconds. @@ -51,7 +53,9 @@ fn inner_scheduler_factory( store_manager: &StoreManager, ) -> Result { let scheduler: SchedulerFactoryResults = match scheduler_type_cfg { - SchedulerConfig::simple(config) => simple_scheduler_factory(config)?, + SchedulerConfig::simple(config) => { + simple_scheduler_factory(config, store_manager, SystemTime::now)? + } SchedulerConfig::grpc(config) => (Some(Arc::new(GrpcScheduler::new(config)?)), None), SchedulerConfig::cache_lookup(config) => { let ac_store = store_manager @@ -83,6 +87,8 @@ fn inner_scheduler_factory( fn simple_scheduler_factory( config: &nativelink_config::schedulers::SimpleScheduler, + store_manager: &StoreManager, + now_fn: fn() -> SystemTime, ) -> Result { match config .experimental_backend @@ -100,6 +106,36 @@ fn simple_scheduler_factory( SimpleScheduler::new(config, awaited_action_db, task_change_notify); Ok((Some(action_scheduler), Some(worker_scheduler))) } + ExperimentalSimpleSchedulerBackend::redis(redis_config) => { + let store = store_manager + .get_store(redis_config.redis_store.as_ref()) + .err_tip(|| { + format!( + "'redis_store': '{}' does not exist", + redis_config.redis_store + ) + })?; + let task_change_notify = Arc::new(Notify::new()); + let store = store + .into_inner() + .as_any_arc() + .downcast::() + .map_err(|_| { + make_input_err!( + "Could not downcast to redis store in RedisAwaitedActionDb::new" + ) + })?; + let awaited_action_db = StoreAwaitedActionDb::new( + store, + task_change_notify.clone(), + now_fn, + Default::default, + ) + .err_tip(|| "In state_manager_factory::redis_state_manager")?; + let (action_scheduler, worker_scheduler) = + SimpleScheduler::new(config, awaited_action_db, task_change_notify); + Ok((Some(action_scheduler), Some(worker_scheduler))) + } } } diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs index b06015b47..514460aea 100644 --- a/nativelink-scheduler/src/memory_awaited_action_db.rs +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -697,7 +697,8 @@ impl I + Clone + Send + Sync> AwaitedActionDbI ActionUniqueQualifier::Uncachable(_unique_key) => None, }; let operation_id = OperationId::default(); - let awaited_action = AwaitedAction::new(operation_id.clone(), action_info); + let awaited_action = + AwaitedAction::new(operation_id.clone(), action_info, (self.now_fn)().now()); debug_assert!( ActionStage::Queued == awaited_action.state().stage, "Expected action to be queued" diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index b793590c3..c344d2c29 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -332,7 +332,6 @@ impl SimpleScheduler { let worker_change_notify = Arc::new(Notify::new()); let state_manager = SimpleSchedulerStateManager::new( max_job_retries, - // TODO(allada) This should probably have its own config. Duration::from_secs(worker_timeout_s), awaited_action_db, now_fn, diff --git a/nativelink-scheduler/src/store_awaited_action_db.rs b/nativelink-scheduler/src/store_awaited_action_db.rs index 4ed1cb72c..f19eeddf5 100644 --- a/nativelink-scheduler/src/store_awaited_action_db.rs +++ b/nativelink-scheduler/src/store_awaited_action_db.rs @@ -296,17 +296,19 @@ impl SchedulerStoreDataProvider for UpdateClientIdToOperationId { } #[derive(MetricsComponent)] -pub struct StoreAwaitedActionDb { +pub struct StoreAwaitedActionDb OperationId> { store: Arc, now_fn: fn() -> SystemTime, + operation_id_creator: F, _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>, } -impl StoreAwaitedActionDb { +impl OperationId> StoreAwaitedActionDb { pub fn new( store: Arc, task_change_publisher: Arc, now_fn: fn() -> SystemTime, + operation_id_creator: F, ) -> Result { let mut subscription = store .subscription_manager() @@ -340,6 +342,7 @@ impl StoreAwaitedActionDb { Ok(Self { store, now_fn, + operation_id_creator, _pull_task_change_subscriber_spawn: pull_task_change_subscriber, }) } @@ -409,7 +412,9 @@ impl StoreAwaitedActionDb { } } -impl AwaitedActionDb for StoreAwaitedActionDb { +impl OperationId + Send + Sync + Unpin + 'static> AwaitedActionDb + for StoreAwaitedActionDb +{ type Subscriber = OperationSubscriber; async fn get_awaited_action_by_id( @@ -466,8 +471,9 @@ impl AwaitedActionDb for StoreAwaitedActionDb { return Ok(sub); } - let new_operation_id = OperationId::default(); - let awaited_action = AwaitedAction::new(new_operation_id.clone(), action_info); + let new_operation_id = (self.operation_id_creator)(); + let awaited_action = + AwaitedAction::new(new_operation_id.clone(), action_info, (self.now_fn)()); debug_assert!( ActionStage::Queued == awaited_action.state().stage, "Expected action to be queued" diff --git a/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs new file mode 100644 index 000000000..9c5ff530d --- /dev/null +++ b/nativelink-scheduler/tests/redis_store_awaited_action_db_test.rs @@ -0,0 +1,449 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::sync::Arc; +use std::thread::panicking; +use std::time::{Duration, SystemTime}; + +use bytes::Bytes; +use fred::bytes_utils::string::Str; +use fred::error::{RedisError, RedisErrorKind}; +use fred::mocks::{MockCommand, Mocks}; +use fred::prelude::Builder; +use fred::types::{RedisConfig, RedisValue}; +use mock_instant::SystemTime as MockSystemTime; +use nativelink_error::Error; +use nativelink_macro::nativelink_test; +use nativelink_scheduler::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, +}; +use nativelink_scheduler::store_awaited_action_db::StoreAwaitedActionDb; +use nativelink_store::redis_store::{RedisStore, RedisSubscriptionManager}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, +}; +use nativelink_util::common::DigestInfo; +use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::store_trait::{SchedulerStore, SchedulerSubscriptionManager}; +use parking_lot::Mutex; +use pretty_assertions::assert_eq; +use tokio::sync::Notify; + +const INSTANCE_NAME: &str = "instance_name"; +const TEMP_UUID: &str = "550e8400-e29b-41d4-a716-446655440000"; +const SCRIPT_VERSION: &str = "3e762c15"; +const VERSION_SCRIPT_HASH: &str = "fdf1152fd21705c8763752809b86b55c5d4511ce"; + +fn mock_uuid_generator() -> String { + uuid::Uuid::parse_str(TEMP_UUID).unwrap().to_string() +} + +type CommandandCallbackTuple = (MockCommand, Option>); +#[derive(Default)] +struct MockRedisBackend { + /// Commands we expect to encounter, and results we to return to the client. + // Commands are pushed from the back and popped from the front. + expected: Mutex)>>, +} + +impl fmt::Debug for MockRedisBackend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MockRedisBackend").finish() + } +} + +impl MockRedisBackend { + fn new() -> Self { + Self::default() + } + + fn expect( + &self, + command: MockCommand, + result: Result, + cb: Option>, + ) -> &Self { + self.expected.lock().push_back(((command, cb), result)); + self + } +} + +impl Mocks for MockRedisBackend { + fn process_command(&self, actual: MockCommand) -> Result { + let Some(((expected, maybe_cb), result)) = self.expected.lock().pop_front() else { + // panic here -- this isn't a redis error, it's a test failure + panic!("Didn't expect any more commands, but received {actual:?}"); + }; + + assert_eq!(actual, expected); + if let Some(cb) = maybe_cb { + (cb)(); + } + + result + } + + fn process_transaction(&self, commands: Vec) -> Result { + static MULTI: MockCommand = MockCommand { + cmd: Str::from_static("MULTI"), + subcommand: None, + args: Vec::new(), + }; + static EXEC: MockCommand = MockCommand { + cmd: Str::from_static("EXEC"), + subcommand: None, + args: Vec::new(), + }; + + let results = std::iter::once(MULTI.clone()) + .chain(commands) + .chain([EXEC.clone()]) + .map(|command| self.process_command(command)) + .collect::, RedisError>>()?; + + Ok(RedisValue::Array(results)) + } +} + +impl Drop for MockRedisBackend { + fn drop(&mut self) { + if panicking() { + // We're already panicking, let's make debugging easier and let future devs solve problems one at a time. + return; + } + + let expected = self.expected.get_mut(); + + if expected.is_empty() { + return; + } + + assert_eq!( + expected + .drain(..) + .map(|((cmd, _), res)| (cmd, res)) + .collect::>(), + VecDeque::new(), + "Didn't receive all expected commands." + ); + + // Panicking isn't enough inside a tokio task, we need to `exit(1)` + std::process::exit(1) + } +} + +#[nativelink_test] +async fn add_action_smoke_test() -> Result<(), Error> { + const CLIENT_OPERATION_ID: &str = "my_client_operation_id"; + const WORKER_OPERATION_ID: &str = "my_worker_operation_id"; + + let worker_awaited_action = AwaitedAction::new( + WORKER_OPERATION_ID.into(), + Arc::new(ActionInfo { + command_digest: DigestInfo::zero_digest(), + input_root_digest: DigestInfo::zero_digest(), + timeout: Duration::from_secs(1), + platform_properties: HashMap::new(), + priority: 0, + load_timestamp: SystemTime::UNIX_EPOCH, + insert_timestamp: SystemTime::UNIX_EPOCH, + unique_qualifier: ActionUniqueQualifier::Cachable(ActionUniqueKey { + instance_name: INSTANCE_NAME.to_string(), + digest_function: DigestHasherFunc::Sha256, + digest: DigestInfo::zero_digest(), + }), + }), + MockSystemTime::now().into(), + ); + let new_awaited_action = { + let mut new_awaited_action = worker_awaited_action.clone(); + let mut new_state = new_awaited_action.state().as_ref().clone(); + new_state.stage = ActionStage::Executing; + new_awaited_action.set_state(Arc::new(new_state), Some(MockSystemTime::now().into())); + new_awaited_action + }; + + const SUB_CHANNEL: &str = "sub_channel"; + let ft_aggregate_args = vec![ + format!("aa__unique_qualifier_{SCRIPT_VERSION}").into(), + format!("@unique_qualifier:{{ {INSTANCE_NAME}_SHA256_0000000000000000000000000000000000000000000000000000000000000000_0_c* }}").into(), + "LOAD".into(), + 2.into(), + "data".into(), + "version".into(), + "SORTBY".into(), + 2.into(), + "@unique_qualifier".into(), + "ASC".into(), + "WITHCURSOR".into(), + "COUNT".into(), + 256.into(), + "MAXIDLE".into(), + 2000.into(), + ]; + static SUBSCRIPTION_MANAGER: Mutex>> = Mutex::new(None); + let mocks = Arc::new(MockRedisBackend::new()); + mocks + .expect( + MockCommand { + cmd: Str::from_static("FT.AGGREGATE"), + subcommand: None, + args: ft_aggregate_args.clone(), + }, + Err(RedisError::new( + RedisErrorKind::NotFound, + String::new(), + )), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("SUBSCRIBE"), + subcommand: None, + args: vec![SUB_CHANNEL.as_bytes().into()], + }, + Ok(RedisValue::Integer(0)), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("FT.CREATE"), + subcommand: None, + args: vec![ + format!("aa__unique_qualifier_{SCRIPT_VERSION}").into(), + "ON".into(), + "HASH".into(), + "PREFIX".into(), + 1.into(), + "aa_".into(), + "TEMPORARY".into(), + 86400.into(), + "NOOFFSETS".into(), + "NOHL".into(), + "NOFIELDS".into(), + "NOFREQS".into(), + "SCHEMA".into(), + "unique_qualifier".into(), + "TAG".into(), + "SORTABLE".into(), + ], + }, + Ok(RedisValue::Bytes(Bytes::from("data"))), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("FT.AGGREGATE"), + subcommand: None, + args: ft_aggregate_args.clone(), + }, + Ok(RedisValue::Array(vec![ + RedisValue::Array(vec![ + RedisValue::Integer(0), + ]), + RedisValue::Integer(0), // Means no more items in cursor. + ])), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("EVALSHA"), + subcommand: None, + args: vec![ + VERSION_SCRIPT_HASH.into(), + 1.into(), + format!("aa_{WORKER_OPERATION_ID}").as_bytes().into(), + "0".as_bytes().into(), + RedisValue::Bytes(Bytes::from(serde_json::to_string(&worker_awaited_action).unwrap())), + "unique_qualifier".as_bytes().into(), + format!("{INSTANCE_NAME}_SHA256_0000000000000000000000000000000000000000000000000000000000000000_0_c").as_bytes().into(), + "sort_key".as_bytes().into(), + "q_9223372041149743103".as_bytes().into(), + ], + }, + Ok(1.into() /* New version */), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("PUBLISH"), + subcommand: None, + args: vec![ + SUB_CHANNEL.into(), + format!("aa_{WORKER_OPERATION_ID}").into(), + ], + }, + Ok(0.into() /* unused */), + Some(Box::new(|| SUBSCRIPTION_MANAGER.lock().as_ref().unwrap().notify_for_test(format!("aa_{WORKER_OPERATION_ID}")))), + ) + .expect( + MockCommand { + cmd: Str::from_static("HSET"), + subcommand: None, + args: vec![ + format!("cid_{CLIENT_OPERATION_ID}").as_bytes().into(), + "data".as_bytes().into(), + format!("{{\"String\":\"{WORKER_OPERATION_ID}\"}}").as_bytes().into(), + ], + }, + Ok(RedisValue::new_ok()), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("PUBLISH"), + subcommand: None, + args: vec![ + SUB_CHANNEL.into(), + format!("cid_{CLIENT_OPERATION_ID}").into(), + ], + }, + Ok(0.into() /* unused */), + Some(Box::new(|| SUBSCRIPTION_MANAGER.lock().as_ref().unwrap().notify_for_test(format!("aa_{CLIENT_OPERATION_ID}")))), + ) + .expect( + MockCommand { + cmd: Str::from_static("HMGET"), + subcommand: None, + args: vec![ + format!("aa_{WORKER_OPERATION_ID}").as_bytes().into(), + "version".as_bytes().into(), + "data".as_bytes().into(), + ], + }, + Ok(RedisValue::Array(vec![ + // Version. + "1".into(), + // Data. + RedisValue::Bytes(Bytes::from(serde_json::to_string(&worker_awaited_action).unwrap())), + ])), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("EVALSHA"), + subcommand: None, + args: vec![ + VERSION_SCRIPT_HASH.into(), + 1.into(), + format!("aa_{WORKER_OPERATION_ID}").as_bytes().into(), + "0".as_bytes().into(), + RedisValue::Bytes(Bytes::from(serde_json::to_string(&new_awaited_action).unwrap())), + "unique_qualifier".as_bytes().into(), + format!("{INSTANCE_NAME}_SHA256_0000000000000000000000000000000000000000000000000000000000000000_0_c").as_bytes().into(), + "sort_key".as_bytes().into(), + "e_9223372041149743103".as_bytes().into(), + ], + }, + Ok(2.into() /* New version */), + None, + ) + .expect( + MockCommand { + cmd: Str::from_static("PUBLISH"), + subcommand: None, + args: vec![ + SUB_CHANNEL.into(), + format!("aa_{WORKER_OPERATION_ID}").into(), + ], + }, + Ok(0.into() /* unused */), + Some(Box::new(|| SUBSCRIPTION_MANAGER.lock().as_ref().unwrap().notify_for_test(format!("aa_{WORKER_OPERATION_ID}")))), + ) + .expect( + MockCommand { + cmd: Str::from_static("HMGET"), + subcommand: None, + args: vec![ + format!("aa_{WORKER_OPERATION_ID}").as_bytes().into(), + "version".as_bytes().into(), + "data".as_bytes().into(), + ], + }, + Ok(RedisValue::Array(vec![ + // Version. + "2".into(), + // Data. + RedisValue::Bytes(Bytes::from(serde_json::to_string(&new_awaited_action).unwrap())), + ])), + None, + ) + ; + + let store = { + let mut builder = Builder::default_centralized(); + builder.set_config(RedisConfig { + mocks: Some(Arc::clone(&mocks) as Arc), + ..Default::default() + }); + + Arc::new( + RedisStore::new_from_builder_and_parts( + builder, + Some(SUB_CHANNEL.into()), + mock_uuid_generator, + String::new(), + ) + .unwrap(), + ) + }; + SUBSCRIPTION_MANAGER + .lock() + .replace(store.subscription_manager().unwrap()); + + let notifier = Arc::new(Notify::new()); + let awaited_action_db = StoreAwaitedActionDb::new( + store.clone(), + notifier.clone(), + || MockSystemTime::now().into(), + move || WORKER_OPERATION_ID.into(), + ) + .unwrap(); + + let mut subscription = awaited_action_db + .add_action( + CLIENT_OPERATION_ID.into(), + worker_awaited_action.action_info().clone(), + ) + .await + .unwrap(); + + { + // Check initial change state. + let changed_awaited_action_res = subscription.changed().await; + + assert_eq!( + changed_awaited_action_res.unwrap().state().stage, + ActionStage::Queued + ); + } + + { + // Update the action and check the new state. + let (changed_awaited_action_res, update_res) = tokio::join!( + subscription.changed(), + awaited_action_db.update_awaited_action(new_awaited_action.clone()) + ); + assert_eq!(update_res, Ok(())); + + assert_eq!( + changed_awaited_action_res.unwrap().state().stage, + ActionStage::Executing + ); + } + + Ok(()) +} diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index f9cc56505..61da818df 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -23,7 +23,7 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use async_lock::Mutex; use futures::task::Poll; use futures::{poll, Stream, StreamExt}; -use mock_instant::MockClock; +use mock_instant::{MockClock, SystemTime as MockSystemTime}; use nativelink_config::schedulers::PropertyType; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; @@ -873,6 +873,7 @@ impl AwaitedActionSubscriber for MockAwaitedActionSubscriber { Ok(AwaitedAction::new( OperationId::default(), make_base_action_info(SystemTime::UNIX_EPOCH, DigestInfo::zero_digest()), + MockSystemTime::now().into(), )) } } diff --git a/nativelink-store/BUILD.bazel b/nativelink-store/BUILD.bazel index 7000db9cc..7e5fde3c7 100644 --- a/nativelink-store/BUILD.bazel +++ b/nativelink-store/BUILD.bazel @@ -23,6 +23,8 @@ rust_library( "src/memory_store.rs", "src/noop_store.rs", "src/redis_store.rs", + "src/redis_utils/ft_aggregate.rs", + "src/redis_utils/mod.rs", "src/ref_store.rs", "src/s3_store.rs", "src/shard_store.rs", @@ -48,6 +50,8 @@ rust_library( "@crates//:blake3", "@crates//:byteorder", "@crates//:bytes", + "@crates//:bytes-utils", + "@crates//:const_format", "@crates//:filetime", "@crates//:fred", "@crates//:futures", @@ -57,6 +61,7 @@ rust_library( "@crates//:hyper-rustls", "@crates//:lz4_flex", "@crates//:parking_lot", + "@crates//:patricia_tree", "@crates//:prost", "@crates//:rand", "@crates//:serde", diff --git a/nativelink-store/Cargo.toml b/nativelink-store/Cargo.toml index 81cbfb900..47d567b9c 100644 --- a/nativelink-store/Cargo.toml +++ b/nativelink-store/Cargo.toml @@ -22,8 +22,14 @@ bincode = "1.3.3" blake3 = { version = "1.5.2", default-features = false } byteorder = { version = "1.5.0", default-features = false } bytes = { version = "1.6.1", default-features = false } +bytes-utils = { version = "0.1.4", default-features = false } +const_format = { version = "0.2.32", default-features = false } filetime = "0.2.23" -fred = { version = "9.0.3", features = [ +fred = { version = "9.1.2", default-features = false, features = [ + "i-std", + "i-scripts", + "i-redisearch", + "sha-1", "enable-rustls-ring", "metrics", "blocking-encoding", @@ -31,8 +37,8 @@ fred = { version = "9.0.3", features = [ "sentinel-client", "sentinel-auth", "subscriber-client", - "mocks", ] } +patricia_tree = { version = "0.8.0", default-features = false } futures = { version = "0.3.30", default-features = false } hex = { version = "0.4.3", default-features = false } http-body = "1.0.1" @@ -71,3 +77,4 @@ aws-smithy-runtime-api = "=1.7.1" serial_test = { version = "3.1.1", features = [ "async", ], default-features = false } +fred = { version = "9.1.2", default-features = false, features = ["mocks"] } diff --git a/nativelink-store/src/lib.rs b/nativelink-store/src/lib.rs index 03e6ba3fb..04040fa5b 100644 --- a/nativelink-store/src/lib.rs +++ b/nativelink-store/src/lib.rs @@ -25,6 +25,7 @@ pub mod grpc_store; pub mod memory_store; pub mod noop_store; pub mod redis_store; +mod redis_utils; pub mod ref_store; pub mod s3_store; pub mod shard_store; diff --git a/nativelink-store/src/redis_store.rs b/nativelink-store/src/redis_store.rs index 9aa531096..bda93bf71 100644 --- a/nativelink-store/src/redis_store.rs +++ b/nativelink-store/src/redis_store.rs @@ -15,23 +15,41 @@ use std::borrow::Cow; use std::cmp; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; +use const_format::formatcp; use fred::clients::{RedisClient, RedisPool, SubscriberClient}; use fred::interfaces::{ClientLike, KeysInterface, PubsubInterface}; -use fred::types::{Builder, ConnectionConfig, PerformanceConfig, ReconnectPolicy, RedisConfig}; +use fred::prelude::{EventInterface, HashesInterface, RediSearchInterface}; +use fred::types::{ + Builder, ConnectionConfig, FtCreateOptions, PerformanceConfig, ReconnectPolicy, RedisConfig, + RedisKey, RedisMap, RedisValue, Script, SearchSchema, SearchSchemaKind, +}; +use futures::{FutureExt, Stream, StreamExt}; use nativelink_config::stores::RedisMode; -use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; use nativelink_util::health_utils::{HealthRegistryBuilder, HealthStatus, HealthStatusIndicator}; -use nativelink_util::store_trait::{StoreDriver, StoreKey, UploadSizeInfo}; +use nativelink_util::spawn; +use nativelink_util::store_trait::{ + BoolValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore, + SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider, + SchedulerSubscription, SchedulerSubscriptionManager, StoreDriver, StoreKey, UploadSizeInfo, +}; +use nativelink_util::task::JoinHandleDropGuard; +use parking_lot::{Mutex, RwLock}; +use patricia_tree::StringPatriciaMap; +use tokio::select; +use tokio::time::sleep; +use tracing::{event, Level}; use uuid::Uuid; use crate::cas_utils::is_zero_digest; +use crate::redis_utils::ft_aggregate; // TODO(caass): These (and other settings) should be made configurable via nativelink-config. pub const READ_CHUNK_SIZE: usize = 64 * 1024; @@ -58,6 +76,14 @@ pub struct RedisStore { /// See [`RedisStore::key_prefix`](`nativelink_config::stores::RedisStore::key_prefix`). #[metric(help = "Prefix to append to all keys before sending to Redis")] key_prefix: String, + + /// Redis script used to update a value in redis if the version matches. + /// This is done by incrementing the version number and then setting the new data + /// only if the version number matches the existing version number. + update_if_version_matches_script: Script, + + /// A manager for subscriptions to keys in Redis. + subscription_manager: Mutex>>, } impl RedisStore { @@ -135,6 +161,8 @@ impl RedisStore { subscriber_client, temp_name_generator_fn, key_prefix, + update_if_version_matches_script: Script::from_lua(LUA_VERSION_SET_SCRIPT), + subscription_manager: Mutex::new(None), }) } @@ -161,13 +189,6 @@ impl RedisStore { } } - // TODO: These helpers eventually should not be necessary, as they are only used for functionality - // that could hypothetically be moved behind this API with some non-trivial logic adjustments - // and the addition of one or two new endpoints. - pub fn get_subscriber_client(&self) -> SubscriberClient { - self.subscriber_client.clone() - } - pub fn get_client(&self) -> RedisClient { self.client_pool.next().clone() } @@ -442,3 +463,591 @@ impl HealthStatusIndicator for RedisStore { StoreDriver::check_health(Pin::new(self), namespace).await } } + +/// ------------------------------------------------------------------- +/// Below this line are specific to the redis scheduler implementation. +/// ------------------------------------------------------------------- + +/// The maximum number of results to return per cursor. +const MAX_COUNT_PER_CURSOR: u64 = 256; +/// The time in milliseconds that a redis cursor can be idle before it is closed. +const CURSOR_IDLE_MS: u64 = 2_000; +/// The name of the field in the Redis hash that stores the data. +const DATA_FIELD_NAME: &str = "data"; +/// The name of the field in the Redis hash that stores the version. +const VERSION_FIELD_NAME: &str = "version"; +/// The time to live of indexes in seconds. After this time redis may delete the index. +const INDEX_TTL_S: u64 = 60 * 60 * 24; // 24 hours. + +/// String of the `FT.CREATE` command used to create the index template. It is done this +/// way so we can use it in both const (compile time) functions and runtime functions. +/// This is a macro because we need to use it in multiple places that sometimes require the +/// data as different data types (specifically for rust's format_args! macro). +macro_rules! get_create_index_template { + () => { + "FT.CREATE {} ON HASH PREFIX 1 {} NOOFFSETS NOHL NOFIELDS NOFREQS SCHEMA {} TAG CASESENSITIVE SORTABLE" + } +} + +/// Lua script to set a key if the version matches. +/// Args: +/// KEYS[1]: The key where the version is stored. +/// ARGV[1]: The expected version. +/// ARGV[2]: The new data. +/// ARGV[3*]: Key-value pairs of additional data to include. +/// Returns: +/// The new version if the version matches. nil is returned if the +/// value was not set. +const LUA_VERSION_SET_SCRIPT: &str = formatcp!( + r#" +local key = KEYS[1] +local expected_version = tonumber(ARGV[1]) +local new_data = ARGV[2] +local new_version = redis.call('HINCRBY', key, '{VERSION_FIELD_NAME}', 1) +local i +local indexes = {{}} + +if new_version-1 ~= expected_version then + redis.call('HINCRBY', key, '{VERSION_FIELD_NAME}', -1) + return 0 +end +-- Skip first 2 argvs, as they are known inputs. +-- Remember: Lua is 1-indexed. +for i=3, #ARGV do + indexes[i-2] = ARGV[i] +end + +-- In testing we witnessed redis sometimes not update our FT indexes +-- resulting in stale data. It appears if we delete our keys then insert +-- them again it works and reduces risk significantly. +redis.call('DEL', key) +redis.call('HSET', key, '{DATA_FIELD_NAME}', new_data, '{VERSION_FIELD_NAME}', new_version, unpack(indexes)) + +return new_version +"# +); + +/// Compile-time fingerprint of the `FT.CREATE` command used to create the index template. +/// This is a simple CRC32 checksum of the command string. We don't care about it actually +/// being a valid CRC32 checksum, just that it's a unique identifier with a low chance of +/// collision. +const fn fingerprint_create_index_template() -> u32 { + const POLY: u32 = 0xEDB88320; + const DATA: &[u8] = get_create_index_template!().as_bytes(); + let mut crc = 0xFFFFFFFF; + let mut i = 0; + while i < DATA.len() { + let byte = DATA[i]; + crc ^= byte as u32; + + let mut j = 0; + while j < 8 { + crc = if crc & 1 != 0 { + (crc >> 1) ^ POLY + } else { + crc >> 1 + }; + j += 1; + } + i += 1; + } + crc +} + +/// Get the name of the index to create for the given field. +/// This will add some prefix data to the name to try and ensure +/// if the index definition changes, the name will get a new name. +macro_rules! get_index_name { + ($prefix:expr, $field:expr) => { + format_args!( + "{}_{}_{:08x}", + $prefix, + $field, + fingerprint_create_index_template(), + ) + }; +} + +/// Try to sanitize a string to be used as a Redis key. +/// We don't actually modify the string, just check if it's valid. +const fn try_sanitize(s: &str) -> Option<&str> { + // Note: We cannot use for loops or iterators here because they are not const. + // Allowing us to use a const function here gives the compiler the ability to + // optimize this function away entirely in the case where the input is constant. + let chars = s.as_bytes(); + let mut i: usize = 0; + let len = s.len(); + loop { + if i >= len { + break; + } + let c = chars[i]; + if !c.is_ascii_alphanumeric() && c != b'_' { + return None; + } + i += 1; + } + Some(s) +} + +/// An individual subscription to a key in Redis. +pub struct RedisSubscription { + receiver: Option>, + weak_subscribed_keys: Weak>>, +} + +impl SchedulerSubscription for RedisSubscription { + /// Wait for the subscription key to change. + async fn changed(&mut self) -> Result<(), Error> { + let receiver = self + .receiver + .as_mut() + .ok_or_else(|| make_err!(Code::Internal, "In RedisSubscription::changed::as_mut"))?; + receiver + .changed() + .await + .map_err(|_| make_err!(Code::Internal, "In RedisSubscription::changed::changed")) + } +} + +// If the subscription is dropped, we need to possibly remove the key from the +// subscribed keys map. +impl Drop for RedisSubscription { + fn drop(&mut self) { + let Some(receiver) = self.receiver.take() else { + event!( + Level::WARN, + "RedisSubscription has already been dropped, nothing to do." + ); + return; // Already dropped, nothing to do. + }; + let key = receiver.borrow().clone(); + // IMPORTANT: This must be dropped before receiver_count() is called. + drop(receiver); + let Some(subscribed_keys) = self.weak_subscribed_keys.upgrade() else { + return; // Already dropped, nothing to do. + }; + let mut subscribed_keys = subscribed_keys.write(); + let Some(value) = subscribed_keys.get(&key) else { + event!( + Level::ERROR, + "Key {key} was not found in subscribed keys when checking if it should be removed." + ); + return; + }; + // If we have no receivers, cleanup the entry from our map. + if value.receiver_count() == 0 { + subscribed_keys.remove(key); + } + } +} + +/// A publisher for a key in Redis. +struct RedisSubscriptionPublisher { + sender: Mutex>, +} + +impl RedisSubscriptionPublisher { + fn new( + key: String, + weak_subscribed_keys: Weak>>, + ) -> (Self, RedisSubscription) { + let (sender, receiver) = tokio::sync::watch::channel(key); + let publisher = Self { + sender: Mutex::new(sender), + }; + let subscription = RedisSubscription { + receiver: Some(receiver), + weak_subscribed_keys, + }; + (publisher, subscription) + } + + fn subscribe( + &self, + weak_subscribed_keys: Weak>>, + ) -> RedisSubscription { + let receiver = self.sender.lock().subscribe(); + RedisSubscription { + receiver: Some(receiver), + weak_subscribed_keys, + } + } + + fn receiver_count(&self) -> usize { + self.sender.lock().receiver_count() + } + + fn notify(&self) { + // TODO(https://github.com/sile/patricia_tree/issues/40) When this is addressed + // we can remove the `Mutex` and use the mutable iterator directly. + self.sender.lock().send_modify(|_| {}); + } +} + +pub struct RedisSubscriptionManager { + subscribed_keys: Arc>>, + tx_for_test: tokio::sync::mpsc::UnboundedSender, + _subscription_spawn: JoinHandleDropGuard<()>, +} + +impl RedisSubscriptionManager { + pub fn new(subscribe_client: SubscriberClient, pub_sub_channel: String) -> Self { + let subscribed_keys = Arc::new(RwLock::new(StringPatriciaMap::new())); + let subscribed_keys_weak = Arc::downgrade(&subscribed_keys); + let (tx_for_test, mut rx_for_test) = tokio::sync::mpsc::unbounded_channel(); + Self { + subscribed_keys, + tx_for_test, + _subscription_spawn: spawn!("redis_subscribe_spawn", async move { + let mut rx = subscribe_client.message_rx(); + loop { + if let Err(e) = subscribe_client.subscribe(&pub_sub_channel).await { + event!(Level::ERROR, "Error subscribing to pattern - {e}"); + return; + } + let mut reconnect_rx = subscribe_client.reconnect_rx(); + let reconnect_fut = reconnect_rx.recv().fuse(); + tokio::pin!(reconnect_fut); + loop { + let key = select! { + value = rx_for_test.recv() => { + let Some(value) = value else { + unreachable!("Channel should never close"); + }; + value.into() + }, + msg = rx.recv() => { + match msg { + Ok(msg) => { + match msg.value { + RedisValue::String(s) => s, + _ => { + event!(Level::ERROR, "Received non-string message in RedisSubscriptionManager"); + continue; + } + } + }, + Err(e) => { + // Check to see if our parent has been dropped and if so kill spawn. + if subscribed_keys_weak.upgrade().is_none() { + event!(Level::WARN, "It appears our parent has been dropped, exiting RedisSubscriptionManager spawn"); + return; + }; + event!(Level::ERROR, "Error receiving message in RedisSubscriptionManager reconnecting and flagging everything changed - {e}"); + break; + } + } + }, + _ = &mut reconnect_fut => { + event!(Level::WARN, "Redis reconnected flagging all subscriptions as changed and resuming"); + break; + } + }; + let Some(subscribed_keys) = subscribed_keys_weak.upgrade() else { + event!(Level::WARN, "It appears our parent has been dropped, exiting RedisSubscriptionManager spawn"); + return; + }; + let subscribed_keys_mux = subscribed_keys.read(); + subscribed_keys_mux + .common_prefix_values(&*key) + .for_each(|publisher| publisher.notify()); + } + // Sleep for a small amount of time to ensure we don't reconnect too quickly. + sleep(Duration::from_secs(1)).await; + // If we reconnect or lag behind we might have had dirty keys, so we need to + // flag all of them as changed. + let Some(subscribed_keys) = subscribed_keys_weak.upgrade() else { + event!(Level::WARN, "It appears our parent has been dropped, exiting RedisSubscriptionManager spawn"); + return; + }; + let subscribed_keys_mux = subscribed_keys.read(); + // Just in case also get a new receiver. + rx = subscribe_client.message_rx(); + // Drop all buffered messages, then flag everything as changed. + rx.resubscribe(); + for publisher in subscribed_keys_mux.values() { + publisher.notify(); + } + } + }), + } + } +} + +impl SchedulerSubscriptionManager for RedisSubscriptionManager { + type Subscription = RedisSubscription; + + fn notify_for_test(&self, value: String) { + self.tx_for_test.send(value).unwrap(); + } + + fn subscribe(&self, key: K) -> Result + where + K: SchedulerStoreKeyProvider, + { + let weak_subscribed_keys = Arc::downgrade(&self.subscribed_keys); + let mut subscribed_keys = self.subscribed_keys.write(); + let key = key.get_key(); + let key_str = key.as_str(); + let mut subscription = if let Some(publisher) = subscribed_keys.get(&key_str) { + publisher.subscribe(weak_subscribed_keys) + } else { + let (publisher, subscription) = + RedisSubscriptionPublisher::new(key_str.to_string(), weak_subscribed_keys); + subscribed_keys.insert(key_str, publisher); + subscription + }; + subscription + .receiver + .as_mut() + .ok_or_else(|| { + make_err!( + Code::Internal, + "Receiver should be set in RedisSubscriptionManager::subscribe" + ) + })? + .mark_changed(); + + Ok(subscription) + } +} + +impl SchedulerStore for RedisStore { + type SubscriptionManager = RedisSubscriptionManager; + + fn subscription_manager(&self) -> Result, Error> { + let mut subscription_manager = self.subscription_manager.lock(); + match &*subscription_manager { + Some(subscription_manager) => Ok(subscription_manager.clone()), + None => { + let Some(pub_sub_channel) = &self.pub_sub_channel else { + return Err(make_input_err!("RedisStore must have a pubsub channel for a Redis Scheduler if using subscriptions")); + }; + let sub = Arc::new(RedisSubscriptionManager::new( + self.subscriber_client.clone(), + pub_sub_channel.clone(), + )); + *subscription_manager = Some(sub.clone()); + Ok(sub) + } + } + } + + async fn update_data(&self, data: T) -> Result, Error> + where + T: SchedulerStoreDataProvider + + SchedulerStoreKeyProvider + + SchedulerCurrentVersionProvider + + Send, + { + let key = data.get_key(); + let key = self.encode_key(&key); + let client = self.client_pool.next(); + let maybe_index = data.get_indexes().err_tip(|| { + format!("Err getting index in RedisStore::update_data::versioned for {key:?}") + })?; + if ::Versioned::VALUE { + let current_version = data.current_version(); + let data = data.try_into_bytes().err_tip(|| { + format!("Could not convert value to bytes in RedisStore::update_data::versioned for {key:?}") + })?; + let mut argv = Vec::with_capacity(3 + maybe_index.len() * 2); + argv.push(Bytes::from(format!("{current_version}"))); + argv.push(data); + for (name, value) in maybe_index { + argv.push(Bytes::from_static(name.as_bytes())); + argv.push(value); + } + let new_version = self + .update_if_version_matches_script + .evalsha_with_reload::>(client, vec![key.as_ref()], argv) + .await + .err_tip(|| format!("In RedisStore::update_data::versioned for {key:?}"))?; + if new_version == 0 { + return Ok(None); + } + // If we have a publish channel configured, send a notice that the key has been set. + if let Some(pub_sub_channel) = &self.pub_sub_channel { + return Ok(client.publish(pub_sub_channel, key.as_ref()).await?); + }; + Ok(Some(new_version)) + } else { + let data = data.try_into_bytes().err_tip(|| { + format!("Could not convert value to bytes in RedisStore::update_data::noversion for {key:?}") + })?; + let mut fields = RedisMap::new(); + fields.reserve(1 + maybe_index.len()); + fields.insert(DATA_FIELD_NAME.into(), data.into()); + for (name, value) in maybe_index { + fields.insert(name.into(), value.into()); + } + client + .hset::<(), _, _>(key.as_ref(), fields) + .await + .err_tip(|| format!("In RedisStore::update_data::noversion for {key:?}"))?; + // If we have a publish channel configured, send a notice that the key has been set. + if let Some(pub_sub_channel) = &self.pub_sub_channel { + return Ok(client.publish(pub_sub_channel, key.as_ref()).await?); + }; + Ok(Some(0)) // Always use "0" version since this is not a versioned request. + } + } + + async fn search_by_index_prefix( + &self, + index: K, + ) -> Result< + impl Stream::DecodeOutput, Error>> + Send, + Error, + > + where + K: SchedulerIndexProvider + SchedulerStoreDecodeTo + Send, + { + let index_value_prefix = index.index_value_prefix(); + let run_ft_aggregate = || { + let client = self.client_pool.next().clone(); + let sanitized_field = try_sanitize(index_value_prefix.as_ref()).err_tip(|| { + format!( + "In RedisStore::search_by_index_prefix::try_sanitize - {index_value_prefix:?}" + ) + })?; + Ok::<_, Error>(async move { + ft_aggregate( + client, + format!("{}", get_index_name!(K::KEY_PREFIX, K::INDEX_NAME)), + format!("@{}:{{ {}* }}", K::INDEX_NAME, sanitized_field), + fred::types::FtAggregateOptions { + load: Some(fred::types::Load::Some(vec![ + fred::types::SearchField { + identifier: DATA_FIELD_NAME.into(), + property: None, + }, + fred::types::SearchField { + identifier: VERSION_FIELD_NAME.into(), + property: None, + }, + ])), + cursor: Some(fred::types::WithCursor { + count: Some(MAX_COUNT_PER_CURSOR), + max_idle: Some(CURSOR_IDLE_MS), + }), + pipeline: vec![fred::types::AggregateOperation::SortBy { + properties: vec![( + format!("@{}", K::INDEX_NAME).into(), + fred::types::SortOrder::Asc, + )], + max: None, + }], + ..Default::default() + }, + ) + .await + }) + }; + let stream = match run_ft_aggregate()?.await { + Ok(stream) => stream, + Err(_) => { + let create_result = self + .client_pool + .next() + .ft_create::<(), _>( + format!("{}", get_index_name!(K::KEY_PREFIX, K::INDEX_NAME)), + FtCreateOptions { + on: Some(fred::types::IndexKind::Hash), + prefixes: vec![K::KEY_PREFIX.into()], + nohl: true, + nofields: true, + nofreqs: true, + nooffsets: true, + temporary: Some(INDEX_TTL_S), + ..Default::default() + }, + vec![SearchSchema { + field_name: K::INDEX_NAME.into(), + alias: None, + kind: SearchSchemaKind::Tag { + sortable: true, + unf: false, + separator: None, + casesensitive: false, + withsuffixtrie: false, + noindex: false, + }, + }], + ) + .await + .err_tip(|| { + format!( + "Error with ft_create in RedisStore::search_by_index_prefix({})", + get_index_name!(K::KEY_PREFIX, K::INDEX_NAME), + ) + }); + let run_result = run_ft_aggregate()?.await.err_tip(|| { + format!( + "Error with second ft_aggregate in RedisStore::search_by_index_prefix({})", + get_index_name!(K::KEY_PREFIX, K::INDEX_NAME), + ) + }); + // Creating the index will race which is ok. If it fails to create, we only + // error if the second ft_aggregate call fails and fails to create. + match run_result { + Ok(stream) => stream, + Err(e) => return create_result.merge(Err(e)), + } + } + }; + Ok(stream.map(|result| { + let mut redis_map = + result.err_tip(|| "Error in stream of in RedisStore::search_by_index_prefix")?; + let bytes_data = redis_map + .remove(&RedisKey::from_static_str(DATA_FIELD_NAME)) + .err_tip(|| "Missing data field in RedisStore::search_by_index_prefix")? + .into_bytes() + .err_tip(|| { + formatcp!("'{DATA_FIELD_NAME}' is not Bytes in RedisStore::search_by_index_prefix::into_bytes") + })?; + let version = if ::Versioned::VALUE { + redis_map + .remove(&RedisKey::from_static_str(VERSION_FIELD_NAME)) + .err_tip(|| "Missing version field in RedisStore::search_by_index_prefix")? + .as_u64() + .err_tip(|| { + formatcp!("'{VERSION_FIELD_NAME}' is not u64 in RedisStore::search_by_index_prefix::as_u64") + })? + } else { + 0 + }; + K::decode(version, bytes_data) + .err_tip(|| "In RedisStore::search_by_index_prefix::decode") + })) + } + + async fn get_and_decode( + &self, + key: K, + ) -> Result::DecodeOutput>, Error> + where + K: SchedulerStoreKeyProvider + SchedulerStoreDecodeTo + Send, + { + let key = key.get_key(); + let key = self.encode_key(&key); + let client = self.client_pool.next(); + let (maybe_version, maybe_data) = client + .hmget::<(Option, Option), _, _>( + key.as_ref(), + vec![ + RedisKey::from(VERSION_FIELD_NAME), + RedisKey::from(DATA_FIELD_NAME), + ], + ) + .await + .err_tip(|| format!("In RedisStore::get_without_version::notversioned {key}"))?; + let Some(data) = maybe_data else { + return Ok(None); + }; + Ok(Some(K::decode(maybe_version.unwrap_or(0), data).err_tip( + || format!("In RedisStore::get_with_version::notversioned::decode {key}"), + )?)) + } +} diff --git a/nativelink-store/src/redis_utils/ft_aggregate.rs b/nativelink-store/src/redis_utils/ft_aggregate.rs new file mode 100644 index 000000000..1b685bbd8 --- /dev/null +++ b/nativelink-store/src/redis_utils/ft_aggregate.rs @@ -0,0 +1,127 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; + +use fred::error::{RedisError, RedisErrorKind}; +use fred::interfaces::RediSearchInterface; +use fred::types::{FromRedis, RedisMap, RedisValue}; +use futures::Stream; + +/// Calls FT_AGGREGATE in redis. Fred does not properly support this command +/// so we have to manually handle it. +pub async fn ft_aggregate( + client: C, + index: I, + query: Q, + options: fred::types::FtAggregateOptions, +) -> Result> + Send, RedisError> +where + C: RediSearchInterface, + I: Into, + Q: Into, +{ + let index = index.into(); + let query = query.into(); + let data: RedisCursorData = client.ft_aggregate(index.clone(), query, options).await?; + + struct State { + client: C, + index: bytes_utils::string::Str, + data: RedisCursorData, + } + + let state = State { + client, + index, + data, + }; + Ok(futures::stream::unfold( + Some(state), + move |maybe_state| async move { + let mut state = maybe_state?; + loop { + match state.data.data.pop_front() { + Some(map) => { + return Some((Ok(map), Some(state))); + } + None => { + if state.data.cursor == 0 { + return None; + } + let data_res = state + .client + .ft_cursor_read(state.index.clone(), state.data.cursor, None) + .await; + state.data = match data_res { + Ok(data) => data, + Err(err) => return Some((Err(err), None)), + }; + continue; + } + } + } + }, + )) +} + +#[derive(Debug, Default)] +struct RedisCursorData { + total: u64, + cursor: u64, + data: VecDeque, +} + +impl FromRedis for RedisCursorData { + fn from_value(value: RedisValue) -> Result { + if !value.is_array() { + return Err(RedisError::new(RedisErrorKind::Protocol, "Expected array")); + } + let mut output = Self::default(); + let value = value.into_array(); + if value.len() < 2 { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Expected at least 2 elements", + )); + } + let mut value = value.into_iter(); + let data_ary = value.next().unwrap().into_array(); + if data_ary.is_empty() { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Expected at least 1 element in data array", + )); + } + let Some(total) = data_ary[0].as_u64() else { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Expected integer as first element", + )); + }; + output.total = total; + output.data.reserve(data_ary.len() - 1); + for map_data in data_ary.into_iter().skip(1) { + output.data.push_back(map_data.into_map()?); + } + let Some(cursor) = value.next().unwrap().as_u64() else { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Expected integer as last element", + )); + }; + output.cursor = cursor; + Ok(output) + } +} diff --git a/nativelink-store/src/redis_utils/mod.rs b/nativelink-store/src/redis_utils/mod.rs new file mode 100644 index 000000000..032e9cbff --- /dev/null +++ b/nativelink-store/src/redis_utils/mod.rs @@ -0,0 +1,16 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod ft_aggregate; +pub use ft_aggregate::ft_aggregate; diff --git a/nativelink-util/src/action_messages.rs b/nativelink-util/src/action_messages.rs index 046007c76..7689b3859 100644 --- a/nativelink-util/src/action_messages.rs +++ b/nativelink-util/src/action_messages.rs @@ -245,10 +245,11 @@ impl std::fmt::Display for ActionUniqueQualifier { f.write_fmt(format_args!( // Note: We use underscores because it makes escaping easier // for redis. - "{}/{}/{}/{}", + "{}_{}_{}_{}_{}", unique_key.instance_name, unique_key.digest_function, - unique_key.digest, + unique_key.digest.packed_hash(), + unique_key.digest.size_bytes(), if cachable { 'c' } else { 'u' }, )) } @@ -283,7 +284,7 @@ impl std::fmt::Display for ActionUniqueKey { /// to ensure we never match against another `ActionInfo` (when a task should never be cached). /// This struct must be 100% compatible with `ExecuteRequest` struct in `remote_execution.proto` /// except for the salt field. -#[derive(Clone, Debug, Serialize, Deserialize, MetricsComponent)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, MetricsComponent)] pub struct ActionInfo { /// Digest of the underlying `Command`. #[metric(help = "Digest of the underlying Command.")] diff --git a/nativelink-util/src/store_trait.rs b/nativelink-util/src/store_trait.rs index 1323c6df5..9f0f3509c 100644 --- a/nativelink-util/src/store_trait.rs +++ b/nativelink-util/src/store_trait.rs @@ -849,6 +849,8 @@ pub trait SchedulerSubscription: Send + Sync { pub trait SchedulerSubscriptionManager: Send + Sync { type Subscription: SchedulerSubscription; + fn notify_for_test(&self, value: String); + fn subscribe(&self, key: K) -> Result where K: SchedulerStoreKeyProvider;