From 415c831e306a4259167cf103a7c89a9ec883935d Mon Sep 17 00:00:00 2001 From: Yuya Nishihara Date: Sun, 7 Jul 2024 16:23:20 +0900 Subject: [PATCH] revset: flatten union nodes in AST to save recursion stack Maybe it'll also be good to keep RevsetExpression::Union(_) flattened, but that's not needed to get around stack overflow. The constructed expression tree is balanced. test_expand_symbol_alias() is slightly adjusted since there are more than one representation for "a|b|c" now. Fixes #4031 --- lib/src/revset.rs | 8 +++++++- lib/src/revset_parser.rs | 38 ++++++++++++++++++++++++++++++++------ lib/tests/test_revset.rs | 24 ++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/lib/src/revset.rs b/lib/src/revset.rs index d88a584d4b..357cc48a55 100644 --- a/lib/src/revset.rs +++ b/lib/src/revset.rs @@ -824,13 +824,19 @@ pub fn lower_expression( let lhs = lower_expression(lhs_node, context)?; let rhs = lower_expression(rhs_node, context)?; match op { - BinaryOp::Union => Ok(lhs.union(&rhs)), BinaryOp::Intersection => Ok(lhs.intersection(&rhs)), BinaryOp::Difference => Ok(lhs.minus(&rhs)), BinaryOp::DagRange => Ok(lhs.dag_range_to(&rhs)), BinaryOp::Range => Ok(lhs.range(&rhs)), } } + ExpressionKind::UnionAll(nodes) => { + let expressions: Vec<_> = nodes + .iter() + .map(|node| lower_expression(node, context)) + .try_collect()?; + Ok(RevsetExpression::union_all(&expressions)) + } ExpressionKind::FunctionCall(function) => lower_function_call(function, context), ExpressionKind::Modifier(modifier) => { let name = modifier.name; diff --git a/lib/src/revset_parser.rs b/lib/src/revset_parser.rs index 6c0c315e0c..b860f5041b 100644 --- a/lib/src/revset_parser.rs +++ b/lib/src/revset_parser.rs @@ -311,6 +311,8 @@ pub enum ExpressionKind<'i> { RangeAll, Unary(UnaryOp, Box>), Binary(BinaryOp, Box>, Box>), + /// `x | y | ..` + UnionAll(Vec>), FunctionCall(Box>), /// `name: body` Modifier(Box>), @@ -341,6 +343,10 @@ impl<'i> FoldableExpression<'i> for ExpressionKind<'i> { let rhs = Box::new(folder.fold_expression(*rhs)?); Ok(ExpressionKind::Binary(op, lhs, rhs)) } + ExpressionKind::UnionAll(nodes) => { + let nodes = dsl_util::fold_expression_nodes(folder, nodes)?; + Ok(ExpressionKind::UnionAll(nodes)) + } ExpressionKind::FunctionCall(function) => folder.fold_function_call(function, span), ExpressionKind::Modifier(modifier) => { let modifier = Box::new(ModifierNode { @@ -392,8 +398,6 @@ pub enum UnaryOp { #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum BinaryOp { - /// `|` - Union, /// `&` Intersection, /// `~` @@ -418,6 +422,20 @@ pub struct ModifierNode<'i> { pub body: ExpressionNode<'i>, } +fn union_nodes<'i>(lhs: ExpressionNode<'i>, rhs: ExpressionNode<'i>) -> ExpressionNode<'i> { + let span = lhs.span.start_pos().span(&rhs.span.end_pos()); + let expr = match lhs.kind { + // Flatten "x | y | z" to save recursion stack. Machine-generated query + // might have long chain of unions. + ExpressionKind::UnionAll(mut nodes) => { + nodes.push(rhs); + ExpressionKind::UnionAll(nodes) + } + _ => ExpressionKind::UnionAll(vec![lhs, rhs]), + }; + ExpressionNode::new(expr, span) +} + pub(super) fn parse_program(revset_str: &str) -> Result { let mut pairs = RevsetParser::parse(Rule::program, revset_str)?; let first = pairs.next().unwrap(); @@ -551,7 +569,7 @@ fn parse_expression_node(pairs: Pairs) -> Result BinaryOp::Union, + Rule::union_op => return Ok(union_nodes(lhs?, rhs?)), Rule::compat_add_op => Err(not_infix_op(&op, "|", "union"))?, Rule::intersection_op => BinaryOp::Intersection, Rule::difference_op => BinaryOp::Difference, @@ -883,6 +901,10 @@ mod tests { let rhs = Box::new(normalize_tree(*rhs)); ExpressionKind::Binary(op, lhs, rhs) } + ExpressionKind::UnionAll(nodes) => { + let nodes = normalize_list(nodes); + ExpressionKind::UnionAll(nodes) + } ExpressionKind::FunctionCall(function) => { let function = Box::new(normalize_function_call(*function)); ExpressionKind::FunctionCall(function) @@ -1067,7 +1089,11 @@ mod tests { // Parse the "union" operator assert_matches!( parse_into_kind("foo | bar"), - Ok(ExpressionKind::Binary(BinaryOp::Union, _, _)) + Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 2 + ); + assert_matches!( + parse_into_kind("foo | bar | baz"), + Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 3 ); // Parse the "difference" operator assert_matches!( @@ -1479,8 +1505,8 @@ mod tests { #[test] fn test_expand_symbol_alias() { assert_eq!( - with_aliases([("AB", "a|b")]).parse_normalized("AB|c"), - parse_normalized("(a|b)|c") + with_aliases([("AB", "a&b")]).parse_normalized("AB|c"), + parse_normalized("(a&b)|c") ); assert_eq!( with_aliases([("AB", "a|b")]).parse_normalized("AB::heads(AB)"), diff --git a/lib/tests/test_revset.rs b/lib/tests/test_revset.rs index 86ca9260d0..91e091f6b6 100644 --- a/lib/tests/test_revset.rs +++ b/lib/tests/test_revset.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::iter; use std::path::Path; use assert_matches::assert_matches; @@ -2641,6 +2642,29 @@ fn test_evaluate_expression_union() { ); } +#[test] +fn test_evaluate_expression_machine_generated_union() { + let settings = testutils::user_settings(); + let test_repo = TestRepo::init(); + let repo = &test_repo.repo; + + let mut tx = repo.start_transaction(&settings); + let mut_repo = tx.mut_repo(); + let mut graph_builder = CommitGraphBuilder::new(&settings, mut_repo); + let commit1 = graph_builder.initial_commit(); + let commit2 = graph_builder.commit_with_parents(&[&commit1]); + + // This query shouldn't trigger stack overflow. Here we use "x::y" in case + // we had optimization path for trivial "commit_id|.." expression. + let revset_str = iter::repeat(format!("({}::{})", commit1.id().hex(), commit2.id().hex())) + .take(5000) + .join("|"); + assert_eq!( + resolve_commit_ids(mut_repo, &revset_str), + vec![commit2.id().clone(), commit1.id().clone()] + ); +} + #[test] fn test_evaluate_expression_intersection() { let settings = testutils::user_settings();