Skip to content

Commit

Permalink
fix insert allocs
Browse files Browse the repository at this point in the history
Signed-off-by: dchigarev <[email protected]>
  • Loading branch information
dchigarev authored and AndreyPavlenko committed Aug 27, 2024
1 parent ee45972 commit 899fc02
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
4 changes: 4 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ add_mlir_library(IMEXTransforms
MLIRSupport
MLIRTransformUtils
MLIRVectorTransforms
IMEXXeTileDialect
IMEXRegionDialect

DEPENDS
IMEXTransformsPassIncGen
IMEXXeTilePassIncGen
MLIRRegionOpsIncGen
)
11 changes: 10 additions & 1 deletion lib/Transforms/InsertGPUAllocs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class InsertGPUAllocsPass final
: public imex::impl::InsertGPUAllocsBase<InsertGPUAllocsPass> {

public:
explicit InsertGPUAllocsPass() : m_clientAPI("vulkan") {}
explicit InsertGPUAllocsPass() : m_clientAPI("opencl") {}
explicit InsertGPUAllocsPass(const mlir::StringRef &clientAPI)
: m_clientAPI(clientAPI) {}

Expand Down Expand Up @@ -411,6 +411,15 @@ class InsertGPUAllocsPass final
use.set(newAlloc.getResult());
}
}

// remove 'memref.dealloc' (it's later replaced with gpu.dealloc)
auto memory = alloc->getResult(0);
for (auto u : memory.getUsers()) {
if (auto dealloc = mlir::dyn_cast<mlir::memref::DeallocOp>(u)) {
dealloc.erase();
}
}

alloc.replaceAllUsesWith(allocResult);
builder.create<mlir::gpu::DeallocOp>(loc, std::nullopt, allocResult);
alloc.erase();
Expand Down
2 changes: 1 addition & 1 deletion lib/Transforms/SetSPIRVAbiAttribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace {
class SetSPIRVAbiAttributePass
: public imex::impl::SetSPIRVAbiAttributeBase<SetSPIRVAbiAttributePass> {
public:
explicit SetSPIRVAbiAttributePass() { m_clientAPI = "vulkan"; }
explicit SetSPIRVAbiAttributePass() { m_clientAPI = "opencl"; }
explicit SetSPIRVAbiAttributePass(const mlir::StringRef &clientAPI)
: m_clientAPI(clientAPI) {}

Expand Down

0 comments on commit 899fc02

Please sign in to comment.