Skip to content

Commit

Permalink
Enable sampling for inference profile and expose them in inference pr…
Browse files Browse the repository at this point in the history
…ofile tool.

PiperOrigin-RevId: 703734086
  • Loading branch information
cliveverghese authored and tensorflower-gardener committed Dec 7, 2024
1 parent 90021f3 commit 7cd13b1
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/profiler/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ cc_library(
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events",
Expand Down Expand Up @@ -1222,12 +1223,14 @@ cc_library(
":inference_stats",
":inference_stats_combiner",
":inference_stats_grouping",
":inference_stats_sampler",
":preprocess_single_host_xplane",
":repository",
"//tensorflow/core/profiler/protobuf:inference_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:event_span",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/tsl/profiler/utils:device_utils",
"@local_xla//xla/tsl/profiler/utils:group_events",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/profiler/convert/inference_stats_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct SampledPerModelInferenceStats {
};

// All the sampled inference stats of a profile.
// TODO: Move to use SampledInferenceStatsProto if feasible.
using SampledInferenceStats =
absl::flat_hash_map<int /*model_index*/, SampledPerModelInferenceStats>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ limitations under the License.
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/profiler/utils/device_utils.h"
#include "xla/tsl/profiler/utils/group_events.h"
#include "xla/tsl/profiler/utils/tpu_xplane_utils.h"
#include "tensorflow/core/profiler/convert/inference_stats.h"
#include "tensorflow/core/profiler/convert/inference_stats_combiner.h"
#include "tensorflow/core/profiler/convert/inference_stats_grouping.h"
#include "tensorflow/core/profiler/convert/inference_stats_sampler.h"
#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h"
#include "tensorflow/core/profiler/convert/repository.h"
#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h"
Expand All @@ -33,8 +35,35 @@ limitations under the License.

namespace tensorflow::profiler {

namespace {
SampledInferenceStatsProto GetSampledInferenceStatsProto(
const InferenceStats& inference_stats, absl::string_view request_column,
absl::string_view batch_column) {
SampledInferenceStatsProto result;
SampledInferenceStats sampled_stats =
SampleInferenceStats(request_column, batch_column, inference_stats);
for (const auto& [model_index, samples] : sampled_stats) {
SampledPerModelInferenceStatsProto per_model_stats;
for (const auto& [request, percentile] : samples.sampled_requests) {
RequestDetail request_detail = *request;
request_detail.set_percentile(percentile);
*per_model_stats.add_sampled_requests() = request_detail;
}
for (const auto& [batch, percentile] : samples.sampled_batches) {
BatchDetail batch_detail = *batch;
batch_detail.set_percentile(percentile);
*per_model_stats.add_sampled_batches() = batch_detail;
}
result.mutable_sampled_inference_stats_per_model()->insert(
{model_index, per_model_stats});
}
return result;
}
} // namespace

absl::Status ConvertMultiXSpaceToInferenceStats(
const SessionSnapshot& session_snapshot, InferenceStats* inference_stats) {
const SessionSnapshot& session_snapshot, absl::string_view request_column,
absl::string_view batch_column, InferenceStats* inference_stats) {
for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<XSpace> xspace,
session_snapshot.GetXSpace(i));
Expand All @@ -51,6 +80,9 @@ absl::Status ConvertMultiXSpaceToInferenceStats(
CombineInferenceStatsResult(i, inference_stats_per_host, inference_stats);
}
RegroupInferenceStatsByModel(inference_stats);
*inference_stats->mutable_sampled_inference_stats() =
GetSampledInferenceStatsProto(*inference_stats, request_column,
batch_column);
return absl::OkStatus();
}
} // namespace tensorflow::profiler
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_

#include "absl/strings/string_view.h"
#include "tensorflow/core/profiler/convert/repository.h"
#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h"
namespace tensorflow::profiler {
absl::Status ConvertMultiXSpaceToInferenceStats(
const SessionSnapshot& session_snapshot, InferenceStats* inference_stats);
const SessionSnapshot& session_snapshot, absl::string_view request_column,
absl::string_view batch_column, InferenceStats* inference_stats);
}

#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_
13 changes: 9 additions & 4 deletions tensorflow/core/profiler/convert/xplane_to_tools_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/profiler/convert/xplane_to_trace_events.h"
Expand Down Expand Up @@ -332,10 +333,14 @@ absl::StatusOr<std::string> ConvertDcnCollectiveStatsToToolData(
}

absl::StatusOr<std::string> ConvertMultiXSpacesToInferenceStats(
const SessionSnapshot& session_snapshot) {
const SessionSnapshot& session_snapshot, const ToolOptions& options) {
InferenceStats inference_stats;
TF_RETURN_IF_ERROR(
ConvertMultiXSpaceToInferenceStats(session_snapshot, &inference_stats));
std::string request_column =
GetParamWithDefault<std::string>(options, "request_column", "");
std::string batch_column =
GetParamWithDefault<std::string>(options, "batch_column", "");
TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats(
session_snapshot, request_column, batch_column, &inference_stats));
return inference_stats.SerializeAsString();
}

Expand Down Expand Up @@ -375,7 +380,7 @@ absl::StatusOr<std::string> ConvertMultiXSpacesToToolData(
} else if (tool_name == "_xplane.pb") { // internal test only.
return PreprocessXSpace(session_snapshot);
} else if (tool_name == "inference_profile") {
return ConvertMultiXSpacesToInferenceStats(session_snapshot);
return ConvertMultiXSpacesToInferenceStats(session_snapshot, options);
} else {
return errors::InvalidArgument(
"Can not find tool: ", tool_name,
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/profiler/protobuf/inference_stats.proto
Original file line number Diff line number Diff line change
Expand Up @@ -281,5 +281,18 @@ message InferenceStats {
// A database of tensor patterns.
optional TensorPatternDatabase tensor_pattern_db = 6;

optional SampledInferenceStatsProto sampled_inference_stats = 7;

reserved 1, 2; // were processing_stats, session_run_times
}

message SampledPerModelInferenceStatsProto {
repeated RequestDetail sampled_requests = 1;
repeated BatchDetail sampled_batches = 2;
}

message SampledInferenceStatsProto {
// Map from model index to the Sampled Stats.
map<int32 /* host-id */, SampledPerModelInferenceStatsProto>
sampled_inference_stats_per_model = 1;
}

0 comments on commit 7cd13b1

Please sign in to comment.