diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c00106ba..d014d806 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1093,7 +1093,19 @@ def _task(self, index: int): return self.value -class Blockwise(Expr): +class BlockwiseOverlapping(Expr): + """Super class that operates like a blockwise op + + We require information from neighboring partitions, so we can't prune partitions + before lowering, but the spec is the same as for Blockwise ops, we don't reoder + things, alignment stays consistent, ... + + """ + + pass + + +class Blockwise(BlockwiseOverlapping): """Super-class for block-wise operations This is fairly generic, and includes definitions for `_meta`, `divisions`, @@ -1276,7 +1288,7 @@ def _task(self, index: int): ) -class MapOverlap(MapPartitions): +class MapOverlap(BlockwiseOverlapping): _parameters = [ "frame", "func", @@ -1296,6 +1308,29 @@ class MapOverlap(MapPartitions): "clear_divisions": False, } + def _broadcast_dep(self, dep: Expr): + return dep.npartitions == 1 + + @property + def args(self): + return [self.frame] + self.operands[len(self._parameters) :] + + def _divisions(self): + # Unknown divisions + if self.clear_divisions: + return (None,) * (self.frame.npartitions + 1) + + # (Possibly) known divisions + dfs = [arg for arg in self.args if isinstance(arg, Expr)] + return _get_divisions_map_partitions( + True, # Partitions must already be "aligned" + self.transform_divisions, + dfs, + self.func, + self.args, + self.kwargs, + ) + @functools.cached_property def _kwargs(self) -> dict: kwargs = self.kwargs @@ -2454,7 +2489,7 @@ def non_blockwise_ancestors(expr): e = stack.pop() if isinstance(e, IO): yield e - elif isinstance(e, Blockwise): + elif isinstance(e, BlockwiseOverlapping): dependencies = e.dependencies() stack.extend([expr for expr in dependencies if not is_broadcastable(expr)]) else: