diff --git a/dask_expr/_core.py b/dask_expr/_core.py index f545c3f4..3aac656c 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -267,7 +267,7 @@ def rewrite(self, kind: str): return expr - def simplify_once(self, dependents: defaultdict): + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression This leverages the ``._simplify_down`` and ``._simplify_up`` @@ -278,12 +278,18 @@ def simplify_once(self, dependents: defaultdict): dependents: defaultdict[list] The dependents for every node. + simplified: dict + Cache of simplified expressions for these dependents. Returns ------- expr: output expression """ + # Check if we've already simplified for these dependents + if self._name in simplified: + return simplified[self._name] + expr = self while True: @@ -314,7 +320,10 @@ def simplify_once(self, dependents: defaultdict): if isinstance(operand, Expr): # Bandaid for now, waiting for Singleton dependents[operand._name].append(weakref.ref(expr)) - new = operand.simplify_once(dependents=dependents) + new = operand.simplify_once( + dependents=dependents, simplified=simplified + ) + simplified[operand._name] = new if new._name != operand._name: changed = True else: @@ -332,7 +341,7 @@ def simplify(self) -> Expr: expr = self while True: dependents = collect_dependents(expr) - new = expr.simplify_once(dependents=dependents) + new = expr.simplify_once(dependents=dependents, simplified={}) if new._name == expr._name: break expr = new diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c019cfe1..fbab747f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1742,9 +1742,7 @@ def vals(self): @functools.cached_property def _meta(self): - args = [ - meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args - ] + args = [op._meta if isinstance(op, Expr) else op for op in self._args] return make_meta(self.operation(*args, **self._kwargs)) def _tree_repr_argument_construction(self, i, op, header):