diff --git a/src/storage/src/hummock/utils.rs b/src/storage/src/hummock/utils.rs index c8889a6164cb2..2e2e1d76f7389 100644 --- a/src/storage/src/hummock/utils.rs +++ b/src/storage/src/hummock/utils.rs @@ -15,14 +15,17 @@ use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::{Debug, Formatter}; +use std::future::Future; use std::ops::Bound::{Excluded, Included, Unbounded}; use std::ops::{Bound, RangeBounds}; use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering}; use std::sync::Arc; +use std::task::Poll; use std::time::Duration; use bytes::Bytes; use foyer::memory::CacheContext; +use futures::FutureExt; use parking_lot::Mutex; use risingwave_common::catalog::{TableId, TableOption}; use risingwave_hummock_sdk::key::{ @@ -175,12 +178,40 @@ enum MemoryRequest { pub struct PendingRequestCancelGuard { inner: Option>, + rc: Receiver, } impl Drop for PendingRequestCancelGuard { fn drop(&mut self) { if let Some(limiter) = self.inner.take() { - limiter.may_notify_waiters(); + self.rc.close(); + if let Ok(msg) = self.rc.try_recv() { + drop(msg); + if limiter.pending_request_count.load(AtomicOrdering::Acquire) > 0 { + limiter.may_notify_waiters(); + } + } + } + } +} + +impl Future for PendingRequestCancelGuard { + type Output = Option; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll { + match self.rc.poll_unpin(cx) { + Poll::Ready(Ok(msg)) => { + self.inner.take(); + Poll::Ready(Some(msg)) + } + Poll::Ready(Err(_)) => { + self.inner.take(); + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, } } } @@ -203,11 +234,11 @@ impl MemoryLimiterInner { fn may_notify_waiters(self: &Arc) { let mut waiters = self.controller.lock(); - while let Some((tx, quota)) = waiters.pop_front() { - if !self.try_require_memory(quota) { - waiters.push_front((tx, quota)); + while let Some((_, quota)) = waiters.front() { + if !self.try_require_memory(*quota) { break; } + let (tx, quota) = waiters.pop_front().unwrap(); let _ = tx.send(MemoryTrackerImpl::new(self.clone(), quota)); } @@ -275,12 +306,12 @@ impl Debug for MemoryLimiter { } } -struct MemoryTrackerImpl { +pub struct MemoryTrackerImpl { limiter: Arc, quota: Option, } impl MemoryTrackerImpl { - pub fn new(limiter: Arc, quota: u64) -> Self { + fn new(limiter: Arc, quota: u64) -> Self { Self { limiter, quota: Some(quota), @@ -350,11 +381,11 @@ impl MemoryLimiter { match self.inner.require_memory(quota) { MemoryRequest::Ready(inner) => MemoryTracker { inner }, MemoryRequest::Pending(rc) => { - let mut guard = PendingRequestCancelGuard { + let guard = PendingRequestCancelGuard { inner: Some(self.inner.clone()), + rc, }; - let inner = rc.await.unwrap(); - guard.inner.take(); + let inner = guard.await.unwrap(); MemoryTracker { inner } } }