Skip to content

Commit

Permalink
[AIEX] Add IpuDmaWaitOp to standard lowering pass and fix device chec…
Browse files Browse the repository at this point in the history
…ks (#1229)
  • Loading branch information
jtuyls authored Apr 12, 2024
1 parent d12d183 commit e1c5878
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 13 deletions.
6 changes: 3 additions & 3 deletions lib/Dialect/AIEX/IR/AIEXDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ LogicalResult AIEX::IpuDmaMemcpyNdOp::verify() {

LogicalResult AIEX::IpuDmaWaitOp::verify() {
AIE::DeviceOp dev = (*this)->getParentOfType<AIE::DeviceOp>();
if (!dev)
return emitOpError("couldn't find parent of type DeviceOp");
if (!dev.lookupSymbol(getSymbol()))
// Some passes (e.g. aie-standard-lowering) use aiex ops outside a DeviceOp,
// so we can't expect the device to always exist.
if (dev && !dev.lookupSymbol(getSymbol()))
return emitOpError("couldn't find symbol in parent device");
return success();
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/AIEX/Transforms/AIEDmaToIpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ struct DmaWaitToIpuPattern : OpConversionPattern<IpuDmaWaitOp> {
matchAndRewrite(IpuDmaWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
AIE::DeviceOp dev = op->getParentOfType<AIE::DeviceOp>();
if (!dev)
return op.emitOpError("couldn't find parent of type DeviceOp");

std::optional<AIE::ShimDMAAllocationOp> shimDmaAllocOp =
getAllocOpForSymbol(dev, op.getSymbol());
if (!shimDmaAllocOp) {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct AIEXToStandardPass : AIEXToStandardBase<AIEXToStandardPass> {
ConversionTarget target(getContext());
RewritePatternSet removepatterns(&getContext());
removepatterns.add<AIEXOpRemoval<IpuDmaMemcpyNdOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<IpuDmaWaitOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<IpuShimTilePushQueueOp>>(m.getContext(),
m);
removepatterns.add<AIEXOpRemoval<IpuWriteRTPOp>>(m.getContext(), m);
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/DmaToIpu/dma_to_ipu_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//
//===----------------------------------------------------------------------===//

// RUN: aie-opt --aie-dma-to-ipu --verify-diagnostics %s
// RUN: aie-opt --split-input-file --aie-dma-to-ipu --verify-diagnostics %s

module {
aie.device(ipu) {
Expand Down
8 changes: 0 additions & 8 deletions test/dialect/AIEX/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@

// RUN: aie-opt --split-input-file --verify-diagnostics %s

func.func @ipu_dma_wait_no_device() {
// expected-error@+1 {{'aiex.ipu.dma_wait' op couldn't find parent of type DeviceOp}}
aiex.ipu.dma_wait {symbol = @out0}
return
}

// -----

aie.device(ipu) {
func.func @ipu_dma_wait_no_symbol() {
// expected-error@+1 {{'aiex.ipu.dma_wait' op couldn't find symbol in parent device}}
Expand Down
10 changes: 9 additions & 1 deletion test/dialect/AIEX/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@

// CHECK-LABEL: func.func @ipu_dma_wait
// CHECK: aiex.ipu.dma_wait {symbol = @out0}

aie.device(ipu) {
memref.global "public" @out0 : memref<16xi32>
func.func @ipu_dma_wait() {
aiex.ipu.dma_wait {symbol = @out0}
return
}
}

// -----

// CHECK-LABEL: func.func @ipu_dma_wait_no_device
// CHECK: aiex.ipu.dma_wait {symbol = @out0}
func.func @ipu_dma_wait_no_device() {
aiex.ipu.dma_wait {symbol = @out0}
return
}
26 changes: 26 additions & 0 deletions test/lower-to-standard/aiex_standard_lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- aiex_standard_lowering.mlir -----------------------------*- MLIR -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//

// RUN: aie-opt --split-input-file --aiex-standard-lowering %s | FileCheck %s

// CHECK-LABEL: dma_and_wait
// CHECK-NOT: aiex.ipu.dma_memcpy_nd
// CHECK-NOT: aiex.ipu.dma_wait
module {
aie.device(ipu) {
memref.global "public" @toMem : memref<16xi32>
func.func @dma_and_wait(%arg0: memref<16xi32>, %arg1: memref<16xi32>) {
aiex.ipu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][1, 1, 16, 16][0, 0, 64]) { metadata = @toMem, id = 1 : i64 } : memref<16xi32>
aiex.ipu.dma_wait {symbol = @toMem}
return
}
aie.shim_dma_allocation @toMem (MM2S, 1, 1)
}
}

0 comments on commit e1c5878

Please sign in to comment.