diff --git a/src/common/procedure/src/local/runner.rs b/src/common/procedure/src/local/runner.rs index c4470d888b40..6302adfc649c 100644 --- a/src/common/procedure/src/local/runner.rs +++ b/src/common/procedure/src/local/runner.rs @@ -153,6 +153,10 @@ impl Runner { // Release locks and notify parent procedure. guard.finish(); + // Clean the staled locks. + self.manager_ctx + .key_lock + .clean_keys(self.meta.lock_key.keys_to_lock().map(|k| k.as_string())); // If this is the root procedure, clean up message cache. if self.meta.parent_id.is_none() { @@ -787,6 +791,7 @@ mod tests { runner.manager_ctx = manager_ctx.clone(); runner.run().await; + assert!(manager_ctx.key_lock.is_empty()); // Check child procedures. for child_id in children_ids { @@ -1045,10 +1050,11 @@ mod tests { // Manually add this procedure to the manager ctx. assert!(manager_ctx.try_insert_procedure(meta.clone())); // Replace the manager ctx. - runner.manager_ctx = manager_ctx; + runner.manager_ctx = manager_ctx.clone(); // Run the runner and execute the procedure. runner.run().await; + assert!(manager_ctx.key_lock.is_empty()); let err = meta.state().error().unwrap().output_msg(); assert!(err.contains("subprocedure failed"), "{err}"); } diff --git a/src/common/procedure/src/local/rwlock.rs b/src/common/procedure/src/local/rwlock.rs index b1c6d474572c..7ad983fc381f 100644 --- a/src/common/procedure/src/local/rwlock.rs +++ b/src/common/procedure/src/local/rwlock.rs @@ -43,29 +43,6 @@ pub struct KeyRwLock { inner: Arc>>>>, } -impl KeyRwLock -where - K: Eq + Hash + Clone, -{ - /// Remove locks that are not locked currently. - fn clean_up(locks: &mut HashMap>>) { - let keys = locks - .iter() - .filter_map(|(key, lock)| { - if lock.try_write().is_ok() { - Some(key.clone()) - } else { - None - } - }) - .collect::>(); - - for key in keys { - locks.remove(&key); - } - } -} - impl KeyRwLock where K: Eq + Hash + Send + Clone, @@ -80,8 +57,7 @@ where pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> { let lock = { let mut locks = self.inner.lock().unwrap(); - Self::clean_up(&mut locks); - locks.entry(key.clone()).or_default().clone() + locks.entry(key).or_default().clone() }; lock.read_owned().await @@ -91,8 +67,7 @@ where pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> { let lock = { let mut locks = self.inner.lock().unwrap(); - Self::clean_up(&mut locks); - locks.entry(key.clone()).or_default().clone() + locks.entry(key).or_default().clone() }; lock.write_owned().await @@ -102,8 +77,7 @@ where pub fn try_read(&self, key: K) -> Result, TryLockError> { let lock = { let mut locks = self.inner.lock().unwrap(); - Self::clean_up(&mut locks); - locks.entry(key.clone()).or_default().clone() + locks.entry(key).or_default().clone() }; lock.try_read_owned() @@ -113,8 +87,7 @@ where pub fn try_write(&self, key: K) -> Result, TryLockError> { let lock = { let mut locks = self.inner.lock().unwrap(); - Self::clean_up(&mut locks); - locks.entry(key.clone()).or_default().clone() + locks.entry(key).or_default().clone() }; lock.try_write_owned() @@ -129,6 +102,24 @@ where pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Clean up stale locks. + pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator) { + let mut locks = self.inner.lock().unwrap(); + + let mut keys = Vec::new(); + for key in iter { + if let Some(lock) = locks.get(key) { + if lock.try_write().is_ok() { + keys.push(key); + } + } + } + + for key in keys { + locks.remove(key); + } + } } #[cfg(test)] @@ -155,5 +146,8 @@ mod tests { } assert_eq!(lock_key.len(), 2); + + lock_key.clean_keys(&vec!["test1", "test2"]); + assert!(lock_key.is_empty()); } } diff --git a/src/common/procedure/src/procedure.rs b/src/common/procedure/src/procedure.rs index 1b6c5c13fc20..65f5dc0c2d55 100644 --- a/src/common/procedure/src/procedure.rs +++ b/src/common/procedure/src/procedure.rs @@ -136,6 +136,13 @@ impl StringKey { StringKey::Exclusive(s) => s, } } + + pub fn as_string(&self) -> &String { + match self { + StringKey::Share(s) => s, + StringKey::Exclusive(s) => s, + } + } } impl LockKey {