Skip to content

Commit

Permalink
feat: support normalized expr in CSE (apache#13315)
Browse files Browse the repository at this point in the history
* feat: support normalized expr in CSE

* feat: support normalize_eq in cse optimization

* feat: support cumulative binary expr result in normalize_eq

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
zhuliquan and alamb authored Dec 20, 2024
1 parent 87b77bb commit 74480ac
Show file tree
Hide file tree
Showing 4 changed files with 790 additions and 32 deletions.
150 changes: 123 additions & 27 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,63 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
}
}

/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
///
/// Normalization is the process of converting a node into a canonical form that can be used
/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
pub trait Normalizeable {
fn can_normalize(&self) -> bool;
}

/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
/// normlized nodes in optimizations like Common Subexpression Elimination (CSE).
///
/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
/// are considered equal in CSE optimization, even if their original forms differ.
///
/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
/// internal representations.
pub trait NormalizeEq: Eq + Normalizeable {
fn normalize_eq(&self, other: &Self) -> bool;
}

/// Identifier that represents a [`TreeNode`] tree.
///
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
/// "have no collision (as low as possible)"
#[derive(Debug, Eq, PartialEq)]
struct Identifier<'n, N> {
#[derive(Debug, Eq)]
struct Identifier<'n, N: NormalizeEq> {
// Hash of `node` built up incrementally during the first, visiting traversal.
// Its value is not necessarily equal to default hash of the node. E.g. it is not
// equal to `expr.hash()` if the node is `Expr`.
hash: u64,
node: &'n N,
}

impl<N> Clone for Identifier<'_, N> {
impl<N: NormalizeEq> Clone for Identifier<'_, N> {
fn clone(&self) -> Self {
*self
}
}
impl<N> Copy for Identifier<'_, N> {}
impl<N: NormalizeEq> Copy for Identifier<'_, N> {}

impl<N> Hash for Identifier<'_, N> {
impl<N: NormalizeEq> Hash for Identifier<'_, N> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}

impl<'n, N: HashNode> Identifier<'n, N> {
impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash && self.node.normalize_eq(other.node)
}
}

impl<'n, N> Identifier<'n, N>
where
N: HashNode + NormalizeEq,
{
fn new(node: &'n N, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
node.hash_node(&mut hasher);
Expand Down Expand Up @@ -213,7 +243,11 @@ pub enum FoundCommonNodes<N> {
///
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
/// because they should not be recognized as common subtree.
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
struct CSEVisitor<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
node_stats: &'a mut NodeStats<'n, N>,

Expand Down Expand Up @@ -244,7 +278,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
}

/// Record item that used when traversing a [`TreeNode`] tree.
enum VisitRecord<'n, N> {
enum VisitRecord<'n, N>
where
N: NormalizeEq,
{
/// Marks the beginning of [`TreeNode`]. It contains:
/// - The post-order index assigned during the first, visiting traversal.
EnterMark(usize),
Expand All @@ -258,7 +295,11 @@ enum VisitRecord<'n, N> {
NodeItem(Identifier<'n, N>, bool),
}

impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> {
impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the [`TreeNode`] we marked.
Expand All @@ -271,17 +312,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
/// information up from children to parents via `visit_stack` during the first,
/// visiting traversal and no need to test the expression's validity beforehand with
/// an extra traversal).
fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n, N>>, bool) {
let mut node_id = None;
fn pop_enter_mark(
&mut self,
can_normalize: bool,
) -> (usize, Option<Identifier<'n, N>>, bool) {
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
let mut is_valid = true;

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index) => {
if can_normalize {
node_ids.sort_by_key(|i| i.hash);
}
let node_id = node_ids
.into_iter()
.fold(None, |accum, item| Some(item.combine(accum)));
return (down_index, node_id, is_valid);
}
VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
node_id = Some(sub_node_id.combine(node_id));
node_ids.push(sub_node_id);
is_valid &= sub_node_is_valid;
}
}
Expand All @@ -290,8 +340,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
}
}

impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n>
for CSEVisitor<'_, 'n, N, C>
impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;

Expand Down Expand Up @@ -331,7 +383,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
}

fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark();
let (down_index, sub_node_id, sub_node_is_valid) =
self.pop_enter_mark(node.can_normalize());

let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
let is_valid = C::is_valid(node) && sub_node_is_valid;
Expand Down Expand Up @@ -369,7 +422,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
/// replaced [`TreeNode`] tree.
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
struct CSERewriter<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,

Expand All @@ -386,8 +443,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
controller: &'a mut C,
}

impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
for CSERewriter<'_, '_, N, C>
impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
where
N: TreeNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;

Expand All @@ -408,13 +467,30 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
self.down_index += 1;
}

let (node, alias) =
self.common_nodes.entry(node_id).or_insert_with(|| {
let node_alias = self.controller.generate_alias();
(node, node_alias)
});

let rewritten = self.controller.rewrite(node, alias);
// We *must* replace all original nodes with same `node_id`, not just the first
// node which is inserted into the common_nodes. This is because nodes with the same
// `node_id` are semantically equivalent, but not exactly the same.
//
// For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
// In this case, we should replace the common expression `1 + a` with a new variable
// (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
// `__common_cse_1`.
//
// The final result would be:
// - `__common_cse_1 as a + 1`
// - `__common_cse_1 as 1 + a`
//
// This way, we can efficiently handle semantically equivalent expressions without
// incorrectly treating them as identical.
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
{
self.controller.rewrite(&node, alias)
} else {
let node_alias = self.controller.generate_alias();
let rewritten = self.controller.rewrite(&node, &node_alias);
self.common_nodes.insert(node_id, (node, node_alias));
rewritten
};

return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
Expand All @@ -441,7 +517,11 @@ pub struct CSE<N, C: CSEController<Node = N>> {
controller: C,
}

impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> {
impl<N, C> CSE<N, C>
where
N: TreeNode + HashNode + Clone + NormalizeEq,
C: CSEController<Node = N>,
{
pub fn new(controller: C) -> Self {
Self {
random_state: RandomState::new(),
Expand Down Expand Up @@ -557,6 +637,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
) -> Result<FoundCommonNodes<N>> {
let mut found_common = false;
let mut node_stats = NodeStats::new();

let id_arrays_list = nodes_list
.iter()
.map(|nodes| {
Expand Down Expand Up @@ -596,7 +677,10 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
#[cfg(test)]
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE};
use crate::cse::{
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
Normalizeable, CSE,
};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
use std::collections::HashSet;
Expand Down Expand Up @@ -662,6 +746,18 @@ mod test {
}
}

impl Normalizeable for TestTreeNode<String> {
fn can_normalize(&self) -> bool {
false
}
}

impl NormalizeEq for TestTreeNode<String> {
fn normalize_eq(&self, other: &Self) -> bool {
self == other
}
}

#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
Expand Down
Loading

0 comments on commit 74480ac

Please sign in to comment.