Skip to content

Commit

Permalink
[ObjectFifo] Add pass to flatten the logical objectFifo (#638)
Browse files Browse the repository at this point in the history
This PR is part to achieve
#644
  • Loading branch information
yzhang93 authored Aug 6, 2024
1 parent 9eb13b2 commit ab9fabd
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"

#define DEBUG_TYPE "iree-amdaie-flatten-logicalobjectfifo"

namespace mlir::iree_compiler::AMDAIE {

namespace {

class AMDAIEFlattenLogicalObjectFifoPass
: public impl::AMDAIEFlattenLogicalObjectFifoBase<
AMDAIEFlattenLogicalObjectFifoPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect, memref::MemRefDialect>();
}

AMDAIEFlattenLogicalObjectFifoPass() = default;
AMDAIEFlattenLogicalObjectFifoPass(
const AMDAIEFlattenLogicalObjectFifoPass &pass){};
void runOnOperation() override;
};

void AMDAIEFlattenLogicalObjectFifoPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(context);

moduleOp->walk([&](AMDAIE::LogicalObjectFifoFromMemrefOp op) {
// Get linearized size and new type.
MemRefType oldType = op.getMemrefType();
uint64_t linearizedSize = oldType.getNumElements();
MemRefType newType =
MemRefType::get(linearizedSize, oldType.getElementType(),
MemRefLayoutAttrInterface{}, oldType.getMemorySpace());

rewriter.setInsertionPoint(op);
auto newLogicalObjectFifo =
rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(newType),
op.getMemref(), op.getTiles());
rewriter.replaceOp(op, newLogicalObjectFifo);

// Replace the access op and insert `memref.reinterpret_cast` to get to the
// original local shape as the objectfifo has a single type, while the DMA
// operations converted into objectfifos can have a different source and
// target type.
for (Operation *user : newLogicalObjectFifo->getUsers()) {
if (auto accessOp = dyn_cast<AMDAIE::LogicalObjectFifoAccessOp>(user)) {
rewriter.setInsertionPoint(accessOp);
auto newAccessOp = rewriter.create<AMDAIE::LogicalObjectFifoAccessOp>(
rewriter.getUnknownLoc(), newLogicalObjectFifo.getOutput(),
accessOp.getAccessType());

auto [strides, baseOffset] = getStridesAndOffset(oldType);
auto reinterpretOp = rewriter.create<memref::ReinterpretCastOp>(
rewriter.getUnknownLoc(), oldType, newAccessOp.getOutput(),
baseOffset, oldType.getShape(), strides);
rewriter.replaceAllUsesWith(accessOp, reinterpretOp);
}
}
});

// Erase old access operations.
moduleOp->walk([&](AMDAIE::LogicalObjectFifoAccessOp accessOp) {
if (accessOp->getUses().empty()) {
rewriter.eraseOp(accessOp);
}
});
}

} // namespace

std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass() {
return std::make_unique<AMDAIEFlattenLogicalObjectFifoPass>();
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_cc_library(
"AMDAIEDmaLoopSubsumption.cpp"
"AMDAIEDmaToCircularDma.cpp"
"AMDAIEDmaUtils.cpp"
"AMDAIEFlattenLogicalObjectFifo.cpp"
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEDISTRIBUTECORESANDOBJECTFIFOS
#define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ std::unique_ptr<Pass> createAMDAIEDmaLoopSubsumptionPass(
/// Create a pass to convert dma operations to circular dma operations.
std::unique_ptr<Pass> createAMDAIEDmaToCircularDmaPass();

/// Create a pass to flatten the logical objectFifos.
std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass();

/// Create a pass to fuse the consumer op into the innermost last scf loop.
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def AMDAIEDmaToCircularDma :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEDmaToCircularDmaPass()";
}

def AMDAIEFlattenLogicalObjectFifo :
Pass<"iree-amdaie-flatten-logicalobjectfifo", "ModuleOp"> {
let summary = "Flatten the logical objectFifos.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()";
}

def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_lit_test_suite(
"distribute_cores_and_objectfifos.mlir"
"dma_loop_subsumption.mlir"
"dma_to_circular_dma.mlir"
"flatten_logical_objectfifo.mlir"
"fuse_consumer_into_loop_scf_for.mlir"
"fuse_consumer_into_loop_scf_forall.mlir"
"fuse_fill_into_forall.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-flatten-logicalobjectfifo)" --split-input-file %s | FileCheck %s

// CHECK-LABEL: @access_logical_objectfifo
// CHECK: %[[FROM_MEMREF_0:.*]] = amdaie.logicalobjectfifo.from_memref
// CHECK-SAME: memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>>
// CHECK: %[[FROM_MEMREF_1:.*]] = amdaie.logicalobjectfifo.from_memref
// CHECK-SAME: memref<1x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>
// CHECK: %[[DMA_0:.*]] = amdaie.circular_dma_cpy_nd
// CHECK-SAME: (!amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>)
// CHECK: %[[FROM_MEMREF_2:.*]] = amdaie.logicalobjectfifo.from_memref
// CHECK-SAME: memref<1x1x8x8x4x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>>
// CHECK: amdaie.core
// CHECK: %[[ACCESS:.*]]= amdaie.logicalobjectfifo.access(%[[FROM_MEMREF_0]], Read) : !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>> -> memref<1024xi32, 2 : i32>
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ACCESS]]
// CHECK-SAME: memref<1024xi32, 2 : i32> to memref<1x1x8x4x8x4xi32, 2 : i32>
// CHECK: %[[ACCESS_2:.*]]= amdaie.logicalobjectfifo.access(%[[FROM_MEMREF_2]], None) : !amdaie.logicalobjectfifo<memref<1024xi32, 2 : i32>> -> memref<1024xi32, 2 : i32>
// CHECK: %[[CAST_2:.*]] = memref.reinterpret_cast %[[ACCESS_2]]
// CHECK-SAME: memref<1024xi32, 2 : i32> to memref<1x1x8x8x4x4xi32, 2 : i32>
// CHECK: linalg.fill ins(%{{.+}} : i32) outs(%[[CAST_2]]
module {
func.func @access_logical_objectfifo() {
%c1 = arith.constant 1 : index
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
amdaie.workgroup {
%alloc = memref.alloc() : memref<1x1x8x4x8x4xi32, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x2x32x32xi32, 1 : i32>
%alloc_1 = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32>
%tile = amdaie.tile(%c0, %c1)
%tile_2 = amdaie.tile(%c0, %c2)
%0 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_2} : memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>
%1 = amdaie.logicalobjectfifo.from_memref %alloc_0, {%tile} : memref<1x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>
%2 = amdaie.circular_dma_cpy_nd(%0[%c0] [%c1024] [%c1], %1[%c0, %c0, %c0] [%c8, %c32, %c4] [%c4, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>)
%3 = amdaie.logicalobjectfifo.from_memref %alloc_1, {%tile_2} : memref<1x1x8x8x4x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>
%4 = amdaie.core(%tile_2) {
scf.forall (%arg0, %arg1) in (2, 2) {
%5 = amdaie.logicalobjectfifo.access(%0, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>> -> memref<1x1x8x4x8x4xi32, 2 : i32>
%6 = amdaie.logicalobjectfifo.access(%3, None) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32>
linalg.fill ins(%c0_i32 : i32) outs(%6 : memref<1x1x8x8x4x4xi32, 2 : i32>)
}
amdaie.end
}
amdaie.controlcode {
amdaie.end
}
}
return
}
}

0 comments on commit ab9fabd

Please sign in to comment.