diff --git a/python/dialects/aiex.py b/python/dialects/aiex.py index 9d62861da1..a1047df6f2 100644 --- a/python/dialects/aiex.py +++ b/python/dialects/aiex.py @@ -853,12 +853,8 @@ def shim_dma_bd( strides = [0] * 3 + [1] if transfer_len is None: - if len(sizes) >= 4: - # For shim dma bd, highest dimension is repeat count which is not included in the length - transfer_len = np.prod(sizes[1:]) - else: - # If does not have highest dimension, then we can take the product of all dimensions - transfer_len = np.prod(sizes) + transfer_len = np.prod(sizes[-3:]) + dimensions = list(zip(sizes, strides)) dma_bd(mem, offset=offset, len=transfer_len, dimensions=dimensions)