diff --git a/lib/Dialect/AIEX/IR/AIEXDialect.cpp b/lib/Dialect/AIEX/IR/AIEXDialect.cpp index a007630212..6363626aa2 100644 --- a/lib/Dialect/AIEX/IR/AIEXDialect.cpp +++ b/lib/Dialect/AIEX/IR/AIEXDialect.cpp @@ -107,9 +107,9 @@ LogicalResult AIEX::IpuDmaMemcpyNdOp::verify() { LogicalResult AIEX::IpuDmaWaitOp::verify() { AIE::DeviceOp dev = (*this)->getParentOfType(); - if (!dev) - return emitOpError("couldn't find parent of type DeviceOp"); - if (!dev.lookupSymbol(getSymbol())) + // Some passes (e.g. aie-standard-lowering) use aiex ops outside a DeviceOp, + // so we can't expect the device to always exist. + if (dev && !dev.lookupSymbol(getSymbol())) return emitOpError("couldn't find symbol in parent device"); return success(); } diff --git a/lib/Dialect/AIEX/Transforms/AIEDmaToIpu.cpp b/lib/Dialect/AIEX/Transforms/AIEDmaToIpu.cpp index bf46aaa9eb..3841f73bf5 100644 --- a/lib/Dialect/AIEX/Transforms/AIEDmaToIpu.cpp +++ b/lib/Dialect/AIEX/Transforms/AIEDmaToIpu.cpp @@ -348,6 +348,9 @@ struct DmaWaitToIpuPattern : OpConversionPattern { matchAndRewrite(IpuDmaWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { AIE::DeviceOp dev = op->getParentOfType(); + if (!dev) + return op.emitOpError("couldn't find parent of type DeviceOp"); + std::optional shimDmaAllocOp = getAllocOpForSymbol(dev, op.getSymbol()); if (!shimDmaAllocOp) { diff --git a/lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp b/lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp index bbffb10d56..40609c18e5 100644 --- a/lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp +++ b/lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp @@ -48,6 +48,7 @@ struct AIEXToStandardPass : AIEXToStandardBase { ConversionTarget target(getContext()); RewritePatternSet removepatterns(&getContext()); removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); removepatterns.add>(m.getContext(), m); removepatterns.add>(m.getContext(), m); diff --git a/test/Conversion/DmaToIpu/dma_to_ipu_invalid.mlir b/test/Conversion/DmaToIpu/dma_to_ipu_invalid.mlir index 37ab97bd65..89eff26d44 100644 --- a/test/Conversion/DmaToIpu/dma_to_ipu_invalid.mlir +++ b/test/Conversion/DmaToIpu/dma_to_ipu_invalid.mlir @@ -8,7 +8,7 @@ // //===----------------------------------------------------------------------===// -// RUN: aie-opt --aie-dma-to-ipu --verify-diagnostics %s +// RUN: aie-opt --split-input-file --aie-dma-to-ipu --verify-diagnostics %s module { aie.device(ipu) { diff --git a/test/dialect/AIEX/invalid.mlir b/test/dialect/AIEX/invalid.mlir index b6bc500cc2..9b57d84b70 100644 --- a/test/dialect/AIEX/invalid.mlir +++ b/test/dialect/AIEX/invalid.mlir @@ -10,14 +10,6 @@ // RUN: aie-opt --split-input-file --verify-diagnostics %s -func.func @ipu_dma_wait_no_device() { - // expected-error@+1 {{'aiex.ipu.dma_wait' op couldn't find parent of type DeviceOp}} - aiex.ipu.dma_wait {symbol = @out0} - return -} - -// ----- - aie.device(ipu) { func.func @ipu_dma_wait_no_symbol() { // expected-error@+1 {{'aiex.ipu.dma_wait' op couldn't find symbol in parent device}} diff --git a/test/dialect/AIEX/roundtrip.mlir b/test/dialect/AIEX/roundtrip.mlir index c49fada758..27611d5914 100644 --- a/test/dialect/AIEX/roundtrip.mlir +++ b/test/dialect/AIEX/roundtrip.mlir @@ -12,7 +12,6 @@ // CHECK-LABEL: func.func @ipu_dma_wait // CHECK: aiex.ipu.dma_wait {symbol = @out0} - aie.device(ipu) { memref.global "public" @out0 : memref<16xi32> func.func @ipu_dma_wait() { @@ -20,3 +19,12 @@ aie.device(ipu) { return } } + +// ----- + +// CHECK-LABEL: func.func @ipu_dma_wait_no_device +// CHECK: aiex.ipu.dma_wait {symbol = @out0} +func.func @ipu_dma_wait_no_device() { + aiex.ipu.dma_wait {symbol = @out0} + return +} diff --git a/test/lower-to-standard/aiex_standard_lowering.mlir b/test/lower-to-standard/aiex_standard_lowering.mlir new file mode 100644 index 0000000000..639dbc1e83 --- /dev/null +++ b/test/lower-to-standard/aiex_standard_lowering.mlir @@ -0,0 +1,26 @@ +//===- aiex_standard_lowering.mlir -----------------------------*- MLIR -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +// RUN: aie-opt --split-input-file --aiex-standard-lowering %s | FileCheck %s + +// CHECK-LABEL: dma_and_wait +// CHECK-NOT: aiex.ipu.dma_memcpy_nd +// CHECK-NOT: aiex.ipu.dma_wait +module { + aie.device(ipu) { + memref.global "public" @toMem : memref<16xi32> + func.func @dma_and_wait(%arg0: memref<16xi32>, %arg1: memref<16xi32>) { + aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][1, 1, 16, 16][0, 0, 64]) { metadata = @toMem, id = 1 : i64 } : memref<16xi32> + aiex.ipu.dma_wait {symbol = @toMem} + return + } + aie.shim_dma_allocation @toMem (MM2S, 1, 1) + } +}