diff --git a/include/ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h b/include/ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h new file mode 100644 index 000000000..27d289f73 --- /dev/null +++ b/include/ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H +#define TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::tt { + +#ifdef TTMLIR_ENABLE_STABLEHLO +std::unique_ptr> createConvertArithToStableHLOPass(); +#endif + +} // namespace mlir::tt + +#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_STABLEHLOTOTTIR_H diff --git a/include/ttmlir/Conversion/Passes.h b/include/ttmlir/Conversion/Passes.h index b5c564372..1f754a087 100644 --- a/include/ttmlir/Conversion/Passes.h +++ b/include/ttmlir/Conversion/Passes.h @@ -6,6 +6,7 @@ #define TTMLIR_CONVERSION_PASSES_H #ifdef TTMLIR_ENABLE_STABLEHLO +#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h" #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" #endif #include "ttmlir/Conversion/TTIRToTTMetal/TTIRToTTMetal.h" diff --git a/include/ttmlir/Conversion/Passes.td b/include/ttmlir/Conversion/Passes.td index 340cb422f..98535eac2 100644 --- a/include/ttmlir/Conversion/Passes.td +++ b/include/ttmlir/Conversion/Passes.td @@ -13,6 +13,11 @@ let summary = "Convert StableHLO dialect to TTIR dialect."; let constructor = "createConvertStableHLOToTTIRPass()"; let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::tt::ttir::TTIRDialect"]; } +def ConvertArithToStableHLO : Pass<"convert-arith-to-stablehlo", "::mlir::ModuleOp"> { +let summary = "Convert Arith Dialect to StableHLO dialect."; + let constructor = "createConvertArithToStableHLOPass()"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::arith::ArithDialect"]; +} #endif def ConvertTosaToTTIR : Pass<"convert-tosa-to-ttir", "::mlir::ModuleOp"> { diff --git a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h index 38457d032..33628a65f 100644 --- a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h +++ b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h @@ -20,6 +20,13 @@ struct StableHLOToTTIRPipelineOptions // Currently this pass fails if module has a name, so keeping the // optimization OFF by default until that issue is fixed on llvm side. llvm::cl::init(false)}; + Option arithDialectConversionsEnabled{ + *this, "enable-arith-to-stablehlo", + llvm::cl::desc("Enable Arith to StableHLO conversion pass."), + // Currently torch-mlir front-end does not convert ConstantOp for Arith + // Dialect to StableHLO. This pass makes those conversions until this + // is fixed in the upstream torch-mlir. + llvm::cl::init(true)}; }; void createStableHLOToTTIRPipeline( diff --git a/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp b/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp new file mode 100644 index 000000000..587847582 --- /dev/null +++ b/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_CONVERTARITHTOSTABLEHLO +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +class ArithToStableHLOConstantOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::arith::ConstantOp srcOp, + mlir::arith::ConstantOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(srcOp, + srcOp.getValue()); + return success(); + } +}; + +struct ConvertArithToStableHLOPass + : public ttir::impl::ConvertArithToStableHLOBase< + ConvertArithToStableHLOPass> { + void runOnOperation() final { + mlir::ConversionTarget target(getContext()); + + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + // For now keep the same type assuming StableHLO ops operate on builtin + // tensor. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { + assert(isa(type) && + "only ranked tensor type supported"); + return type; + }); + RewritePatternSet patterns(&getContext()); + + // Convert Arith ConstantOp to StableHLO ConstantOp + patterns.add(typeConverter, + &getContext()); + + // Apply conversion. + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> createConvertArithToStableHLOPass() { + return std::make_unique(); +} + +} // namespace mlir::tt diff --git a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt index 7649d5340..5f0deda9b 100644 --- a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt +++ b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt @@ -6,6 +6,7 @@ include_directories(${PROJECT_SOURCE_DIR}/include) add_mlir_library(TTMLIRStableHLOToTTIR StableHLOToTTIRPatterns.cpp StableHLOToTTIRPass.cpp + ArithToStableHLOPass.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index fdb857cc1..bd97f1dfe 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -5,6 +5,7 @@ #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" #include +#include #include #include #include diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index b4f3b5ee0..08bd321fd 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -17,6 +17,9 @@ namespace mlir::tt::ttir { #ifdef TTMLIR_ENABLE_STABLEHLO void createStableHLOToTTIRPipeline( OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options) { + if (options.arithDialectConversionsEnabled) { + pm.addPass(createConvertArithToStableHLOPass()); + } pm.addPass(createConvertStableHLOToTTIRPass()); if (options.removeDeadValuesEnabled) { pm.addPass(mlir::createRemoveDeadValuesPass()); diff --git a/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir b/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir new file mode 100644 index 000000000..0cbe0385d --- /dev/null +++ b/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_constant attributes {} { + func.func public @test_splat() -> tensor<64xf32> { + %0 = arith.constant dense<0.3> : tensor<64xf32> + // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + return %0 : tensor<64xf32> + } + + func.func public @test_multiple() -> tensor<2x2xf32> { + %0 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + return %0 : tensor<2x2xf32> + } +}