forked from openxla/stablehlo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce a compiler pass to instrument modules (openxla#1841)
This PR adds a compilation pass to insert the recently introduced `interpreter.probe` instrumentation op (openxla#1784) into StableHLO programs. The pass is registered as `stablehlo-instrument` in the `stablehlo-opt` target (open to other naming suggestions). The pass will instrument all SSA values which are not the result of a `stablehlo.constant` operation (as there is no benefit in instrumenting unchanging values). The pass will attempt to use MLIR named location data to derive a suitable `probe_id`, and fallback to an increasing unsigned integer otherwise. For example, given the following simple function: ``` func.func @main(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = stablehlo.constant dense<[0, 0]> : tensor<1x2xi32> %1 = stablehlo.add %0, %arg0 : tensor<1x2xi32> loc("add1") %2 = stablehlo.add %1, %arg1 : tensor<1x2xi32> func.return %2 : tensor<1x2xi32> } ``` This pass would instrument it as follows: ``` func.func @main(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { %0 = stablehlo.constant dense<[0, 0]> : tensor<1x2xi32> %1 = stablehlo.add %0, %arg0 : tensor<1x2xi32> loc("add1") %2 = interpreter.probe %1, probe_id = "add1" : tensor<1x2xi32> %3 = stablehlo.add %2, %arg1 : tensor<1x2xi32> %4 = interpreter.probe %3, probe_id = "1" : tensor<1x2xi32> func.return %4 : tensor<1x2xi32> } ``` Where any operation without an explicit MLIR named location defaults to a unique, increasing `probe_id` integer value (i.e. in the example, the second add op). Note that as described above, constant operations are not instrumented. The pass walks all operations in the `mlir::ModuleOp`, so it will instrument values inside control flow structures as well (while, if, etc).
- Loading branch information
Showing
6 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// RUN: stablehlo-opt --stablehlo-instrument-with-probe="useDebugInfo=true" --split-input-file --verify-diagnostics %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func @instrument_basic_no_location | ||
func.func @instrument_basic_no_location(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { | ||
// CHECK: [[RESULT:%.*]] = interpreter.probe %0, probe_id = "probe1" : tensor<1x2xi32> | ||
// CHECK-NEXT: return [[RESULT]] | ||
%0 = stablehlo.add %arg0, %arg1 : tensor<1x2xi32> | ||
func.return %0 : tensor<1x2xi32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @instrument_basic_location | ||
func.func @instrument_basic_location(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { | ||
// CHECK: [[RESULT:%.*]] = interpreter.probe %0, probe_id = "named_location.1" : tensor<1x2xi32> | ||
// CHECK-NEXT: return [[RESULT]] | ||
%0 = stablehlo.add %arg0, %arg1 : tensor<1x2xi32> loc("named_location") | ||
func.return %0 : tensor<1x2xi32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @do_not_instrument_constant | ||
func.func @do_not_instrument_constant() -> tensor<1xi64> { | ||
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<0> : tensor<1xi64> | ||
// CHECK-NEXT: return [[RESULT]] | ||
%0 = stablehlo.constant dense<0> : tensor<1xi64> | ||
func.return %0 : tensor<1xi64> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @only_instrument_tensor_type | ||
func.func @only_instrument_tensor_type(%arg0: tensor<f32>) -> (!stablehlo.token, tuple<tensor<f32>>, tensor<f32>) { | ||
// CHECK: stablehlo.create_token | ||
// CHECK-NEXT: stablehlo.tuple | ||
// CHECK-NEXT: [[SUM:%.*]] = stablehlo.add | ||
// CHECK-NEXT: interpreter.probe [[SUM]] | ||
// CHECK-NEXT: return | ||
%0 = "stablehlo.create_token"() : () -> !stablehlo.token | ||
%1 = "stablehlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>> | ||
%2 = stablehlo.add %arg0, %arg0 : tensor<f32> | ||
func.return %0, %1, %2 : !stablehlo.token, tuple<tensor<f32>>, tensor<f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @instrument_if | ||
func.func @instrument_if(%arg0: tensor<i1>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) -> tensor<2xi64> { | ||
// CHECK: interpreter.probe {{.*}}, probe_id = "add.1" : tensor<2xi64> | ||
// CHECK: interpreter.probe {{.*}}, probe_id = "probe2" : tensor<2xi64> | ||
%result = "stablehlo.if"(%arg0) ({ | ||
%0 = stablehlo.constant dense<0> : tensor<2xi64> | ||
stablehlo.return %0 : tensor<2xi64> | ||
}, { | ||
%0 = stablehlo.add %arg1, %arg2 : tensor<2xi64> loc("add") | ||
stablehlo.return %0 : tensor<2xi64> | ||
}) : (tensor<i1>) -> (tensor<2xi64>) | ||
func.return %result : tensor<2xi64> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @instrument_loop | ||
func.func @instrument_loop() -> tensor<i64> { | ||
// Instrumented loop condition | ||
// CHECK: [[WHILE:%.*]]:2 = stablehlo.while | ||
// CHECK: [[COND:%.*]] = stablehlo.compare LT | ||
// CHECK-NEXT: [[PROBE1:%.*]] = interpreter.probe [[COND]] | ||
|
||
// Instrumented loop body | ||
// CHECK: interpreter.probe {{.*}}, probe_id = "add.2" : tensor<i64> | ||
// CHECK: interpreter.probe {{.*}}, probe_id = "add.3" : tensor<i64> | ||
|
||
// Instrumented loop return values | ||
// CHECK: interpreter.probe [[WHILE]]#1 | ||
// CHECK-NEXT: interpreter.probe [[WHILE]]#0 | ||
|
||
// int i = 0; | ||
// int sum = 0; | ||
// while (i < 2) { | ||
// sum += 1; | ||
// i += 1; | ||
// } | ||
%init_i = stablehlo.constant dense<0> : tensor<i64> | ||
%init_sum = stablehlo.constant dense<0> : tensor<i64> | ||
%one = stablehlo.constant dense<1> : tensor<i64> | ||
%two = stablehlo.constant dense<2> : tensor<i64> | ||
%results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor<i64>, tensor<i64> | ||
cond { | ||
%cond = stablehlo.compare LT, %arg0, %two : (tensor<i64>, tensor<i64>) -> tensor<i1> | ||
stablehlo.return %cond : tensor<i1> | ||
} do { | ||
%new_sum = stablehlo.add %arg1, %one : tensor<i64> loc("add") | ||
%new_i = stablehlo.add %arg0, %one : tensor<i64> loc("add") | ||
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64> | ||
} | ||
|
||
func.return %results1 : tensor<i64> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* Copyright 2023 The StableHLO Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <string> | ||
|
||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Value.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "stablehlo/dialect/StablehloOps.h" | ||
#include "stablehlo/reference/InterpreterOps.h" | ||
#include "stablehlo/transforms/Passes.h" | ||
|
||
namespace mlir { | ||
namespace stablehlo { | ||
|
||
#define GEN_PASS_DEF_STABLEHLOINSTRUMENTWITHPROBEPASS | ||
#include "stablehlo/transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
class StablehloInstrumentWithProbePass | ||
: public impl::StablehloInstrumentWithProbePassBase< | ||
StablehloInstrumentWithProbePass> { | ||
public: | ||
StablehloInstrumentWithProbePass() | ||
: StablehloInstrumentWithProbePassBase< | ||
StablehloInstrumentWithProbePass>() {} | ||
StablehloInstrumentWithProbePass( | ||
const StablehloInstrumentWithProbePassOptions& opts) | ||
: StablehloInstrumentWithProbePassBase<StablehloInstrumentWithProbePass>( | ||
opts){}; | ||
void runOnOperation() override; | ||
|
||
private: | ||
// Create a uniquely identifying probe_id in the form `probe_id#` where | ||
// `probe_id` is either the MLIR location data (`NamedLoc(probe_id@...)` | ||
// followed by a . separator), or `probe` if debug information is not present | ||
// or used, and # is an increasing positive integer. | ||
std::string getLocationNameOrUniqueId(Location location, unsigned int id); | ||
|
||
// Instrument a specified operation by adding an `interpreter.probe` op for | ||
// each result produced by the operation. | ||
void probeValue(Value value, const std::string& probe_id, OpBuilder& builder); | ||
|
||
// Determine if a given operation is suitable for instrumentation. A suitable | ||
// operation is defined as any operation which is not a ConstantOp, and that | ||
// has at least 1 return value. | ||
bool shouldProbeOp(Operation& op) const; | ||
|
||
// Determine if a given value can be instrumented. Only values that are of | ||
// TensorType are suitable for instrumentation | ||
bool shouldProbeValue(Value value) const; | ||
}; | ||
|
||
std::string StablehloInstrumentWithProbePass::getLocationNameOrUniqueId( | ||
Location location, unsigned int id) { | ||
auto namedLocation = location.dyn_cast<NameLoc>(); | ||
std::string probeName = "probe"; | ||
|
||
if (useDebugInfoOption && namedLocation) | ||
// Append a '.' to the end of the MLIR location data to make it easy to | ||
// extract the location data from the unique ID. | ||
probeName = namedLocation.getName().strref().split('@').first.str() + '.'; | ||
|
||
return probeName + std::to_string(id); | ||
} | ||
|
||
void StablehloInstrumentWithProbePass::probeValue(Value value, | ||
const std::string& probe_id, | ||
OpBuilder& builder) { | ||
builder.setInsertionPointAfterValue(value); | ||
Value instrumentedValue = builder.create<interpreter::ProbeOp>( | ||
value.getLoc(), value, StringAttr::get(&getContext(), probe_id)); | ||
value.replaceAllUsesExcept(instrumentedValue, | ||
instrumentedValue.getDefiningOp()); | ||
} | ||
|
||
void StablehloInstrumentWithProbePass::runOnOperation() { | ||
ModuleOp module = getOperation(); | ||
OpBuilder builder(module); | ||
|
||
// Strictly increasing counter to uniquely identify probe operations when MLIR | ||
// location data is not available/used. | ||
unsigned int probeId = 0; | ||
|
||
module.walk([&](Operation* op) { | ||
if (!shouldProbeOp(*op)) return WalkResult::advance(); | ||
|
||
for (auto res : op->getResults()) { | ||
if (shouldProbeValue(res)) | ||
probeValue(res, getLocationNameOrUniqueId(op->getLoc(), ++probeId), | ||
builder); | ||
} | ||
|
||
return WalkResult::advance(); | ||
}); | ||
} | ||
|
||
bool StablehloInstrumentWithProbePass::shouldProbeOp(Operation& op) const { | ||
if (isa<ConstantOp>(op)) return false; | ||
|
||
// Operations that do not produce values should not be instrumented (ReturnOp, | ||
// TraceOp, etc.) | ||
if (op.getNumResults() == 0) return false; | ||
|
||
return true; | ||
} | ||
|
||
bool StablehloInstrumentWithProbePass::shouldProbeValue(Value value) const { | ||
return value.getType().isa<TensorType>(); | ||
} | ||
|
||
} // namespace | ||
} // namespace stablehlo | ||
} // namespace mlir |