Skip to content

Commit

Permalink
Correct typo and code review changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Nov 5, 2024
1 parent 63421bd commit bb3f739
Show file tree
Hide file tree
Showing 11 changed files with 19 additions and 65 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
def TTIR_WhereOp : TTIR_DPSOp<"where"> {
let summary = "Where operation.";
let description = [{
Select an element from on_true or on_false based on pred.
Selects an element from on_true or on_false based on pred.
}];

let arguments = (ins AnyRankedTensor:$pred,
Expand Down
6 changes: 0 additions & 6 deletions include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@ namespace mlir::tt::ttir {
//
struct StableHLOToTTIRPipelineOptions
: public PassPipelineOptions<StableHLOToTTIRPipelineOptions> {
// Option to enable --remove-dead-values optimization pass.
Option<bool> removeDeadValuesEnabled{
*this, "enable-remove-dead-values",
llvm::cl::desc("Enable --remove-dead-values optimization pass."),
llvm::cl::init(true)};
Option<bool> sparseConstantPropogationEnabled{
*this, "enable-sparse-constant-propogation",
llvm::cl::desc("Enable --sccp optimization pass."),
llvm::cl::init(false)};
Option<bool> arithDialectConversionsEnabled{
*this, "enable-arith-to-stablehlo",
llvm::cl::desc("Enable Arith to StableHLO conversion pass."),
Expand All @@ -35,7 +30,6 @@ struct StableHLOToTTIRPipelineOptions
// This pass will convert stablehlo.composite ops into func.call ops so
// that the TTIR inliner pass may inline the ops.
llvm::cl::init(true)};
llvm::cl::desc("Enable --sccp optimization pass."), llvm::cl::init(true)};
};

void createStableHLOToTTIRPipeline(
Expand Down
26 changes: 0 additions & 26 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,6 @@ def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> {
}];
}

def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> {
let summary = "Determine op configurations for maximum performance.";
let description = [{
Go through the ops, set sharding specs for each op based on sharding analysis,
by updating layout attribute of each op.
}];
let options = [
Option<"overrideOutputLayout", "override-output-layout",
"llvm::StringMap<LayoutOverrideParams>",
/*default=*/"llvm::StringMap<LayoutOverrideParams>()",
"Override output tensor layout for specific ops.">,
Option<"shardingPassEnabled", "sharding-pass-enabled",
"bool",
/*default=*/"false",
"Enable sharding pass.">,
Option<"reshardingEnabled", "resharding-enabled",
"bool",
/*default=*/"false",
"Resharding pass. Temp disabled till we support all types of shard specs.">,
Option<"maxLegalLayouts", "max-legal-layouts",
"int64_t",
/*default=*/"64",
"Override maximum number of legal layouts for grid analysis.">
];
}

def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
let summary = "Load system desc.";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
def TTNN_WhereOp : TTNN_NamedDPSOp<"where"> {
let summary = "Where op.";
let description = [{
Select operation.
Selects an element from on_true or on_false based on pred.
}];

let arguments = (ins AnyRankedTensor:$pred,
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,15 +679,13 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Other ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>,
DefaultOpConversionPattern<ttnn::EmbeddingOp>>(typeConverter,
ctx);
DefaultOpConversionPattern<ttnn::EmbeddingOp>,
DefaultOpConversionPattern<ttnn::WhereOp>>(typeConverter, ctx);

// CCL ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AllGatherOp>>(typeConverter,
ctx);
DefaultOpConversionPattern<ttnn::EmbeddingOp>,
DefaultOpConversionPattern<ttnn::WhereOp>>(typeConverter, ctx);
}

} // namespace mlir::tt
3 changes: 0 additions & 3 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ void createStableHLOToTTIRPipeline(
if (options.removeDeadValuesEnabled) {
pm.addPass(mlir::createRemoveDeadValuesPass());
}
if (options.sparseConstantPropogationEnabled) {
pm.addPass(mlir::createSCCPPass());
}
}
#endif

Expand Down
4 changes: 1 addition & 3 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void createTTNNPipelineAnalysisPasses(
options.memoryLayoutAnalysisEnabled;
optimizerOptions.memReconfigEnabled = options.memReconfigEnabled;
optimizerOptions.maxLegalLayouts = options.maxLegalLayouts;
pm.addPass(mlir::tt::ttir::createTTIROptimizer(optimizerOptions));
pm.addPass(mlir::tt::ttnn::createTTNNOptimizer(optimizerOptions));
}
}

Expand All @@ -62,8 +62,6 @@ void createTTNNPipelineLoweringPasses(
pm.addPass(createConvertTTIRToTTNNPass());
// Add pass to remove unused values.
pm.addPass(mlir::createRemoveDeadValuesPass());
// Dealloc pass for tensor memory deallocation after last use.
pm.addPass(createTTNNDeallocate());
}

void createTTNNPipelineLayoutDecompositionPass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,19 @@
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::where {
static void
runWhereOp(::tt::target::ttnn::WhereOp const *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(
const ::ttnn::Tensor &,
const std::optional<std::variant<int, std::vector<int>>> &,
const bool, const std::optional<::tt::tt_metal::MemoryConfig> &,
const std::optional<::ttnn::DeviceComputeKernelConfig> &, float)>
ttnnOp) {

void run(const ::tt::target::ttnn::WhereOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
runWhereOp(op, tensorPool, ::ttnn::where);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());

::ttnn::Tensor out = ttnnOp(in, op->pred(), op->on_true(), op->on_false(),
outputMemoryConfig /* memory_config_arg */);
::ttnn::Tensor out =
::ttnn::where(in, op->pred(), op->on_true(), op->on_false(),
outputMemoryConfig /* memory_config_arg */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::WhereOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
runWhereOp(op, tensorPool, ::ttnn::where);
}
} // namespace tt::runtime::ttnn::operations::where
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_WHERE_H
#define TTNN_RUNTIME_WHERE_H
#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERTIARY_WHERE_H
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERTIARY_WHERE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module @jit_eltwise_select attributes {} {
func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1>
%1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.where"[[C:.*]]
// CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: [[VAL1:%[0-9]+]] = "ttir.where"(%arg0, [[VAL0]]) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
return %1 : tensor<13x37xf32>
}
}
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/simple_where.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ module @jit_eltwise_where {
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.where"[[C:.*]]
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.where"(%{{[0-9]+}}, [[VAL0]])
return %3 : tensor<13x37xf32>
}
}

0 comments on commit bb3f739

Please sign in to comment.