diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 52e4a000355d..c2818db41c35 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1245,12 +1245,14 @@ dependencies = [ "parquet", "paste", "sqlparser", + "tokio", ] [[package]] name = "datafusion-common-runtime" version = "41.0.0" dependencies = [ + "log", "tokio", ] diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index c10436087675..a21c72cd9f83 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -36,4 +36,8 @@ name = "datafusion_common_runtime" path = "src/lib.rs" [dependencies] +log = { workspace = true } tokio = { workspace = true } + +[dev-dependencies] +tokio = { version = "1.36", features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index 2f7ddb972f42..698a846b4844 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -60,8 +60,8 @@ impl SpawnedTask { } /// Joins the task and unwinds the panic if it happens. - pub async fn join_unwind(self) -> R { - self.join().await.unwrap_or_else(|e| { + pub async fn join_unwind(self) -> Result { + self.join().await.map_err(|e| { // `JoinError` can be caused either by panic or cancellation. We have to handle panics: if e.is_panic() { std::panic::resume_unwind(e.into_panic()); @@ -69,9 +69,43 @@ impl SpawnedTask { // Cancellation may be caused by two reasons: // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). // 2. The runtime is shutting down. - // So we consider this branch as unreachable. - unreachable!("SpawnedTask was cancelled unexpectedly"); + log::warn!("SpawnedTask was polled during shutdown"); + e } }) } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::future::{pending, Pending}; + + use tokio::runtime::Runtime; + + #[tokio::test] + async fn runtime_shutdown() { + let rt = Runtime::new().unwrap(); + let task = rt + .spawn(async { + SpawnedTask::spawn(async { + let fut: Pending<()> = pending(); + fut.await; + unreachable!("should never return"); + }) + }) + .await + .unwrap(); + + // caller shutdown their DF runtime (e.g. timeout, error in caller, etc) + rt.shutdown_background(); + + // race condition + // poll occurs during shutdown (buffered stream poll calls, etc) + assert!(matches!( + task.join_unwind().await, + Err(e) if e.is_cancelled() + )); + } +} diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 8435d0632576..79e20ba1215c 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -63,6 +63,7 @@ parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" pyo3 = { version = "0.21.0", optional = true } sqlparser = { workspace = true } +tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] instant = { version = "0.1", features = ["wasm-bindgen"] } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 27a25d0c9dd5..05988d6c6da4 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -34,6 +34,7 @@ use arrow::error::ArrowError; #[cfg(feature = "parquet")] use parquet::errors::ParquetError; use sqlparser::parser::ParserError; +use tokio::task::JoinError; /// Result type for operations that could result in an [DataFusionError] pub type Result = result::Result; @@ -112,6 +113,10 @@ pub enum DataFusionError { /// SQL method, opened a CSV file that is broken, or tried to divide an /// integer by zero. Execution(String), + /// [`JoinError`] during execution of the query. + /// + /// This error can unoccur for unjoined tasks, such as execution shutdown. + ExecutionJoin(JoinError), /// Error when resources (such as memory of scratch disk space) are exhausted. /// /// This error is thrown when a consumer cannot acquire additional memory @@ -306,6 +311,7 @@ impl Error for DataFusionError { DataFusionError::Plan(_) => None, DataFusionError::SchemaError(e, _) => Some(e), DataFusionError::Execution(_) => None, + DataFusionError::ExecutionJoin(e) => Some(e), DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), DataFusionError::Context(_, e) => Some(e.as_ref()), @@ -418,6 +424,7 @@ impl DataFusionError { DataFusionError::Configuration(_) => "Invalid or Unsupported Configuration: ", DataFusionError::SchemaError(_, _) => "Schema error: ", DataFusionError::Execution(_) => "Execution error: ", + DataFusionError::ExecutionJoin(_) => "ExecutionJoin error: ", DataFusionError::ResourcesExhausted(_) => "Resources exhausted: ", DataFusionError::External(_) => "External error: ", DataFusionError::Context(_, _) => "", @@ -453,6 +460,7 @@ impl DataFusionError { Cow::Owned(format!("{desc}{backtrace}")) } DataFusionError::Execution(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::ExecutionJoin(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::ResourcesExhausted(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::External(ref desc) => Cow::Owned(desc.to_string()), #[cfg(feature = "object_store")] diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 8b6a8800119d..95f76195e63d 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -341,7 +341,10 @@ impl DataSink for ArrowFileSink { } } - demux_task.join_unwind().await?; + demux_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(row_count as u64) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index f233f3842c8c..83f77ca9371a 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -836,7 +836,10 @@ impl DataSink for ParquetSink { } } - demux_task.join_unwind().await?; + demux_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(row_count as u64) } @@ -942,7 +945,10 @@ fn spawn_rg_join_and_finalize_task( let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - let (writer, _col_reservation) = task.join_unwind().await?; + let (writer, _col_reservation) = task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; let encoded_size = writer.get_estimated_total_bytes(); rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); @@ -1070,7 +1076,8 @@ async fn concatenate_parallel_row_groups( while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, mut rg_reservation, _cnt) = result?; + let (serialized_columns, mut rg_reservation, _cnt) = + result.map_err(DataFusionError::ExecutionJoin)??; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; rg_reservation.free(); @@ -1134,7 +1141,10 @@ async fn output_single_parquet_file_parallelized( ) .await?; - launch_serialization_task.join_unwind().await?; + launch_serialization_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(file_metadata) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 1d32063ee9f3..6f27e6f3889f 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -298,8 +298,8 @@ pub(crate) async fn stateless_multipart_put( write_coordinator_task.join_unwind(), demux_task.join_unwind() ); - r1?; - r2?; + r1.map_err(DataFusionError::ExecutionJoin)??; + r2.map_err(DataFusionError::ExecutionJoin)??; let total_count = rx_row_cnt.await.map_err(|_| { internal_datafusion_err!("Did not receive row count from write coordinator") diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 682565aea909..b53fe8663178 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -438,6 +438,9 @@ impl DataSink for StreamWrite { } } drop(sender); - write_task.join_unwind().await + write_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)? } }