diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 16d6b93482..0b3d9c2fe0 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -186,7 +186,7 @@ class TTIR_ElementwiseUnaryOp traits = []> : def TTIR_WhereOp : TTIR_DPSOp<"where"> { let summary = "Where operation."; let description = [{ - Select an element from on_true or on_false based on pred. + Selects an element from on_true or on_false based on pred. }]; let arguments = (ins AnyRankedTensor:$pred, diff --git a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h index 58864f5a2e..f922c2501e 100644 --- a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h +++ b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h @@ -13,15 +13,10 @@ namespace mlir::tt::ttir { // struct StableHLOToTTIRPipelineOptions : public PassPipelineOptions { - // Option to enable --remove-dead-values optimization pass. Option removeDeadValuesEnabled{ *this, "enable-remove-dead-values", llvm::cl::desc("Enable --remove-dead-values optimization pass."), llvm::cl::init(true)}; - Option sparseConstantPropogationEnabled{ - *this, "enable-sparse-constant-propogation", - llvm::cl::desc("Enable --sccp optimization pass."), - llvm::cl::init(false)}; Option arithDialectConversionsEnabled{ *this, "enable-arith-to-stablehlo", llvm::cl::desc("Enable Arith to StableHLO conversion pass."), @@ -35,7 +30,6 @@ struct StableHLOToTTIRPipelineOptions // This pass will convert stablehlo.composite ops into func.call ops so // that the TTIR inliner pass may inline the ops. llvm::cl::init(true)}; - llvm::cl::desc("Enable --sccp optimization pass."), llvm::cl::init(true)}; }; void createStableHLOToTTIRPipeline( diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 3e5298440f..1cee4cbb5c 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -108,32 +108,6 @@ def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> { }]; } -def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> { - let summary = "Determine op configurations for maximum performance."; - let description = [{ - Go through the ops, set sharding specs for each op based on sharding analysis, - by updating layout attribute of each op. - }]; - let options = [ - Option<"overrideOutputLayout", "override-output-layout", - "llvm::StringMap", - /*default=*/"llvm::StringMap()", - "Override output tensor layout for specific ops.">, - Option<"shardingPassEnabled", "sharding-pass-enabled", - "bool", - /*default=*/"false", - "Enable sharding pass.">, - Option<"reshardingEnabled", "resharding-enabled", - "bool", - /*default=*/"false", - "Resharding pass. Temp disabled till we support all types of shard specs.">, - Option<"maxLegalLayouts", "max-legal-layouts", - "int64_t", - /*default=*/"64", - "Override maximum number of legal layouts for grid analysis."> - ]; -} - def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { let summary = "Load system desc."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 80df4f6007..4866b9ab21 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -149,7 +149,7 @@ class TTNN_ElementwiseBinaryOp traits = []> : def TTNN_WhereOp : TTNN_NamedDPSOp<"where"> { let summary = "Where op."; let description = [{ - Select operation. + Selects an element from on_true or on_false based on pred. }]; let arguments = (ins AnyRankedTensor:$pred, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index ba499ddf26..9e3bda0997 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -679,15 +679,13 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Other ops // patterns.add, - DefaultOpConversionPattern>(typeConverter, - ctx); + DefaultOpConversionPattern, + DefaultOpConversionPattern>(typeConverter, ctx); // CCL ops // patterns.add>(typeConverter, ctx); - DefaultOpConversionPattern, - DefaultOpConversionPattern>(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index 5cd2d05888..a092a36e1c 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -31,9 +31,6 @@ void createStableHLOToTTIRPipeline( if (options.removeDeadValuesEnabled) { pm.addPass(mlir::createRemoveDeadValuesPass()); } - if (options.sparseConstantPropogationEnabled) { - pm.addPass(mlir::createSCCPPass()); - } } #endif diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index a7e065cc19..7f3baaeaf7 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -52,7 +52,7 @@ void createTTNNPipelineAnalysisPasses( options.memoryLayoutAnalysisEnabled; optimizerOptions.memReconfigEnabled = options.memReconfigEnabled; optimizerOptions.maxLegalLayouts = options.maxLegalLayouts; - pm.addPass(mlir::tt::ttir::createTTIROptimizer(optimizerOptions)); + pm.addPass(mlir::tt::ttnn::createTTNNOptimizer(optimizerOptions)); } } @@ -62,8 +62,6 @@ void createTTNNPipelineLoweringPasses( pm.addPass(createConvertTTIRToTTNNPass()); // Add pass to remove unused values. pm.addPass(mlir::createRemoveDeadValuesPass()); - // Dealloc pass for tensor memory deallocation after last use. - pm.addPass(createTTNNDeallocate()); } void createTTNNPipelineLayoutDecompositionPass( diff --git a/runtime/lib/ttnn/operations/eltwise/where/where.cpp b/runtime/lib/ttnn/operations/eltwise/tertiary/where.cpp similarity index 54% rename from runtime/lib/ttnn/operations/eltwise/where/where.cpp rename to runtime/lib/ttnn/operations/eltwise/tertiary/where.cpp index e4d20f00ab..40b7fe38cc 100644 --- a/runtime/lib/ttnn/operations/eltwise/where/where.cpp +++ b/runtime/lib/ttnn/operations/eltwise/tertiary/where.cpp @@ -7,26 +7,19 @@ #include "tt/runtime/ttnn/operations/utils.h" namespace tt::runtime::ttnn::operations::where { -static void -runWhereOp(::tt::target::ttnn::WhereOp const *op, ProgramTensorPool &tensorPool, - std::function<::ttnn::Tensor( - const ::ttnn::Tensor &, - const std::optional>> &, - const bool, const std::optional<::tt::tt_metal::MemoryConfig> &, - const std::optional<::ttnn::DeviceComputeKernelConfig> &, float)> - ttnnOp) { + +void run(const ::tt::target::ttnn::WhereOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.tensorPool; + runWhereOp(op, tensorPool, ::ttnn::where); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = utils::createMemoryConfig(op->out()); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); - ::ttnn::Tensor out = ttnnOp(in, op->pred(), op->on_true(), op->on_false(), - outputMemoryConfig /* memory_config_arg */); + ::ttnn::Tensor out = + ::ttnn::where(in, op->pred(), op->on_true(), op->on_false(), + outputMemoryConfig /* memory_config_arg */); tensorPool.insert_or_assign(op->out()->global_id(), out); } - -void run(const ::tt::target::ttnn::WhereOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - runWhereOp(op, tensorPool, ::ttnn::where); -} } // namespace tt::runtime::ttnn::operations::where diff --git a/runtime/lib/ttnn/operations/eltwise/where/where.h b/runtime/lib/ttnn/operations/eltwise/tertiary/where.h similarity index 73% rename from runtime/lib/ttnn/operations/eltwise/where/where.h rename to runtime/lib/ttnn/operations/eltwise/tertiary/where.h index 9ead3a1a66..a51956a8a5 100644 --- a/runtime/lib/ttnn/operations/eltwise/where/where.h +++ b/runtime/lib/ttnn/operations/eltwise/tertiary/where.h @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTNN_RUNTIME_WHERE_H -#define TTNN_RUNTIME_WHERE_H +#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERTIARY_WHERE_H +#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERTIARY_WHERE_H #include "tt/runtime/ttnn/types.h" #include "ttmlir/Target/TTNN/program_generated.h" diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir index b6d088f9f4..2d1c0afd48 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir @@ -5,8 +5,8 @@ module @jit_eltwise_select attributes {} { func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1> %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> - // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.where"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.where"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %1 : tensor<13x37xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_where.mlir b/test/ttmlir/Dialect/TTNN/simple_where.mlir index 9df12c77f5..778aba3a77 100644 --- a/test/ttmlir/Dialect/TTNN/simple_where.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_where.mlir @@ -6,8 +6,8 @@ module @jit_eltwise_where { %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> %2 = tensor.empty() : tensor<13x37xf32> %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.where"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%{{[0-9]+}}, [[VAL0]]) return %3 : tensor<13x37xf32> } }