Skip to content

Commit

Permalink
dialects: (builtin) shape for DenseIntORFPElementsAttr is always known (
Browse files Browse the repository at this point in the history
#3534)

As they only work for ranked structures, the shape will always be not
None
  • Loading branch information
jorendumoulin authored Nov 28, 2024
1 parent dff62c6 commit c624080
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 9 deletions.
6 changes: 1 addition & 5 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,9 +1733,7 @@ class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]):
data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]]

# The type stores the shape data
def get_shape(self) -> tuple[int, ...] | None:
if isinstance(self.type, UnrankedTensorType):
return None
def get_shape(self) -> tuple[int, ...]:
return self.type.get_shape()

def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
Expand All @@ -1744,8 +1742,6 @@ def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
@property
def shape_is_complete(self) -> bool:
shape = self.get_shape()
if shape is None or not len(shape):
return False

n = 1
for dim in shape:
Expand Down
2 changes: 1 addition & 1 deletion xdsl/interpreters/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def dense_int_or_fp_elements_value(
attr.get_element_type(), interpreter.index_bitwidth
),
)
return ShapedArray(data_ptr, list(shape) if shape is not None else [])
return ShapedArray(data_ptr, list(shape))
2 changes: 0 additions & 2 deletions xdsl/interpreters/ml_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def run_global_load_constant(
global_value = global_op.value
assert isinstance(global_value, DenseIntOrFPElementsAttr)
shape = global_value.get_shape()
if shape is None:
raise NotImplementedError()
xtype = xtype_for_el_type(
global_value.get_element_type(), interpreter.index_bitwidth
)
Expand Down
2 changes: 1 addition & 1 deletion xdsl/interpreters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def run_constant(
op.value.get_element_type(), interpreter.index_bitwidth
),
)
return (ShapedArray(data_ptr, list(shape) if shape is not None else []),)
return (ShapedArray(data_ptr, list(shape)),)

@impl(onnx.ReshapeOp)
def run_reshape(
Expand Down

0 comments on commit c624080

Please sign in to comment.