From f90b06180870e9cda32ca090cc9a001135796f01 Mon Sep 17 00:00:00 2001 From: Yuya Nishihara Date: Sun, 7 Jul 2024 16:04:04 +0900 Subject: [PATCH] fileset: flatten union nodes in AST to save recursion stack This is somewhat similar to templater where "x ++ y" operator is special cased. --- lib/src/fileset.rs | 21 ++++++++------------- lib/src/fileset_parser.rs | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/lib/src/fileset.rs b/lib/src/fileset.rs index 032373a27d..acffdbf943 100644 --- a/lib/src/fileset.rs +++ b/lib/src/fileset.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; use std::{iter, path, slice}; +use itertools::Itertools as _; use once_cell::sync::Lazy; use thiserror::Error; @@ -227,18 +228,6 @@ impl FilesetExpression { FilesetExpression::Pattern(FilePattern::PrefixPath(path)) } - /// Expression that matches either `self` or `other` (or both). - pub fn union(self, other: Self) -> Self { - match self { - // Micro optimization for "x | y | z" - FilesetExpression::UnionAll(mut expressions) => { - expressions.push(other); - FilesetExpression::UnionAll(expressions) - } - expr => FilesetExpression::UnionAll(vec![expr, other]), - } - } - /// Expression that matches any of the given `expressions`. pub fn union_all(expressions: Vec) -> Self { match expressions.len() { @@ -442,11 +431,17 @@ fn resolve_expression( let lhs = resolve_expression(path_converter, lhs_node)?; let rhs = resolve_expression(path_converter, rhs_node)?; match op { - BinaryOp::Union => Ok(lhs.union(rhs)), BinaryOp::Intersection => Ok(lhs.intersection(rhs)), BinaryOp::Difference => Ok(lhs.difference(rhs)), } } + ExpressionKind::UnionAll(nodes) => { + let expressions = nodes + .iter() + .map(|node| resolve_expression(path_converter, node)) + .try_collect()?; + Ok(FilesetExpression::union_all(expressions)) + } ExpressionKind::FunctionCall(function) => resolve_function(path_converter, function), } } diff --git a/lib/src/fileset_parser.rs b/lib/src/fileset_parser.rs index 7b0a30aada..68e7bf74bf 100644 --- a/lib/src/fileset_parser.rs +++ b/lib/src/fileset_parser.rs @@ -164,9 +164,14 @@ fn rename_rules_in_pest_error(err: pest::error::Error) -> pest::error::Err pub enum ExpressionKind<'i> { Identifier(&'i str), String(String), - StringPattern { kind: &'i str, value: String }, + StringPattern { + kind: &'i str, + value: String, + }, Unary(UnaryOp, Box>), Binary(BinaryOp, Box>, Box>), + /// `x | y | ..` + UnionAll(Vec>), FunctionCall(Box>), } @@ -178,8 +183,6 @@ pub enum UnaryOp { #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum BinaryOp { - /// `|` - Union, /// `&` Intersection, /// `~` @@ -189,6 +192,20 @@ pub enum BinaryOp { pub type ExpressionNode<'i> = dsl_util::ExpressionNode<'i, ExpressionKind<'i>>; pub type FunctionCallNode<'i> = dsl_util::FunctionCallNode<'i, ExpressionKind<'i>>; +fn union_nodes<'i>(lhs: ExpressionNode<'i>, rhs: ExpressionNode<'i>) -> ExpressionNode<'i> { + let span = lhs.span.start_pos().span(&rhs.span.end_pos()); + let expr = match lhs.kind { + // Flatten "x | y | z" to save recursion stack. Machine-generated query + // might have long chain of unions. + ExpressionKind::UnionAll(mut nodes) => { + nodes.push(rhs); + ExpressionKind::UnionAll(nodes) + } + _ => ExpressionKind::UnionAll(vec![lhs, rhs]), + }; + ExpressionNode::new(expr, span) +} + fn parse_function_call_node(pair: Pair) -> FilesetParseResult { assert_eq!(pair.as_rule(), Rule::function); let (name_pair, args_pair) = pair.into_inner().collect_tuple().unwrap(); @@ -273,7 +290,7 @@ fn parse_expression_node(pair: Pair) -> FilesetParseResult }) .map_infix(|lhs, op, rhs| { let op_kind = match op.as_rule() { - Rule::union_op => BinaryOp::Union, + Rule::union_op => return Ok(union_nodes(lhs?, rhs?)), Rule::intersection_op => BinaryOp::Intersection, Rule::difference_op => BinaryOp::Difference, r => panic!("unexpected infix operator rule {r:?}"), @@ -389,6 +406,10 @@ mod tests { let rhs = Box::new(normalize_tree(*rhs)); ExpressionKind::Binary(op, lhs, rhs) } + ExpressionKind::UnionAll(nodes) => { + let nodes = normalize_list(nodes); + ExpressionKind::UnionAll(nodes) + } ExpressionKind::FunctionCall(function) => { let function = Box::new(normalize_function_call(*function)); ExpressionKind::FunctionCall(function) @@ -525,7 +546,11 @@ mod tests { ); assert_matches!( parse_into_kind("x|y"), - Ok(ExpressionKind::Binary(BinaryOp::Union, _, _)) + Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 2 + ); + assert_matches!( + parse_into_kind("x|y|z"), + Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 3 ); assert_matches!( parse_into_kind("x&y"),