Skip to content

Commit

Permalink
Add Redis ClientStateManager Implementation
Browse files Browse the repository at this point in the history
Adds RedisStore based implementation for ClientStateManager trait.
Provides API hooks for adding or filtering Operations.
  • Loading branch information
zbirenbaum committed Jun 18, 2024
1 parent 4ec3d22 commit 69d345b
Showing 1 changed file with 145 additions and 9 deletions.
154 changes: 145 additions & 9 deletions nativelink-scheduler/src/redis_operation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::time::SystemTime;
use futures::{join};
use nativelink_util::buf_channel::make_buf_channel_pair;
use nativelink_util::background_spawn;
use std::iter::zip;

use tonic::async_trait;
use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, WorkerId};
use nativelink_error::{make_input_err, Error};
use tokio::sync::watch;

use redis::AsyncCommands;
use redis::aio::{ConnectionLike, ConnectionManager};
use redis_macros::{FromRedisValue, ToRedisArgs};
use serde::{Serialize, Deserialize};
use futures::{join, StreamExt};

use nativelink_store::redis_store::RedisStore;
use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, WorkerId};
use nativelink_error::{make_input_err, Error};
use nativelink_util::buf_channel::make_buf_channel_pair;
use nativelink_util::background_spawn;
use nativelink_util::store_trait::{StoreDriver, StoreLike, StoreSubscription};

use crate::operation_state_manager::{ActionStateResult, OperationStageFlags};
use nativelink_util::store_trait::StoreSubscription;
use crate::operation_state_manager::{ActionStateResultStream, ClientStateManager, OperationFilter};

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
enum OperationStage {
Expand All @@ -36,7 +46,7 @@ enum OperationStage {
Unknown,
}

fn _parse_stage_flags(flags: &OperationStageFlags) -> Vec<OperationStage> {
fn parse_stage_flags(flags: &OperationStageFlags) -> Vec<OperationStage> {
if flags.contains(OperationStageFlags::Any) {
return Vec::from([
OperationStage::CacheCheck,
Expand Down Expand Up @@ -68,7 +78,7 @@ pub struct RedisOperationState {
}

impl RedisOperationState {
fn _new(inner: RedisOperation, mut subscription: Box<dyn StoreSubscription>) -> Self {
fn new(inner: RedisOperation, mut subscription: Box<dyn StoreSubscription>) -> Self {
let (tx, rx) = watch::channel(inner.as_state().unwrap());

let _join_handle = background_spawn!("redis_subscription_watcher", async move {
Expand Down Expand Up @@ -189,7 +199,7 @@ impl RedisOperation {
}
}
}
fn _unique_qualifier(&self) -> &ActionInfoHashKey {
fn unique_qualifier(&self) -> &ActionInfoHashKey {
&self.operation_id.unique_qualifier
}
}
Expand All @@ -212,3 +222,129 @@ impl TryFrom<RedisOperation> for ActionState {
})
}
}

fn match_optional_filter<T: PartialEq>(value_opt: Option<T>, filter_opt: Option<T>, cond: impl Fn(T, T) -> bool) -> bool {
let Some(filter) = filter_opt else { return true };
let Some(value) = value_opt else { return false; };
cond(filter, value)
}

pub fn matches_filter(operation: &RedisOperation, filter: &OperationFilter) -> bool {
if !parse_stage_flags(&filter.stages).contains(&operation.stage) && !(filter.stages == OperationStageFlags::Any) {
return false
}
match_optional_filter(Some(&operation.operation_id), filter.operation_id.as_ref(), |a, b| { a == b })
&& match_optional_filter(operation.worker_id, filter.worker_id, |a, b| { a == b })
&& match_optional_filter(Some(operation.unique_qualifier().digest), filter.action_digest, |a, b| { a == b })
&& match_optional_filter(operation.completed_at, filter.completed_before, |a, b| { a < b })
&& match_optional_filter(operation.last_client_update, filter.last_client_update_before, |a, b| { a < b })
}
pub struct RedisStateManager<T: ConnectionLike + Unpin + Clone + Send + Sync + 'static = ConnectionManager> {
pub store: Arc<RedisStore<T>>
}

impl<T: ConnectionLike + Unpin + Clone + Send + Sync + 'static> RedisStateManager<T> {

pub async fn get_conn(&self) -> Result<T, Error> {
self.store.get_conn().await
}

async fn list<'a, V>(&self, prefix: &str, result_map: &mut HashMap<String, V>) -> Result<(), Error>
where
V: FromStr
{
let mut con = self.get_conn().await?;
let ids_iter = con.scan_match::<&str, String>(prefix).await?;
let keys = ids_iter.collect::<Vec<String>>().await;
let raw_values: Vec<String> = con.get(&keys).await?;

let value_res: Result<Vec<V>, Error> = raw_values.into_iter().map(|s| {
V::from_str(&s).map_err(|_| {
make_input_err!("list: Failed to convert value to type")
})
}).collect();
match value_res {
Ok(values) => {
let zipped = zip(keys.into_iter(), values.into_iter());
*result_map = HashMap::from_iter(zipped);
Ok(())
},
Err(e) => {
Err(e)
}
}
}

async fn inner_add_action(
&self,
action_info: ActionInfo,
) -> Result<Arc<dyn ActionStateResult>, Error> {
let operation_id = OperationId::new(action_info.unique_qualifier.clone());
let mut con = self.get_conn().await?;
let hash_key = operation_id.unique_qualifier.action_name().clone();

let action_key = format!("actions:{}", hash_key.clone());
// TODO: List API call to find existing actions.
let mut existing_operations: Vec<String> = Vec::new();
let operation = match existing_operations.pop() {
Some(existing_operation) => {
let operation: RedisOperation = con.get(format!("operations:{}", &existing_operation)).await?;
RedisOperation::from_existing(operation.clone(), operation_id.clone())
},
None => {
RedisOperation::new(action_info, operation_id.clone())
}
};

let operation_key = format!("operations:{}", operation_id).to_string();
let store = self.store.as_store_driver_pin();
store
.update_oneshot(operation_key.into(), operation.as_json().into())
.await?;
store
.update_oneshot(action_key.into(), operation_id.to_string().into())
.await?;
let store_subscription = self.store.clone().subscribe(format!("operations:{}", operation_id).into()).await;
let state = RedisOperationState::new(operation, store_subscription);
Ok(Arc::new(state))
}

async fn inner_filter_operations(
&self,
filter: OperationFilter,
) -> Result<ActionStateResultStream, Error> {
let mut existing_operations: HashMap<String, RedisOperation> = HashMap::new();
self.list("operations:*", &mut existing_operations).await?;
let mut v: Vec<Arc<dyn ActionStateResult>> = Vec::new();
for operation in existing_operations.values() {
if matches_filter(operation, &filter) {

let store_subscription = self.store.clone().subscribe(format!("operations:{}", operation.operation_id).into()).await;
v.push(Arc::new(
RedisOperationState::new(
operation.clone(), store_subscription
)
)
);
}
}
Ok(Box::pin(futures::stream::iter(v)))
}
}

#[async_trait]
impl ClientStateManager for RedisStateManager {
async fn add_action(
&mut self,
action_info: ActionInfo,
) -> Result<Arc<dyn ActionStateResult>, Error> {
self.inner_add_action(action_info).await
}

async fn filter_operations(
&self,
filter: OperationFilter,
) -> Result<ActionStateResultStream, Error> {
self.inner_filter_operations(filter).await
}
}

0 comments on commit 69d345b

Please sign in to comment.