Skip to content

Commit

Permalink
Introduce a compiler pass to instrument modules (openxla#1841)
Browse files Browse the repository at this point in the history
This PR adds a compilation pass to insert the recently introduced
`interpreter.probe` instrumentation op (openxla#1784) into StableHLO programs.
The pass is registered as `stablehlo-instrument` in the `stablehlo-opt`
target (open to other naming suggestions).

The pass will instrument all SSA values which are not the result of a
`stablehlo.constant` operation (as there is no benefit in instrumenting
unchanging values). The pass will attempt to use MLIR named location
data to derive a suitable `probe_id`, and fallback to an increasing
unsigned integer otherwise. For example, given the following simple
function:

```
func.func @main(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = stablehlo.constant dense<[0, 0]> : tensor<1x2xi32>
  %1 = stablehlo.add %0, %arg0 : tensor<1x2xi32> loc("add1")
  %2 = stablehlo.add %1, %arg1 : tensor<1x2xi32>
  func.return %2 : tensor<1x2xi32>
}
```

This pass would instrument it as follows:

```
func.func @main(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = stablehlo.constant dense<[0, 0]> : tensor<1x2xi32>
  %1 = stablehlo.add %0, %arg0 : tensor<1x2xi32> loc("add1")
  %2 = interpreter.probe %1, probe_id = "add1" : tensor<1x2xi32>
  %3 = stablehlo.add %2, %arg1 : tensor<1x2xi32>
  %4 = interpreter.probe %3, probe_id = "1" : tensor<1x2xi32>
  func.return %4 : tensor<1x2xi32>
}
```

Where any operation without an explicit MLIR named location defaults to
a unique, increasing `probe_id` integer value (i.e. in the example, the
second add op). Note that as described above, constant operations are
not instrumented. The pass walks all operations in the `mlir::ModuleOp`,
so it will instrument values inside control flow structures as well
(while, if, etc).
  • Loading branch information
penagos authored Dec 1, 2023
1 parent ad6a900 commit afd8f5b
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 0 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ cc_library(
srcs = [
"stablehlo/transforms/PassPipelines.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloInstrumentWithProbe.cpp",
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
"stablehlo/transforms/StablehloRefineShapes.cpp",
"stablehlo/transforms/VhloLegalizeToStablehlo.cpp",
Expand All @@ -894,6 +895,7 @@ cc_library(
deps = [
":base",
":chlo_ops",
":interpreter_ops",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
":stablehlo_pass_inc_gen",
Expand Down
100 changes: 100 additions & 0 deletions stablehlo/tests/stablehlo_probe_instrumentation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: stablehlo-opt --stablehlo-instrument-with-probe="useDebugInfo=true" --split-input-file --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func @instrument_basic_no_location
func.func @instrument_basic_no_location(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK: [[RESULT:%.*]] = interpreter.probe %0, probe_id = "probe1" : tensor<1x2xi32>
// CHECK-NEXT: return [[RESULT]]
%0 = stablehlo.add %arg0, %arg1 : tensor<1x2xi32>
func.return %0 : tensor<1x2xi32>
}

// -----

// CHECK-LABEL: func @instrument_basic_location
func.func @instrument_basic_location(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK: [[RESULT:%.*]] = interpreter.probe %0, probe_id = "named_location.1" : tensor<1x2xi32>
// CHECK-NEXT: return [[RESULT]]
%0 = stablehlo.add %arg0, %arg1 : tensor<1x2xi32> loc("named_location")
func.return %0 : tensor<1x2xi32>
}

// -----

// CHECK-LABEL: func @do_not_instrument_constant
func.func @do_not_instrument_constant() -> tensor<1xi64> {
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<0> : tensor<1xi64>
// CHECK-NEXT: return [[RESULT]]
%0 = stablehlo.constant dense<0> : tensor<1xi64>
func.return %0 : tensor<1xi64>
}

// -----

// CHECK-LABEL: func @only_instrument_tensor_type
func.func @only_instrument_tensor_type(%arg0: tensor<f32>) -> (!stablehlo.token, tuple<tensor<f32>>, tensor<f32>) {
// CHECK: stablehlo.create_token
// CHECK-NEXT: stablehlo.tuple
// CHECK-NEXT: [[SUM:%.*]] = stablehlo.add
// CHECK-NEXT: interpreter.probe [[SUM]]
// CHECK-NEXT: return
%0 = "stablehlo.create_token"() : () -> !stablehlo.token
%1 = "stablehlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
%2 = stablehlo.add %arg0, %arg0 : tensor<f32>
func.return %0, %1, %2 : !stablehlo.token, tuple<tensor<f32>>, tensor<f32>
}

// -----

// CHECK-LABEL: func @instrument_if
func.func @instrument_if(%arg0: tensor<i1>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) -> tensor<2xi64> {
// CHECK: interpreter.probe {{.*}}, probe_id = "add.1" : tensor<2xi64>
// CHECK: interpreter.probe {{.*}}, probe_id = "probe2" : tensor<2xi64>
%result = "stablehlo.if"(%arg0) ({
%0 = stablehlo.constant dense<0> : tensor<2xi64>
stablehlo.return %0 : tensor<2xi64>
}, {
%0 = stablehlo.add %arg1, %arg2 : tensor<2xi64> loc("add")
stablehlo.return %0 : tensor<2xi64>
}) : (tensor<i1>) -> (tensor<2xi64>)
func.return %result : tensor<2xi64>
}

// -----

// CHECK-LABEL: func @instrument_loop
func.func @instrument_loop() -> tensor<i64> {
// Instrumented loop condition
// CHECK: [[WHILE:%.*]]:2 = stablehlo.while
// CHECK: [[COND:%.*]] = stablehlo.compare LT
// CHECK-NEXT: [[PROBE1:%.*]] = interpreter.probe [[COND]]

// Instrumented loop body
// CHECK: interpreter.probe {{.*}}, probe_id = "add.2" : tensor<i64>
// CHECK: interpreter.probe {{.*}}, probe_id = "add.3" : tensor<i64>

// Instrumented loop return values
// CHECK: interpreter.probe [[WHILE]]#1
// CHECK-NEXT: interpreter.probe [[WHILE]]#0

// int i = 0;
// int sum = 0;
// while (i < 2) {
// sum += 1;
// i += 1;
// }
%init_i = stablehlo.constant dense<0> : tensor<i64>
%init_sum = stablehlo.constant dense<0> : tensor<i64>
%one = stablehlo.constant dense<1> : tensor<i64>
%two = stablehlo.constant dense<2> : tensor<i64>
%results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor<i64>, tensor<i64>
cond {
%cond = stablehlo.compare LT, %arg0, %two : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
} do {
%new_sum = stablehlo.add %arg1, %one : tensor<i64> loc("add")
%new_i = stablehlo.add %arg0, %one : tensor<i64> loc("add")
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}

func.return %results1 : tensor<i64>
}
1 change: 1 addition & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_mlir_dialect_library(StablehloPasses
PARTIAL_SOURCES_INTENDED
PassPipelines.cpp
StablehloCanonicalizeDynamism.cpp
StablehloInstrumentWithProbe.cpp
StablehloLegalizeToVhlo.cpp
StablehloRefineShapes.cpp
VhloLegalizeToStablehlo.cpp
Expand Down
1 change: 1 addition & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace mlir {
namespace stablehlo {

#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
#define GEN_PASS_DECL_STABLEHLOINSTRUMENTWITHPROBEPASS
#define GEN_PASS_DECL_STABLEHLOLEGALIZETOVHLOPASS
#define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS
#define GEN_PASS_DECL_VHLOLEGALIZETOSTABLEHLOPASS
Expand Down
34 changes: 34 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@ def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism",
}];
}

def StablehloInstrumentWithProbePass : Pass<"stablehlo-instrument-with-probe", "ModuleOp"> {
let summary = "Inserts probe instrumentation instructions in a StableHLO "
"program.";
let options = [
Option<"useDebugInfoOption", "useDebugInfo", "bool", /*default=*/"false",
"Whether or not to use location debug data as `probe_id` values.">,
];

let dependentDialects = ["mlir::stablehlo::interpreter::InterpreterDialect"];
let description = [{
Walks through a StableHLO program and inserts a probe instrumentation
operation after each suitable operation (see below for how a suitable
operation is defined). Instrumentation is used to extract intermediate
tensor values from the StableHLO reference interpreter for later comparison
with other runtimes.

All operations are considered suitable for instrumentation, except constant
ops, ops which do not have any tensor return values (i.e. an op that
produces a tuple or a token or no return values will not be instrumented).
Suitable operations will be instrumented regardless of their level of
nesting. That is, operations inside loop/branch regions will also be
instrumented.

Instrumented operations will have their return values written to disk using
the NumPy data format as they are executed. If the `useDebugInfo` pass
option is enabled, location debug information will be used when available to
uniquely identify instrumented tensor values (i.e. the pass will extract
`probe_id` from `NamedLoc(probe_id@<...>)` and use the format `probe_id`.#).
Otherwise, instrumented values will be referred to in the increasing
sequence: `probe1`, `probe2`, ... See `interpreter.probe` for additional
information on how data is serialized.
}];
}

def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> {
let summary = "Legalize StableHLO to VHLO.";
let dependentDialects = ["mlir::vhlo::VhloDialect"];
Expand Down
128 changes: 128 additions & 0 deletions stablehlo/transforms/StablehloInstrumentWithProbe.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/* Copyright 2023 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/reference/InterpreterOps.h"
#include "stablehlo/transforms/Passes.h"

namespace mlir {
namespace stablehlo {

#define GEN_PASS_DEF_STABLEHLOINSTRUMENTWITHPROBEPASS
#include "stablehlo/transforms/Passes.h.inc"

namespace {

class StablehloInstrumentWithProbePass
: public impl::StablehloInstrumentWithProbePassBase<
StablehloInstrumentWithProbePass> {
public:
StablehloInstrumentWithProbePass()
: StablehloInstrumentWithProbePassBase<
StablehloInstrumentWithProbePass>() {}
StablehloInstrumentWithProbePass(
const StablehloInstrumentWithProbePassOptions& opts)
: StablehloInstrumentWithProbePassBase<StablehloInstrumentWithProbePass>(
opts){};
void runOnOperation() override;

private:
// Create a uniquely identifying probe_id in the form `probe_id#` where
// `probe_id` is either the MLIR location data (`NamedLoc(probe_id@...)`
// followed by a . separator), or `probe` if debug information is not present
// or used, and # is an increasing positive integer.
std::string getLocationNameOrUniqueId(Location location, unsigned int id);

// Instrument a specified operation by adding an `interpreter.probe` op for
// each result produced by the operation.
void probeValue(Value value, const std::string& probe_id, OpBuilder& builder);

// Determine if a given operation is suitable for instrumentation. A suitable
// operation is defined as any operation which is not a ConstantOp, and that
// has at least 1 return value.
bool shouldProbeOp(Operation& op) const;

// Determine if a given value can be instrumented. Only values that are of
// TensorType are suitable for instrumentation
bool shouldProbeValue(Value value) const;
};

std::string StablehloInstrumentWithProbePass::getLocationNameOrUniqueId(
Location location, unsigned int id) {
auto namedLocation = location.dyn_cast<NameLoc>();
std::string probeName = "probe";

if (useDebugInfoOption && namedLocation)
// Append a '.' to the end of the MLIR location data to make it easy to
// extract the location data from the unique ID.
probeName = namedLocation.getName().strref().split('@').first.str() + '.';

return probeName + std::to_string(id);
}

void StablehloInstrumentWithProbePass::probeValue(Value value,
const std::string& probe_id,
OpBuilder& builder) {
builder.setInsertionPointAfterValue(value);
Value instrumentedValue = builder.create<interpreter::ProbeOp>(
value.getLoc(), value, StringAttr::get(&getContext(), probe_id));
value.replaceAllUsesExcept(instrumentedValue,
instrumentedValue.getDefiningOp());
}

void StablehloInstrumentWithProbePass::runOnOperation() {
ModuleOp module = getOperation();
OpBuilder builder(module);

// Strictly increasing counter to uniquely identify probe operations when MLIR
// location data is not available/used.
unsigned int probeId = 0;

module.walk([&](Operation* op) {
if (!shouldProbeOp(*op)) return WalkResult::advance();

for (auto res : op->getResults()) {
if (shouldProbeValue(res))
probeValue(res, getLocationNameOrUniqueId(op->getLoc(), ++probeId),
builder);
}

return WalkResult::advance();
});
}

bool StablehloInstrumentWithProbePass::shouldProbeOp(Operation& op) const {
if (isa<ConstantOp>(op)) return false;

// Operations that do not produce values should not be instrumented (ReturnOp,
// TraceOp, etc.)
if (op.getNumResults() == 0) return false;

return true;
}

bool StablehloInstrumentWithProbePass::shouldProbeValue(Value value) const {
return value.getType().isa<TensorType>();
}

} // namespace
} // namespace stablehlo
} // namespace mlir

0 comments on commit afd8f5b

Please sign in to comment.