Skip to content

Commit

Permalink
fix(core): Poll TimeoutLayer::sleep once to make sure timer registered
Browse files Browse the repository at this point in the history
* feat(test): Cover the test for list timeout

Signed-off-by: Xuanwo <[email protected]>

* Fix

Signed-off-by: Xuanwo <[email protected]>

---------

Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo authored Feb 21, 2024
1 parent 7303f78 commit 72e3ea0
Showing 1 changed file with 71 additions and 26 deletions.
97 changes: 71 additions & 26 deletions core/src/layers/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,22 +288,21 @@ impl<R> TimeoutWrapper<R> {

#[inline]
fn poll_timeout(&mut self, cx: &mut Context<'_>, op: &'static str) -> Result<()> {
if let Some(sleep) = self.sleep.as_mut() {
match sleep.as_mut().poll(cx) {
Poll::Pending => Ok(()),
Poll::Ready(_) => {
self.sleep = None;
Err(
Error::new(ErrorKind::Unexpected, "io operation timeout reached")
.with_operation(op)
.with_context("io_timeout", self.timeout.as_secs_f64().to_string())
.set_temporary(),
)
}
let sleep = self
.sleep
.get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout)));

match sleep.as_mut().poll(cx) {
Poll::Pending => Ok(()),
Poll::Ready(_) => {
self.sleep = None;
Err(
Error::new(ErrorKind::Unexpected, "io operation timeout reached")
.with_operation(op)
.with_context("io_timeout", self.timeout.as_secs_f64().to_string())
.set_temporary(),
)
}
} else {
self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout)));
Ok(())
}
}
}
Expand Down Expand Up @@ -380,6 +379,7 @@ mod tests {

use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use tokio::time::sleep;
use tokio::time::timeout;

Expand All @@ -397,7 +397,7 @@ mod tests {
impl Accessor for MockService {
type Reader = MockReader;
type Writer = ();
type Lister = ();
type Lister = MockLister;
type BlockingReader = ();
type BlockingWriter = ();
type BlockingLister = ();
Expand All @@ -424,6 +424,10 @@ mod tests {

Ok(RpDelete::default())
}

async fn list(&self, _: &str, _: OpList) -> Result<(RpList, Self::Lister)> {
Ok((RpList::default(), MockLister))
}
}

#[derive(Debug, Clone, Default)]
Expand All @@ -443,6 +447,15 @@ mod tests {
}
}

#[derive(Debug, Clone, Default)]
struct MockLister;

impl oio::List for MockLister {
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Result<Option<oio::Entry>>> {
Poll::Pending
}
}

#[tokio::test]
async fn test_operation_timeout() {
let acc = Arc::new(TypeEraseLayer.layer(MockService)) as FusedAccessor;
Expand All @@ -468,18 +481,50 @@ mod tests {
let op = Operator::from_inner(acc)
.layer(TimeoutLayer::new().with_io_timeout(Duration::from_secs(1)));

let fut = async {
let mut reader = op.reader("test").await.unwrap();
let mut reader = op.reader("test").await.unwrap();

let res = reader.read(&mut [0; 4]).await;
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unexpected);
assert!(err.to_string().contains("timeout"))
};
let res = reader.read(&mut [0; 4]).await;
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unexpected);
assert!(err.to_string().contains("timeout"))
}

timeout(Duration::from_secs(2), fut)
#[tokio::test]
async fn test_list_timeout() {
let acc = Arc::new(TypeEraseLayer.layer(MockService)) as FusedAccessor;
let op = Operator::from_inner(acc).layer(
TimeoutLayer::new()
.with_timeout(Duration::from_secs(1))
.with_io_timeout(Duration::from_secs(1)),
);

let mut lister = op.lister("test").await.unwrap();

let res = lister.next().await.unwrap();
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unexpected);
assert!(err.to_string().contains("timeout"))
}

#[tokio::test]
async fn test_list_timeout_raw() {
let acc = MockService;
let timeout_layer = TimeoutLayer::new()
.with_timeout(Duration::from_secs(1))
.with_io_timeout(Duration::from_secs(1));
let timeout_acc = timeout_layer.layer(acc);

let (_, mut lister) = Accessor::list(&timeout_acc, "test", OpList::default())
.await
.expect("this test should not exceed 2 seconds")
.unwrap();

use oio::ListExt;
let res = lister.next().await;
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unexpected);
assert!(err.to_string().contains("timeout"));
}
}

0 comments on commit 72e3ea0

Please sign in to comment.