diff --git a/src/storage/src/hummock/utils.rs b/src/storage/src/hummock/utils.rs index 2e2e1d76f7389..0c42af9288e9f 100644 --- a/src/storage/src/hummock/utils.rs +++ b/src/storage/src/hummock/utils.rs @@ -697,8 +697,10 @@ pub(crate) async fn wait_for_epoch( #[cfg(test)] mod tests { use std::future::{poll_fn, Future}; + use std::sync::Arc; use std::task::Poll; + use futures::future::join_all; use futures::FutureExt; use crate::hummock::utils::MemoryLimiter; @@ -731,4 +733,39 @@ mod tests { drop(tracker3); assert_eq!(0, memory_limiter.get_memory_usage()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + async fn test_multi_thread_acquire_memory() { + const QUOTA: u64 = 15; + let memory_limiter = Arc::new(MemoryLimiter::new(100)); + let mut handles = vec![]; + for _ in 0..10 { + let limiter = memory_limiter.clone(); + let h = tokio::spawn(async move { + let mut buffers = vec![]; + for idx in 0..1000 { + if let Some(tracker) = limiter.try_require_memory(QUOTA) { + buffers.push(tracker); + } else { + buffers.clear(); + let req = limiter.require_memory(QUOTA); + match tokio::time::timeout(std::time::Duration::from_millis(1), req).await { + Ok(tracker) => { + buffers.push(tracker); + } + Err(_) => { + continue; + } + } + } + if idx % 3 == 0 { + tokio::time::sleep(std::time::Duration::from_millis(1)).await; + } + } + }); + handles.push(h); + } + let h = join_all(handles); + let _ = h.await; + } }