From ec2592b33a8f34a038ac68db4e8b3873d35cde6f Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 21 Apr 2024 17:22:59 +0800 Subject: [PATCH 1/7] minor: avoid cloning the `SetExpr` during planning of `SelectInto` (#10152) * minor: avoid cloning the `SetExpr` during planning of `SelectInto` * retry ci --- datafusion/sql/src/query.rs | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ba876d052f5e2..058496e88367c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -46,29 +46,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { query: Query, planner_context: &mut PlannerContext, ) -> Result { - let set_expr = query.body; + let mut set_expr = query.body; if let Some(with) = query.with { self.plan_with_clause(with, planner_context)?; } - let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; - let plan = self.order_by(plan, query.order_by, planner_context)?; - let plan = self.limit(plan, query.offset, query.limit)?; - - let plan = match *set_expr { - SetExpr::Select(select) if select.into.is_some() => { - let select_into = select.into.unwrap(); - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - name: self.object_name_to_table_reference(select_into.name)?, - constraints: Constraints::empty(), - input: Arc::new(plan), - if_not_exists: false, - or_replace: false, - column_defaults: vec![], - })) - } - _ => plan, + // Take the `SelectInto` for later processing. + let select_into = match set_expr.as_mut() { + SetExpr::Select(select) => select.into.take(), + _ => None, }; - + let plan = self.set_expr_to_plan(*set_expr, planner_context)?; + let plan = self.order_by(plan, query.order_by, planner_context)?; + let mut plan = self.limit(plan, query.offset, query.limit)?; + if let Some(into) = select_into { + plan = LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + name: self.object_name_to_table_reference(into.name)?, + constraints: Constraints::empty(), + input: Arc::new(plan), + if_not_exists: false, + or_replace: false, + column_defaults: vec![], + })) + } Ok(plan) } From 70d1a5d795f0a4f16ba80c98df28ecc556c6da8f Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Sun, 21 Apr 2024 21:06:55 +1000 Subject: [PATCH 2/7] Add distinct aggregate tests to sqllogictest (#10158) * Add distinct aggregate tests to sqllogictest * Update tests --- .../sqllogictest/test_files/aggregate.slt | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 457cd11211f19..c25f6d50b3a3c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -373,6 +373,15 @@ SELECT var(c2) FROM aggregate_test_100 ---- 1.886363636364 +# csv_query_distinct_variance +query R +SELECT var(distinct c2) FROM aggregate_test_100 +---- +2.5 + +statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available +SELECT var(c2), var(distinct c2) FROM aggregate_test_100 + # csv_query_variance_5 query R SELECT var_samp(c2) FROM aggregate_test_100 @@ -457,6 +466,24 @@ SELECT median(col_i8) FROM median_table ---- -14 +# distinct_median_i8 +query I +SELECT median(distinct col_i8) FROM median_table +---- +100 + +statement error DataFusion error: This feature is not implemented: MEDIAN\(DISTINCT\) aggregations are not available +SELECT median(col_i8), median(distinct col_i8) FROM median_table + +# approx_distinct_median_i8 +query I +SELECT approx_median(distinct col_i8) FROM median_table +---- +100 + +statement error DataFusion error: This feature is not implemented: APPROX_MEDIAN\(DISTINCT\) aggregations are not available +SELECT approx_median(col_i8), approx_median(distinct col_i8) FROM median_table + # median_i16 query I SELECT median(col_i16) FROM median_table @@ -2498,6 +2525,15 @@ select avg(x_dict) from value_dict; ---- 2.625 +# distinct_average +query R +select avg(distinct x_dict) from value_dict; +---- +3 + +statement error DataFusion error: This feature is not implemented: AVG\(DISTINCT\) aggregations are not available +select avg(x_dict), avg(distinct x_dict) from value_dict; + query I select min(x_dict) from value_dict; ---- From 6f5d6a224385108b7109305336c8f91fad2f9eb0 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Sun, 21 Apr 2024 21:09:31 +1000 Subject: [PATCH 3/7] Add test for LIKE newline handling (#10160) --- datafusion/sqllogictest/test_files/regexp.slt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 19966be2095b4..a45ce3718bc40 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -309,5 +309,16 @@ SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'x ---- fooxx +# multiline string +query B +SELECT 'foo\nbar\nbaz' ~ 'bar'; +---- +true + +query B +SELECT 'foo\nbar\nbaz' LIKE '%bar%'; +---- +true + statement ok drop table t; From 16e3831734358d2c628c7ff281cddd680dc4aa10 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sun, 21 Apr 2024 06:10:43 -0500 Subject: [PATCH 4/7] add test and remove panics (#10150) --- datafusion/sql/src/unparser/expr.rs | 107 +++++++++++++++++++----- datafusion/sql/tests/sql_integration.rs | 1 + 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 623f61fb600de..b99cfe11f0630 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -319,7 +319,28 @@ impl Unparser<'_> { expr: Box::new(sql_parser_expr), }) } - _ => not_impl_err!("Unsupported expression: {expr:?}"), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::IsNull(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::GetIndexedField(_) => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::TryCast(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::Wildcard { qualifier: _ } => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::GroupingSet(_) => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::Placeholder(_) => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Unsupported Expr conversion: {expr:?}") + } + Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), } } @@ -638,29 +659,71 @@ impl Unparser<'_> { } DataType::Date32 => Ok(ast::DataType::Date), DataType::Date64 => Ok(ast::DataType::Datetime(None)), - DataType::Time32(_) => todo!(), - DataType::Time64(_) => todo!(), - DataType::Duration(_) => todo!(), - DataType::Interval(_) => todo!(), - DataType::Binary => todo!(), - DataType::FixedSizeBinary(_) => todo!(), - DataType::LargeBinary => todo!(), - DataType::BinaryView => todo!(), + DataType::Time32(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Time64(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Duration(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Interval(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Binary => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::FixedSizeBinary(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::LargeBinary => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::BinaryView => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } DataType::Utf8 => Ok(ast::DataType::Varchar(None)), DataType::LargeUtf8 => Ok(ast::DataType::Text), - DataType::Utf8View => todo!(), - DataType::List(_) => todo!(), - DataType::FixedSizeList(_, _) => todo!(), - DataType::LargeList(_) => todo!(), - DataType::ListView(_) => todo!(), - DataType::LargeListView(_) => todo!(), - DataType::Struct(_) => todo!(), - DataType::Union(_, _) => todo!(), - DataType::Dictionary(_, _) => todo!(), - DataType::Decimal128(_, _) => todo!(), - DataType::Decimal256(_, _) => todo!(), - DataType::Map(_, _) => todo!(), - DataType::RunEndEncoded(_, _) => todo!(), + DataType::Utf8View => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::List(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::FixedSizeList(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::LargeList(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::ListView(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::LargeListView(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Struct(_) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Union(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Dictionary(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Decimal128(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Decimal256(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::Map(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } + DataType::RunEndEncoded(_, _) => { + not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + } } } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 410cbdad747ab..da1baf65de342 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4613,6 +4613,7 @@ fn roundtrip_statement() -> Result<()> { "select * from (select id, first_name from (select * from person))", "select id, count(*) as cnt from (select id from person) group by id", "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", + "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", r#"select "First Name" from person_quoted_cols"#, r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", From eb72debc2abb3fb2f132c56da8986794d523e39d Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Sun, 21 Apr 2024 15:46:34 +0300 Subject: [PATCH 5/7] Support Duration and Union types in ScalarValue::iter_to_array (#10139) --- datafusion/common/src/scalar/mod.rs | 19 ++++++++++++++++--- .../sqllogictest/test_files/aggregate.slt | 6 ++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 88d40a35585d7..365898abc3d71 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1575,6 +1575,18 @@ impl ScalarValue { tz ) } + DataType::Duration(TimeUnit::Second) => { + build_array_primitive!(DurationSecondArray, DurationSecond) + } + DataType::Duration(TimeUnit::Millisecond) => { + build_array_primitive!(DurationMillisecondArray, DurationMillisecond) + } + DataType::Duration(TimeUnit::Microsecond) => { + build_array_primitive!(DurationMicrosecondArray, DurationMicrosecond) + } + DataType::Duration(TimeUnit::Nanosecond) => { + build_array_primitive!(DurationNanosecondArray, DurationNanosecond) + } DataType::Interval(IntervalUnit::DayTime) => { build_array_primitive!(IntervalDayTimeArray, IntervalDayTime) } @@ -1605,7 +1617,10 @@ impl ScalarValue { let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); arrow::compute::concat(arrays.as_slice())? } - DataType::List(_) | DataType::LargeList(_) | DataType::Struct(_) => { + DataType::List(_) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); arrow::compute::concat(arrays.as_slice())? @@ -1673,8 +1688,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::Duration(_) - | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) | DataType::Utf8View diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c25f6d50b3a3c..030b8ef8ce7db 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1983,6 +1983,12 @@ SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; +# aggregate_duration_array_agg +query T? +SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(Millisecond, None)')) FROM t GROUP BY tag ORDER BY tag; +---- +X [0 days 0 hours 0 mins 0.011 secs, 0 days 0 hours 0 mins 0.123 secs] +Y [, 0 days 0 hours 0 mins 0.432 secs] statement ok drop table t_source; From fc34dacdb9842cde4d056d5a659796ede4ae5e74 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Sun, 21 Apr 2024 23:09:01 +1000 Subject: [PATCH 6/7] chore(deps): update sqlparser requirement from 0.44.0 to 0.45.0 (#10137) * chore(deps): update sqlparser requirement from 0.44.0 to 0.45.0 * Bump datafusion-cli Cargo.lock * Enhance error messages --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 4 +- .../common/src/functional_dependencies.rs | 44 +++++++++++++------ datafusion/sql/src/expr/mod.rs | 1 + datafusion/sql/src/expr/value.rs | 8 +++- datafusion/sql/src/statement.rs | 22 +++++++++- 6 files changed, 60 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8876f41986428..3002a5760fbcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -103,7 +103,7 @@ parquet = { version = "51.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.19.0" serde_json = "1" -sqlparser = { version = "0.44.0", features = ["visitor"] } +sqlparser = { version = "0.45.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9a27d7fff923b..ba3e68e4011fb 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3257,9 +3257,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.44.0" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf9c7ff146298ffda83a200f8d5084f08dcee1edfc135fcc1d646a45d50ffd6" +checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" dependencies = [ "log", "sqlparser_derive", diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 2eab0ece6d8b5..d1c3747b52b4c 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -68,32 +68,48 @@ impl Constraints { let constraints = constraints .iter() .map(|c: &TableConstraint| match c { - TableConstraint::Unique { - columns, - is_primary, - .. - } => { + TableConstraint::Unique { name, columns, .. } => { let field_names = df_schema.field_names(); - // Get primary key and/or unique indices in the schema: + // Get unique constraint indices in the schema: let indices = columns .iter() - .map(|pk| { + .map(|u| { let idx = field_names .iter() - .position(|item| *item == pk.value) + .position(|item| *item == u.value) .ok_or_else(|| { + let name = name + .as_ref() + .map(|name| format!("with name '{name}' ")) + .unwrap_or("".to_string()); DataFusionError::Execution( - "Primary key doesn't exist".to_string(), + format!("Column for unique constraint {}not found in schema: {}", name,u.value) ) })?; Ok(idx) }) .collect::>>()?; - Ok(if *is_primary { - Constraint::PrimaryKey(indices) - } else { - Constraint::Unique(indices) - }) + Ok(Constraint::Unique(indices)) + } + TableConstraint::PrimaryKey { columns, .. } => { + let field_names = df_schema.field_names(); + // Get primary key indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = field_names + .iter() + .position(|item| *item == pk.value) + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Column for primary key not found in schema: {}", + pk.value + )) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::PrimaryKey(indices)) } TableConstraint::ForeignKey { .. } => { _plan_err!("Foreign key constraints are not currently supported") diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index f07377ce50e12..7d4e745eb21e9 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -194,6 +194,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::MapAccess { column, keys } => { if let SQLExpr::Identifier(id) = *column { + let keys = keys.into_iter().map(|mak| mak.key).collect(); self.plan_indexed( col(self.normalizer.normalize(id)), keys, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 8d19b32b8e40c..25857db839c8b 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -215,13 +215,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("Unsupported interval operator: {op:?}"); } }; - match (interval.leading_field, left.as_ref(), right.as_ref()) { + match ( + interval.leading_field.as_ref(), + left.as_ref(), + right.as_ref(), + ) { (_, _, SQLExpr::Value(_)) => { let left_expr = self.sql_interval_to_expr( negative, Interval { value: left, - leading_field: interval.leading_field, + leading_field: interval.leading_field.clone(), leading_precision: None, last_field: None, fractional_seconds_precision: None, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1bb024733c343..53fbfb0552bb4 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -94,13 +94,27 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::Unique { name: name.clone(), columns: vec![column.name.clone()], - is_primary: *is_primary, characteristics: *characteristics, + index_name: None, + index_type_display: ast::KeyOrIndexDisplay::None, + index_type: None, + index_options: vec![], + }), + ast::ColumnOption::Unique { + is_primary: true, + characteristics, + } => constraints.push(ast::TableConstraint::PrimaryKey { + name: name.clone(), + columns: vec![column.name.clone()], + characteristics: *characteristics, + index_name: None, + index_type: None, + index_options: vec![], }), ast::ColumnOption::ForeignKey { foreign_table, @@ -465,6 +479,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_alias, replace_into, priority, + insert_alias, } => { if or.is_some() { plan_err!("Inserts with or clauses not supported")?; @@ -503,6 +518,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Inserts with a `PRIORITY` clause not supported: {priority:?}" )? }; + if insert_alias.is_some() { + plan_err!("Inserts with an alias not supported")?; + } let _ = into; // optional keyword doesn't change behavior self.insert_to_plan(table_name, columns, source, overwrite) } From 70db5eab8996af4816958f798f6ee887dffb69ed Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Sun, 21 Apr 2024 19:37:10 +0300 Subject: [PATCH 7/7] fix: duplicate output for HashJoinExec in CollectLeft mode (#9757) * fix: duplicate output for HashJoinExec in CollectLeft mode * address review comments * test fix after merging main --- .../src/physical_optimizer/join_selection.rs | 73 ++--- .../physical-plan/src/joins/hash_join.rs | 291 ++++++++++++++---- datafusion/sqllogictest/test_files/joins.slt | 14 +- 3 files changed, 261 insertions(+), 117 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 72174b0e6e2f6..f7512cb6d0756 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -305,11 +305,6 @@ impl PhysicalOptimizerRule for JoinSelection { /// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. /// When the `ignore_threshold` is false, this function will also check left /// and right sizes in bytes or rows. -/// -/// For [`JoinType::Full`], it can not use `CollectLeft` mode and will return `None`. -/// For [`JoinType::Left`] and [`JoinType::LeftAnti`], it can not run `CollectLeft` -/// mode as is, but it can do so by changing the join type to [`JoinType::Right`] -/// and [`JoinType::RightAnti`], respectively. fn try_collect_left( hash_join: &HashJoinExec, ignore_threshold: bool, @@ -318,38 +313,20 @@ fn try_collect_left( ) -> Result>> { let left = hash_join.left(); let right = hash_join.right(); - let join_type = hash_join.join_type(); - let left_can_collect = match join_type { - JoinType::Left | JoinType::Full | JoinType::LeftAnti => false, - JoinType::Inner - | JoinType::LeftSemi - | JoinType::Right - | JoinType::RightSemi - | JoinType::RightAnti => { - ignore_threshold - || supports_collect_by_thresholds( - &**left, - threshold_byte_size, - threshold_num_rows, - ) - } - }; - let right_can_collect = match join_type { - JoinType::Right | JoinType::Full | JoinType::RightAnti => false, - JoinType::Inner - | JoinType::RightSemi - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti => { - ignore_threshold - || supports_collect_by_thresholds( - &**right, - threshold_byte_size, - threshold_num_rows, - ) - } - }; + let left_can_collect = ignore_threshold + || supports_collect_by_thresholds( + &**left, + threshold_byte_size, + threshold_num_rows, + ); + let right_can_collect = ignore_threshold + || supports_collect_by_thresholds( + &**right, + threshold_byte_size, + threshold_num_rows, + ); + match (left_can_collect, right_can_collect) { (true, true) => { if should_swap_join_order(&**left, &**right)? @@ -916,9 +893,9 @@ mod tests_statistical { } #[tokio::test] - async fn test_left_join_with_swap() { + async fn test_left_join_no_swap() { let (big, small) = create_big_and_small(); - // Left out join should alway swap when the mode is PartitionMode::CollectLeft, even left side is small and right side is large + let join = Arc::new( HashJoinExec::try_new( Arc::clone(&small), @@ -942,32 +919,18 @@ mod tests_statistical { .optimize(join.clone(), &ConfigOptions::new()) .unwrap(); - let swapping_projection = optimized_join - .as_any() - .downcast_ref::() - .expect("A proj is required to swap columns back to their original order"); - - assert_eq!(swapping_projection.expr().len(), 2); - let (col, name) = &swapping_projection.expr()[0]; - assert_eq!(name, "small_col"); - assert_col_expr(col, "small_col", 1); - let (col, name) = &swapping_projection.expr()[1]; - assert_eq!(name, "big_col"); - assert_col_expr(col, "big_col", 0); - - let swapped_join = swapping_projection - .input() + let swapped_join = optimized_join .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); assert_eq!( swapped_join.left().statistics().unwrap().total_byte_size, - Precision::Inexact(2097152) + Precision::Inexact(8192) ); assert_eq!( swapped_join.right().statistics().unwrap().total_byte_size, - Precision::Inexact(8192) + Precision::Inexact(2097152) ); crosscheck_plans(join.clone()).unwrap(); } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 1c0181c2e1165..2b553135ada14 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -19,6 +19,7 @@ use std::fmt; use std::mem::size_of; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; use std::{any::Any, usize, vec}; @@ -72,6 +73,9 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use futures::{ready, Stream, StreamExt, TryStreamExt}; +use parking_lot::Mutex; + +type SharedBitmapBuilder = Mutex; /// HashTable and input data for the left (build side) of a join struct JoinLeftData { @@ -79,6 +83,11 @@ struct JoinLeftData { hash_map: JoinHashMap, /// The input rows for the build side batch: RecordBatch, + /// Shared bitmap builder for visited left indices + visited_indices_bitmap: Mutex, + /// Counter of running probe-threads, potentially + /// able to update `visited_indices_bitmap` + probe_threads_counter: AtomicUsize, /// Memory reservation that tracks memory used by `hash_map` hash table /// `batch`. Cleared on drop. #[allow(dead_code)] @@ -90,20 +99,19 @@ impl JoinLeftData { fn new( hash_map: JoinHashMap, batch: RecordBatch, + visited_indices_bitmap: SharedBitmapBuilder, + probe_threads_counter: AtomicUsize, reservation: MemoryReservation, ) -> Self { Self { hash_map, batch, + visited_indices_bitmap, + probe_threads_counter, reservation, } } - /// Returns the number of rows in the build side - fn num_rows(&self) -> usize { - self.batch.num_rows() - } - /// return a reference to the hash map fn hash_map(&self) -> &JoinHashMap { &self.hash_map @@ -113,6 +121,17 @@ impl JoinLeftData { fn batch(&self) -> &RecordBatch { &self.batch } + + /// returns a reference to the visited indices bitmap + fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { + &self.visited_indices_bitmap + } + + /// Decrements the counter of running threads, and returns `true` + /// if caller is the last running thread + fn report_probe_completed(&self) -> bool { + self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + } } /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple @@ -715,6 +734,8 @@ impl ExecutionPlan for HashJoinExec { context.clone(), join_metrics.clone(), reservation, + need_produce_result_in_final(self.join_type), + self.right().output_partitioning().partition_count(), ) }), PartitionMode::Partitioned => { @@ -730,6 +751,8 @@ impl ExecutionPlan for HashJoinExec { context.clone(), join_metrics.clone(), reservation, + need_produce_result_in_final(self.join_type), + 1, )) } PartitionMode::Auto => { @@ -742,9 +765,6 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); - let reservation = MemoryConsumer::new(format!("HashJoinStream[{partition}]")) - .register(context.memory_pool()); - // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context)?; @@ -769,7 +789,6 @@ impl ExecutionPlan for HashJoinExec { random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - reservation, state: HashJoinStreamState::WaitBuildSide, build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, @@ -808,6 +827,7 @@ impl ExecutionPlan for HashJoinExec { /// Reads the left (build) side of the input, buffering it in memory, to build a /// hash table (`LeftJoinData`) +#[allow(clippy::too_many_arguments)] async fn collect_left_input( partition: Option, random_state: RandomState, @@ -816,6 +836,8 @@ async fn collect_left_input( context: Arc, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, + with_visited_indices_bitmap: bool, + probe_threads_count: usize, ) -> Result { let schema = left.schema(); @@ -892,10 +914,29 @@ async fn collect_left_input( )?; offset += batch.num_rows(); } - // Merge all batches into a single batch, so we - // can directly index into the arrays + // Merge all batches into a single batch, so we can directly index into the arrays let single_batch = concat_batches(&schema, batches_iter)?; - let data = JoinLeftData::new(hashmap, single_batch, reservation); + + // Reserve additional memory for visited indices bitmap and create shared builder + let visited_indices_bitmap = if with_visited_indices_bitmap { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let data = JoinLeftData::new( + hashmap, + single_batch, + Mutex::new(visited_indices_bitmap), + AtomicUsize::new(probe_threads_count), + reservation, + ); Ok(data) } @@ -965,10 +1006,6 @@ struct BuildSideInitialState { struct BuildSideReadyState { /// Collected build-side data left_data: Arc, - /// Which build-side rows have been matched while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct output. - visited_left_side: BooleanBufferBuilder, } impl BuildSide { @@ -1087,8 +1124,6 @@ struct HashJoinStream { column_indices: Vec, /// If null_equals_null is true, null == null else null != null null_equals_null: bool, - /// Memory reservation - reservation: MemoryReservation, /// State of the stream state: HashJoinStreamState, /// Build side @@ -1250,6 +1285,14 @@ pub fn equal_rows_arr( )) } +fn get_final_indices_from_shared_bitmap( + shared_bitmap: &SharedBitmapBuilder, + join_type: JoinType, +) -> (UInt64Array, UInt32Array) { + let bitmap = shared_bitmap.lock(); + get_final_indices_from_bit_map(&bitmap, join_type) +} + impl HashJoinStream { /// Separate implementation function that unpins the [`HashJoinStream`] so /// that partial borrows work correctly @@ -1292,35 +1335,8 @@ impl HashJoinStream { .get_shared(cx))?; build_timer.done(); - // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet - // and join_type requires to store it - if need_produce_result_in_final(self.join_type) { - // TODO: Replace `ceil` wrapper with stable `div_cell` after - // https://github.com/rust-lang/rust/issues/88581 - let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); - self.reservation.try_grow(visited_bitmap_size)?; - self.join_metrics.build_mem_used.add(visited_bitmap_size); - } - - let visited_left_side = if need_produce_result_in_final(self.join_type) { - let num_rows = left_data.num_rows(); - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - }; - self.state = HashJoinStreamState::FetchProbeBatch; - self.build_side = BuildSide::Ready(BuildSideReadyState { - left_data, - visited_left_side, - }); + self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -1405,8 +1421,9 @@ impl HashJoinStream { // mark joined left-side indices as visited, if required by join type if need_produce_result_in_final(self.join_type) { + let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); left_indices.iter().flatten().for_each(|x| { - build_side.visited_left_side.set_bit(x as usize, true); + bitmap.set_bit(x as usize, true); }); } @@ -1485,15 +1502,20 @@ impl HashJoinStream { if !need_produce_result_in_final(self.join_type) { self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); } let build_side = self.build_side.try_as_ready()?; + if !build_side.left_data.report_probe_completed() { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let (left_side, right_side) = get_final_indices_from_shared_bitmap( + build_side.left_data.visited_indices_bitmap(), + self.join_type, + ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result let result = build_batch_from_indices( @@ -1644,26 +1666,73 @@ mod tests { join_type: &JoinType, null_equals_null: bool, context: Arc, + ) -> Result<(Vec, Vec)> { + join_collect_with_partition_mode( + left, + right, + on, + join_type, + PartitionMode::Partitioned, + null_equals_null, + context, + ) + .await + } + + async fn join_collect_with_partition_mode( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + partition_mode: PartitionMode, + null_equals_null: bool, + context: Arc, ) -> Result<(Vec, Vec)> { let partition_count = 4; let (left_expr, right_expr) = on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); - let join = HashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( + let left_repartitioned: Arc = match partition_mode { + PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)), + PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(left_expr, partition_count), )?), - Arc::new(RepartitionExec::try_new( + PartitionMode::Auto => { + return internal_err!("Unexpected PartitionMode::Auto in join tests") + } + }; + + let right_repartitioned: Arc = match partition_mode { + PartitionMode::CollectLeft => { + let partition_column_name = right.schema().field(0).name().clone(); + let partition_expr = vec![Arc::new(Column::new_with_schema( + &partition_column_name, + &right.schema(), + )?) as _]; + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(partition_expr, partition_count), + )?) as _ + } + PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new( right, Partitioning::Hash(right_expr, partition_count), )?), + PartitionMode::Auto => { + return internal_err!("Unexpected PartitionMode::Auto in join tests") + } + }; + + let join = HashJoinExec::try_new( + left_repartitioned, + right_repartitioned, on, None, join_type, None, - PartitionMode::Partitioned, + partition_mode, null_equals_null, )?; @@ -3316,6 +3385,120 @@ mod tests { Ok(()) } + /// Test for parallelised HashJoinExec with PartitionMode::CollectLeft + #[tokio::test] + async fn test_collect_left_multiple_partitions_join() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + )]; + + let expected_inner = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + let expected_left = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + let expected_right = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + let expected_full = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + let expected_left_semi = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + let expected_left_anti = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "+----+----+----+", + ]; + let expected_right_semi = vec![ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "| 10 | 4 | 70 |", + "| 20 | 5 | 80 |", + "+----+----+----+", + ]; + let expected_right_anti = vec![ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + + let test_cases = vec![ + (JoinType::Inner, expected_inner), + (JoinType::Left, expected_left), + (JoinType::Right, expected_right), + (JoinType::Full, expected_full), + (JoinType::LeftSemi, expected_left_semi), + (JoinType::LeftAnti, expected_left_anti), + (JoinType::RightSemi, expected_right_semi), + (JoinType::RightAnti, expected_right_anti), + ]; + + for (join_type, expected) in test_cases { + let (_, batches) = join_collect_with_partition_mode( + left.clone(), + right.clone(), + on.clone(), + &join_type, + PartitionMode::CollectLeft, + false, + task_ctx.clone(), + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + } + + Ok(()) + } + #[tokio::test] async fn join_date32() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index aa84031d55bd7..d999734ba70ed 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3662,16 +3662,14 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=Partitioned, join_type=Full, on=[(e@0, c@0)] +03)----HashJoinExec: mode=CollectLeft, join_type=Full, on=[(e@0, c@0)] 04)------ProjectionExec: expr=[1 as e, 3 as f] 05)--------PlaceholderRowExec -06)------CoalesceBatchesExec: target_batch_size=2 -07)--------RepartitionExec: partitioning=Hash([c@0], 1), input_partitions=2 -08)----------UnionExec -09)------------ProjectionExec: expr=[1 as c, 2 as d] -10)--------------PlaceholderRowExec -11)------------ProjectionExec: expr=[1 as c, 3 as d] -12)--------------PlaceholderRowExec +06)------UnionExec +07)--------ProjectionExec: expr=[1 as c, 2 as d] +08)----------PlaceholderRowExec +09)--------ProjectionExec: expr=[1 as c, 3 as d] +10)----------PlaceholderRowExec query IIII rowsort SELECT * FROM (