Skip to content

Commit

Permalink
Adding dumping functionality for HloUnoptimizedSnapshot.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703076882
  • Loading branch information
Aliia Khasanova authored and tensorflower-gardener committed Dec 5, 2024
1 parent 4c3aa84 commit d6ccc6a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ cc_library(
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":hlo_graph_dumper",
":hlo_proto_cc",
":hlo_proto_util",
"//xla:util",
"//xla:xla_proto_cc",
Expand Down
31 changes: 31 additions & 0 deletions third_party/xla/xla/service/dump.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_graph_dumper.h"
#include "xla/service/hlo_proto_util.h"
#include "xla/tsl/lib/io/zlib_compression_options.h"
Expand Down Expand Up @@ -884,6 +885,36 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
DumpToFileInDirImpl(filename, pb, canonical_opts);
}

void DumpHloUnoptimizedSnapshotIfEnabled(
const HloUnoptimizedSnapshot& hlo_snapshot, const DebugOptions& opts) {
CanonicalDebugOptions canonical_opts(opts);
std::string name = hlo_snapshot.hlo_module().name();
if (!canonical_opts.should_dump_module(name) ||
!canonical_opts.dump_unoptimized_snapshots) {
return;
}

if (hlo_snapshot.partitions_size() == 0) {
LOG(ERROR) << "Refusing to write unoptimized HLO snapshot proto for module "
<< name << ": no partitions input found.";
return;
}
int64_t execution_count;
{
static absl::Mutex mu(absl::kConstInit);
static auto& module_id_to_execution_count ABSL_GUARDED_BY(mu) =
*new absl::flat_hash_map<int64_t, int64_t>();
absl::MutexLock lock(&mu);
execution_count =
module_id_to_execution_count[hlo_snapshot.hlo_module().id()]++;
}
std::string filename = FilenameFor(
hlo_snapshot.hlo_module().id(), hlo_snapshot.hlo_module().name(), "",
absl::StrFormat("execution_%04d.hlo_unoptimized_snapshot",
execution_count));
DumpProtobufToFile(hlo_snapshot, opts, filename, nullptr);
}

void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules) {
absl::flat_hash_set<int64_t> dumped_module_ids;
for (const HloModule* module : modules) {
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/dump.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ void DumpHloSnapshotIfEnabled(const HloModule& module,
void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
const DebugOptions& opts);

// Dumps the given HloUnoptimisedSnapshot to the module's xla_dump_dir, if this
// is enabled.
void DumpHloUnoptimizedSnapshotIfEnabled(
const HloUnoptimizedSnapshot& hlo_snapshot, const DebugOptions& opts);

void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules);

// Returns true if we should dump data for an HloModule. This is useful if you
Expand Down
29 changes: 29 additions & 0 deletions third_party/xla/xla/service/dump_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,34 @@ TEST(DumpTest, DumpFdoProfileToFileWhenEnabled) {
EXPECT_TRUE(absl::StrContains(data, fdo_profile));
}

TEST(DumpTest, DumpHloUnoptimizedSnapshot) {
HloUnoptimizedSnapshot hlo_snapshot;
HloModuleProto module;
module.set_name("hello");
*hlo_snapshot.mutable_hlo_module() = module;
*hlo_snapshot.add_partitions() = HloInputs();

HloModuleConfig config;
DebugOptions options = config.debug_options();

options.set_xla_dump_to(tsl::testing::TmpDir());
options.set_xla_dump_hlo_as_text(true);
options.set_xla_gpu_dump_hlo_unoptimized_snapshots(true);
config.set_debug_options(options);

DumpHloUnoptimizedSnapshotIfEnabled(hlo_snapshot, options);

std::vector<std::string> matches;
std::string pattern_filename =
tsl::io::JoinPath(tsl::testing::TmpDir(), "*hlo_unoptimized_snapshot*");
TF_ASSERT_OK(
tsl::Env::Default()->GetMatchingPaths(pattern_filename, &matches));
EXPECT_THAT(matches, Not(IsEmpty()));

HloUnoptimizedSnapshot hlo_snapshot_loaded;
TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), matches.front(),
&hlo_snapshot_loaded));
EXPECT_EQ(hlo_snapshot_loaded.hlo_module().name(), module.name());
}
} // namespace
} // namespace xla

0 comments on commit d6ccc6a

Please sign in to comment.