Skip to content

Commit

Permalink
fileset: flatten union nodes in AST to save recursion stack
Browse files Browse the repository at this point in the history
This is somewhat similar to templater where "x ++ y" operator is special cased.
  • Loading branch information
yuja committed Jul 11, 2024
1 parent 503771c commit f90b061
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
21 changes: 8 additions & 13 deletions lib/src/fileset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use std::collections::HashMap;
use std::{iter, path, slice};

use itertools::Itertools as _;
use once_cell::sync::Lazy;
use thiserror::Error;

Expand Down Expand Up @@ -227,18 +228,6 @@ impl FilesetExpression {
FilesetExpression::Pattern(FilePattern::PrefixPath(path))
}

/// Expression that matches either `self` or `other` (or both).
pub fn union(self, other: Self) -> Self {
match self {
// Micro optimization for "x | y | z"
FilesetExpression::UnionAll(mut expressions) => {
expressions.push(other);
FilesetExpression::UnionAll(expressions)
}
expr => FilesetExpression::UnionAll(vec![expr, other]),
}
}

/// Expression that matches any of the given `expressions`.
pub fn union_all(expressions: Vec<FilesetExpression>) -> Self {
match expressions.len() {
Expand Down Expand Up @@ -442,11 +431,17 @@ fn resolve_expression(
let lhs = resolve_expression(path_converter, lhs_node)?;
let rhs = resolve_expression(path_converter, rhs_node)?;
match op {
BinaryOp::Union => Ok(lhs.union(rhs)),
BinaryOp::Intersection => Ok(lhs.intersection(rhs)),
BinaryOp::Difference => Ok(lhs.difference(rhs)),
}
}
ExpressionKind::UnionAll(nodes) => {
let expressions = nodes
.iter()
.map(|node| resolve_expression(path_converter, node))
.try_collect()?;
Ok(FilesetExpression::union_all(expressions))
}
ExpressionKind::FunctionCall(function) => resolve_function(path_converter, function),
}
}
Expand Down
35 changes: 30 additions & 5 deletions lib/src/fileset_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,14 @@ fn rename_rules_in_pest_error(err: pest::error::Error<Rule>) -> pest::error::Err
pub enum ExpressionKind<'i> {
Identifier(&'i str),
String(String),
StringPattern { kind: &'i str, value: String },
StringPattern {
kind: &'i str,
value: String,
},
Unary(UnaryOp, Box<ExpressionNode<'i>>),
Binary(BinaryOp, Box<ExpressionNode<'i>>, Box<ExpressionNode<'i>>),
/// `x | y | ..`
UnionAll(Vec<ExpressionNode<'i>>),
FunctionCall(Box<FunctionCallNode<'i>>),
}

Expand All @@ -178,8 +183,6 @@ pub enum UnaryOp {

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum BinaryOp {
/// `|`
Union,
/// `&`
Intersection,
/// `~`
Expand All @@ -189,6 +192,20 @@ pub enum BinaryOp {
pub type ExpressionNode<'i> = dsl_util::ExpressionNode<'i, ExpressionKind<'i>>;
pub type FunctionCallNode<'i> = dsl_util::FunctionCallNode<'i, ExpressionKind<'i>>;

fn union_nodes<'i>(lhs: ExpressionNode<'i>, rhs: ExpressionNode<'i>) -> ExpressionNode<'i> {
let span = lhs.span.start_pos().span(&rhs.span.end_pos());
let expr = match lhs.kind {
// Flatten "x | y | z" to save recursion stack. Machine-generated query
// might have long chain of unions.
ExpressionKind::UnionAll(mut nodes) => {
nodes.push(rhs);
ExpressionKind::UnionAll(nodes)
}
_ => ExpressionKind::UnionAll(vec![lhs, rhs]),
};
ExpressionNode::new(expr, span)
}

fn parse_function_call_node(pair: Pair<Rule>) -> FilesetParseResult<FunctionCallNode> {
assert_eq!(pair.as_rule(), Rule::function);
let (name_pair, args_pair) = pair.into_inner().collect_tuple().unwrap();
Expand Down Expand Up @@ -273,7 +290,7 @@ fn parse_expression_node(pair: Pair<Rule>) -> FilesetParseResult<ExpressionNode>
})
.map_infix(|lhs, op, rhs| {
let op_kind = match op.as_rule() {
Rule::union_op => BinaryOp::Union,
Rule::union_op => return Ok(union_nodes(lhs?, rhs?)),
Rule::intersection_op => BinaryOp::Intersection,
Rule::difference_op => BinaryOp::Difference,
r => panic!("unexpected infix operator rule {r:?}"),
Expand Down Expand Up @@ -389,6 +406,10 @@ mod tests {
let rhs = Box::new(normalize_tree(*rhs));
ExpressionKind::Binary(op, lhs, rhs)
}
ExpressionKind::UnionAll(nodes) => {
let nodes = normalize_list(nodes);
ExpressionKind::UnionAll(nodes)
}
ExpressionKind::FunctionCall(function) => {
let function = Box::new(normalize_function_call(*function));
ExpressionKind::FunctionCall(function)
Expand Down Expand Up @@ -525,7 +546,11 @@ mod tests {
);
assert_matches!(
parse_into_kind("x|y"),
Ok(ExpressionKind::Binary(BinaryOp::Union, _, _))
Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 2
);
assert_matches!(
parse_into_kind("x|y|z"),
Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 3
);
assert_matches!(
parse_into_kind("x&y"),
Expand Down

0 comments on commit f90b061

Please sign in to comment.