Skip to content

Commit

Permalink
dialects: (gpu) Elide returns in __init__ methods (#1694)
Browse files Browse the repository at this point in the history
This PR:

- Elides the explicit `return` statements in `__init__` methods of the
`gpu` dialect.

This must have been a leftover from when `.get()` constructor methods
were used.
I had it in the backlog and my linter complained since I happened to
open the dialect source file recently.
  • Loading branch information
compor authored Oct 24, 2023
1 parent fcae7da commit 14c5320
Showing 1 changed file with 22 additions and 32 deletions.
54 changes: 22 additions & 32 deletions xdsl/dialects/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
attributes: dict[str, Attribute] = (
{"hostShared": UnitAttr()} if host_shared else {}
)
return super().__init__(
super().__init__(
operands=[async_dependencies_vals, dynamic_sizes_vals, []],
result_types=[return_type, token_return],
attributes=attributes,
Expand Down Expand Up @@ -271,7 +271,7 @@ class BarrierOp(IRDLOperation):
name = "gpu.barrier"

def __init__(self):
return super().__init__()
super().__init__()


@irdl_op_definition
Expand All @@ -281,9 +281,7 @@ class BlockDimOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self, dim: DimensionAttr):
return super().__init__(
result_types=[IndexType()], properties={"dimension": dim}
)
super().__init__(result_types=[IndexType()], properties={"dimension": dim})


@irdl_op_definition
Expand All @@ -293,9 +291,7 @@ class BlockIdOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self, dim: DimensionAttr):
return super().__init__(
result_types=[IndexType()], properties={"dimension": dim}
)
super().__init__(result_types=[IndexType()], properties={"dimension": dim})


@irdl_op_definition
Expand All @@ -315,7 +311,7 @@ def __init__(
async_dependencies: Sequence[SSAValue | Operation] | None = None,
is_async: bool = False,
):
return super().__init__(
super().__init__(
operands=[async_dependencies, buffer],
result_types=[[AsyncTokenType()] if is_async else []],
)
Expand All @@ -340,7 +336,7 @@ def __init__(
async_dependencies: Sequence[SSAValue | Operation] | None = None,
is_async: bool = False,
):
return super().__init__(
super().__init__(
operands=[async_dependencies, destination, source],
result_types=[[AsyncTokenType()] if is_async else []],
)
Expand All @@ -360,7 +356,7 @@ class ModuleEndOp(IRDLOperation):
traits = traits_def(lambda: frozenset([IsTerminator(), HasParent(ModuleOp)]))

def __init__(self):
return super().__init__()
super().__init__()


@irdl_op_definition
Expand Down Expand Up @@ -450,9 +446,7 @@ class GlobalIdOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self, dim: DimensionAttr):
return super().__init__(
result_types=[IndexType()], properties={"dimension": dim}
)
super().__init__(result_types=[IndexType()], properties={"dimension": dim})


@irdl_op_definition
Expand All @@ -462,9 +456,7 @@ class GridDimOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self, dim: DimensionAttr):
return super().__init__(
result_types=[IndexType()], properties={"dimension": dim}
)
super().__init__(result_types=[IndexType()], properties={"dimension": dim})


@irdl_op_definition
Expand All @@ -484,7 +476,7 @@ class HostRegisterOp(IRDLOperation):
value: Operand = operand_def(memref.UnrankedMemrefType)

def __init__(self, memref: SSAValue | Operation):
return super().__init__(operands=[SSAValue.get(memref)])
super().__init__(operands=[SSAValue.get(memref)])


@irdl_op_definition
Expand All @@ -498,7 +490,7 @@ class HostUnregisterOp(IRDLOperation):
value: Operand = operand_def(memref.UnrankedMemrefType)

def __init__(self, memref: SSAValue | Operation):
return super().__init__(operands=[SSAValue.get(memref)])
super().__init__(operands=[SSAValue.get(memref)])


@irdl_op_definition
Expand All @@ -507,7 +499,7 @@ class LaneIdOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self):
return super().__init__(result_types=[IndexType()])
super().__init__(result_types=[IndexType()])


@irdl_op_definition
Expand Down Expand Up @@ -551,7 +543,7 @@ def __init__(
if dynamicSharedMemorySize is None
else [SSAValue.get(dynamicSharedMemorySize)]
]
return super().__init__(
super().__init__(
operands=operands,
result_types=[[AsyncTokenType()] if async_launch else []],
regions=[body],
Expand Down Expand Up @@ -634,7 +626,7 @@ def __init__(
if len(blockSize) != 3:
raise ValueError(f"LaunchOp must have 3 blockSizes, got {len(blockSize)}")

return super().__init__(
super().__init__(
operands=[
asyncDependencies,
*gridSize,
Expand All @@ -654,7 +646,7 @@ class NumSubgroupsOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self):
return super().__init__(result_types=[IndexType()])
super().__init__(result_types=[IndexType()])


@irdl_op_definition
Expand All @@ -666,7 +658,7 @@ class ReturnOp(IRDLOperation):
traits = frozenset([IsTerminator(), HasParent(FuncOp)])

def __init__(self, operands: Sequence[SSAValue | Operation]):
return super().__init__(operands=[operands])
super().__init__(operands=[operands])


@irdl_op_definition
Expand All @@ -675,7 +667,7 @@ class SetDefaultDeviceOp(IRDLOperation):
devIndex: Operand = operand_def(i32)

def __init__(self, devIndex: SSAValue | Operation):
return super().__init__(operands=[SSAValue.get(devIndex)])
super().__init__(operands=[SSAValue.get(devIndex)])


@irdl_op_definition
Expand All @@ -684,7 +676,7 @@ class SubgroupIdOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self):
return super().__init__(result_types=[IndexType()])
super().__init__(result_types=[IndexType()])


@irdl_op_definition
Expand All @@ -693,7 +685,7 @@ class SubgroupSizeOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self):
return super().__init__(result_types=[IndexType()])
super().__init__(result_types=[IndexType()])


@irdl_op_definition
Expand All @@ -703,7 +695,7 @@ class TerminatorOp(IRDLOperation):
traits = frozenset([HasParent(LaunchOp), IsTerminator()])

def __init__(self):
return super().__init__()
super().__init__()


@irdl_op_definition
Expand All @@ -713,9 +705,7 @@ class ThreadIdOp(IRDLOperation):
result: OpResult = result_def(IndexType)

def __init__(self, dim: DimensionAttr):
return super().__init__(
result_types=[IndexType()], properties={"dimension": dim}
)
super().__init__(result_types=[IndexType()], properties={"dimension": dim})


@irdl_op_definition
Expand All @@ -724,7 +714,7 @@ class YieldOp(IRDLOperation):
values: VarOperand = var_operand_def(Attribute)

def __init__(self, operands: Sequence[SSAValue | Operation]):
return super().__init__(operands=[operands])
super().__init__(operands=[operands])

traits = frozenset([IsTerminator()])

Expand Down

0 comments on commit 14c5320

Please sign in to comment.