diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 0b5061116e..71211017d7 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1733,9 +1733,7 @@ class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]): data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]] # The type stores the shape data - def get_shape(self) -> tuple[int, ...] | None: - if isinstance(self.type, UnrankedTensorType): - return None + def get_shape(self) -> tuple[int, ...]: return self.type.get_shape() def get_element_type(self) -> IntegerType | IndexType | AnyFloat: @@ -1744,8 +1742,6 @@ def get_element_type(self) -> IntegerType | IndexType | AnyFloat: @property def shape_is_complete(self) -> bool: shape = self.get_shape() - if shape is None or not len(shape): - return False n = 1 for dim in shape: diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py index 5128205a8d..c87c0e0998 100644 --- a/xdsl/interpreters/builtin.py +++ b/xdsl/interpreters/builtin.py @@ -94,4 +94,4 @@ def dense_int_or_fp_elements_value( attr.get_element_type(), interpreter.index_bitwidth ), ) - return ShapedArray(data_ptr, list(shape) if shape is not None else []) + return ShapedArray(data_ptr, list(shape)) diff --git a/xdsl/interpreters/ml_program.py b/xdsl/interpreters/ml_program.py index 820590b994..af06accd68 100644 --- a/xdsl/interpreters/ml_program.py +++ b/xdsl/interpreters/ml_program.py @@ -28,8 +28,6 @@ def run_global_load_constant( global_value = global_op.value assert isinstance(global_value, DenseIntOrFPElementsAttr) shape = global_value.get_shape() - if shape is None: - raise NotImplementedError() xtype = xtype_for_el_type( global_value.get_element_type(), interpreter.index_bitwidth ) diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py index 4c6acac1c2..be7251298d 100644 --- a/xdsl/interpreters/onnx.py +++ b/xdsl/interpreters/onnx.py @@ -145,7 +145,7 @@ def run_constant( op.value.get_element_type(), interpreter.index_bitwidth ), ) - return (ShapedArray(data_ptr, list(shape) if shape is not None else []),) + return (ShapedArray(data_ptr, list(shape)),) @impl(onnx.ReshapeOp) def run_reshape(