Skip to content

Commit

Permalink
API: add Op suffix to all operations in xDSL (#3522)
Browse files Browse the repository at this point in the history
I recently received feedback that one of the confusing aspects of xDSL
is that it's not consistent in terms of operation class naming, unlike
MLIR. This PR changes this. To make these changes, I added a runtime
check to Operation's post_init, asserted that it ended in "Op", and
changed the class names to make everything pass, then reverted that
change, so this PR converts all the classes that are initialized at some
point in our tests to this naming convention. We might want to add that
check in the future, but I wanted to discuss this in a separate PR.


This was the check I used:
``` python
        assert (
            "cmath" in self.name or self.__class__.__name__[-2:] == "Op"
        ), self.__class__.__name__
```
  • Loading branch information
superlopuh authored Nov 27, 2024
1 parent b7e13f2 commit 9fa78bf
Show file tree
Hide file tree
Showing 174 changed files with 2,303 additions and 2,212 deletions.
50 changes: 25 additions & 25 deletions docs/Toy/toy/rewrites/lower_toy_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def convert_tensor_to_memref(type: toy.TensorTypeF64) -> MemrefTypeF64:

def insert_alloc_and_dealloc(
type: MemrefTypeF64, op: Operation, rewriter: PatternRewriter
) -> memref.Alloc:
) -> memref.AllocOp:
"""
Insert an allocation and deallocation for the given MemRefType.
"""
Expand All @@ -59,12 +59,12 @@ def insert_alloc_and_dealloc(
assert block.last_op is not None

# Make sure to allocate at the beginning of the block.
alloc = memref.Alloc.get(type.element_type, None, type.shape)
alloc = memref.AllocOp.get(type.element_type, None, type.shape)
rewriter.insert_op(alloc, InsertPoint.at_start(block))

# Make sure to deallocate this alloc at the end of the block. This is fine as toy
# functions have no control flow.
dealloc = memref.Dealloc.get(alloc)
dealloc = memref.DeallocOp.get(alloc)
rewriter.insert_op(dealloc, InsertPoint.before(block.last_op))

return alloc
Expand All @@ -83,7 +83,7 @@ def build_affine_for(
step: int,
iter_args: _ValueRange,
body_builder_fn: _AffineForOpBodyBuilderFn,
) -> affine.For:
) -> affine.ForOp:
"""
`body_builder_fn` is used to build the body of affine.for.
"""
Expand All @@ -103,7 +103,7 @@ def build_affine_for(
induction_var, *rest = block.args
region = Region(block)

op = affine.For.from_region(
op = affine.ForOp.from_region(
lb_operands,
ub_operands,
iter_args,
Expand All @@ -125,7 +125,7 @@ def build_affine_for_const(
step: int,
iter_args: _ValueRange,
body_builder_fn: _AffineForOpBodyBuilderFn,
) -> affine.For:
) -> affine.ForOp:
return build_affine_for(
builder,
(),
Expand Down Expand Up @@ -153,7 +153,7 @@ def build_affine_for_const(
_BodyBuilderFn: TypeAlias = Callable[[Builder, _ValueRange], None]
_LoopCreatorFn: TypeAlias = Callable[
[Builder, _BoundT, _BoundT, int, _AffineForOpBodyBuilderFn],
affine.For,
affine.ForOp,
]


Expand Down Expand Up @@ -193,7 +193,7 @@ def body(nested_builder: Builder, iv: SSAValue, iter_args: _ValueRange):
if i == e - 1:
body_builder_fn(nested_builder, ivs)

nested_builder.insert(affine.Yield.get())
nested_builder.insert(affine.YieldOp.get())

# Delegate actual loop creation to the callback in order to dispatch
# between constant- and variable-bound loops.
Expand All @@ -208,7 +208,7 @@ def build_affine_loop_from_constants(
ub: int,
step: int,
body_builder_fn: _AffineForOpBodyBuilderFn,
) -> affine.For:
) -> affine.ForOp:
"""
Creates an affine loop from the bounds known to be constants.
"""
Expand All @@ -221,17 +221,17 @@ def build_affine_loop_from_values(
ub: SSAValue,
step: int,
body_builder_fn: _AffineForOpBodyBuilderFn,
) -> affine.For:
) -> affine.ForOp:
"""
Creates an affine loop from the bounds that may or may not be constants.
"""
lb_const = lb.owner
ub_const = ub.owner

if (
isinstance(lb_const, arith.Constant)
isinstance(lb_const, arith.ConstantOp)
and isinstance(lb_const_value := lb_const.value, IntegerAttr)
and isinstance(ub_const, arith.Constant)
and isinstance(ub_const, arith.ConstantOp)
and isinstance(ub_const_value := ub_const.value, IntegerAttr)
):
lb_val = lb_const_value.value.data
Expand Down Expand Up @@ -301,7 +301,7 @@ def impl_loop(nested_builder: Builder, ivs: _ValueRange):
# loop induction variables. This function will return the value to store at the
# current index.
value_to_store = process_iteration(nested_builder, operands, ivs)
store_op = affine.Store(value_to_store, alloc.memref, ivs)
store_op = affine.StoreOp(value_to_store, alloc.memref, ivs)
nested_builder.insert(store_op)

builder = Builder.before(op)
Expand All @@ -326,9 +326,9 @@ def body(
builder: Builder, memref_operands: _ValueRange, loop_ivs: _ValueRange
) -> SSAValue:
# Generate loads for the element of 'lhs' and 'rhs' at the inner loop.
loaded_lhs = builder.insert(affine.Load(op.lhs, loop_ivs))
loaded_rhs = builder.insert(affine.Load(op.rhs, loop_ivs))
new_binop = builder.insert(arith.Addf(loaded_lhs, loaded_rhs))
loaded_lhs = builder.insert(affine.LoadOp(op.lhs, loop_ivs))
loaded_rhs = builder.insert(affine.LoadOp(op.rhs, loop_ivs))
new_binop = builder.insert(arith.AddfOp(loaded_lhs, loaded_rhs))
return new_binop.result

lower_op_to_loops(op, op.operands, rewriter, body)
Expand All @@ -341,9 +341,9 @@ def body(
builder: Builder, memref_operands: _ValueRange, loop_ivs: _ValueRange
) -> SSAValue:
# Generate loads for the element of 'lhs' and 'rhs' at the inner loop.
loaded_lhs = builder.insert(affine.Load(op.lhs, loop_ivs))
loaded_rhs = builder.insert(affine.Load(op.rhs, loop_ivs))
new_binop = builder.insert(arith.Mulf(loaded_lhs, loaded_rhs))
loaded_lhs = builder.insert(affine.LoadOp(op.lhs, loop_ivs))
loaded_rhs = builder.insert(affine.LoadOp(op.rhs, loop_ivs))
new_binop = builder.insert(arith.MulfOp(loaded_lhs, loaded_rhs))
return new_binop.result

lower_op_to_loops(op, op.operands, rewriter, body)
Expand All @@ -364,16 +364,16 @@ def match_and_rewrite(self, op: toy.ConstantOp, rewriter: PatternRewriter):
value_shape = memref_type.get_shape()

# Scalar constant values for elements of the tensor
constants: list[arith.Constant] = [
arith.Constant(FloatAttr(i.value.data, f64)) for i in constant_value.data
constants: list[arith.ConstantOp] = [
arith.ConstantOp(FloatAttr(i.value.data, f64)) for i in constant_value.data
]

# n-d indices of elements
_indices = product(*(range(d) for d in value_shape))

# For each n-d index into the tensor, store the corresponding scalar
stores = [
affine.Store(
affine.StoreOp(
constants[offset].result,
alloc.memref,
(),
Expand Down Expand Up @@ -427,7 +427,7 @@ def match_and_rewrite(self, op: toy.PrintOp, rewriter: PatternRewriter):

for indices in product(*(range(dim) for dim in shape)):
rewriter.insert_op_before_matched_op(
load := affine.Load(
load := affine.LoadOp(
op.input,
(),
AffineMapAttr(AffineMap.from_callable(lambda: indices)),
Expand All @@ -445,7 +445,7 @@ def match_and_rewrite(self, op: toy.ReturnOp, rewriter: PatternRewriter):
op.input is None
), "During this lowering, we expect that all function calls have been inlined."

rewriter.replace_matched_op(func.Return())
rewriter.replace_matched_op(func.ReturnOp())


class TransposeOpLowering(RewritePattern):
Expand All @@ -455,7 +455,7 @@ def body(
builder: Builder, mem_ref_operands: _ValueRange, loop_ivs: _ValueRange
) -> SSAValue:
# Transpose the elements by generating a load from the reverse indices.
load_op = affine.Load(op.arg, tuple(reversed(loop_ivs)))
load_op = affine.LoadOp(op.arg, tuple(reversed(loop_ivs)))
builder.insert(load_op)
return load_op.result

Expand Down
30 changes: 17 additions & 13 deletions docs/database_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"\n",
"\n",
"@irdl_op_definition\n",
"class Table(IRDLOperation):\n",
"class TableOp(IRDLOperation):\n",
" name = \"sql.table\"\n",
" table_name = attr_def(StringAttr)\n",
" result_bag = result_def(Bag)"
Expand Down Expand Up @@ -137,7 +137,9 @@
}
],
"source": [
"t = Table.build(attributes={\"table_name\": StringAttr(\"T\")}, result_types=[Bag([(i32)])])\n",
"t = TableOp.build(\n",
" attributes={\"table_name\": StringAttr(\"T\")}, result_types=[Bag([(i32)])]\n",
")\n",
"printer.print_op(t)"
]
},
Expand Down Expand Up @@ -189,7 +191,7 @@
"outputs": [],
"source": [
"@irdl_op_definition\n",
"class Selection(IRDLOperation):\n",
"class SelectionOp(IRDLOperation):\n",
" name = \"sql.selection\"\n",
" input_bag = operand_def(Bag)\n",
" filter = region_def()\n",
Expand Down Expand Up @@ -220,12 +222,12 @@
" # filter argument\n",
" (arg,) = args\n",
"\n",
" const1 = Constant.from_int_and_width(5, 32)\n",
" const2 = Constant.from_int_and_width(5, 32)\n",
" add = Addi(const1, const2)\n",
" cmp = Cmpi(arg, add, \"sgt\")\n",
" const1 = ConstantOp.from_int_and_width(5, 32)\n",
" const2 = ConstantOp.from_int_and_width(5, 32)\n",
" add = AddiOp(const1, const2)\n",
" cmp = CmpiOp(arg, add, \"sgt\")\n",
" # sgt stands for `signed greater than`. In xDSL, this is encoded as a predicate attribute with value 4.\n",
" Yield(cmp)"
" YieldOp(cmp)"
]
},
{
Expand All @@ -250,7 +252,7 @@
}
],
"source": [
"sel = Selection.build(result_types=[Bag([i32])], operands=[t], regions=[filter])\n",
"sel = SelectionOp.build(result_types=[Bag([i32])], operands=[t], regions=[filter])\n",
"\n",
"printer.print_op(sel)"
]
Expand All @@ -274,13 +276,15 @@
"@dataclass\n",
"class ConstantFolding(RewritePattern):\n",
" @op_type_rewrite_pattern\n",
" def match_and_rewrite(self, op: Addi, rewriter: PatternRewriter):\n",
" if isinstance(op.lhs.owner, Constant) and isinstance(op.rhs.owner, Constant):\n",
" def match_and_rewrite(self, op: AddiOp, rewriter: PatternRewriter):\n",
" if isinstance(op.lhs.owner, ConstantOp) and isinstance(\n",
" op.rhs.owner, ConstantOp\n",
" ):\n",
" lhs_data = cast(IntegerAttr[IntegerType], op.lhs.owner.value).value.data\n",
" rhs_data = cast(IntegerAttr[IntegerType], op.rhs.owner.value).value.data\n",
" lhs_type = cast(IntegerAttr[IntegerType], op.lhs.owner.value).type\n",
" rewriter.replace_matched_op(\n",
" Constant.from_int_and_width(lhs_data + rhs_data, lhs_type)\n",
" ConstantOp.from_int_and_width(lhs_data + rhs_data, lhs_type)\n",
" )"
]
},
Expand Down Expand Up @@ -357,7 +361,7 @@
"@dataclass\n",
"class DeadConstantElim(RewritePattern):\n",
" @op_type_rewrite_pattern\n",
" def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter):\n",
" def match_and_rewrite(self, op: ConstantOp, rewriter: PatternRewriter):\n",
" if len(op.result.uses) == 0:\n",
" rewriter.erase_matched_op()"
]
Expand Down
Loading

0 comments on commit 9fa78bf

Please sign in to comment.