Skip to content

Commit

Permalink
[XLA:LatencyHidingScheduler] Fix issues with scheduling_group_id an…
Browse files Browse the repository at this point in the history
…notations. Added support for:

1) Update `ready_num_nodes_with_annotation[id]` when processing the roots of the graph.
2) Allow appending an annotation id to the `ready_annotations` vector while scheduling another annotation.

PiperOrigin-RevId: 702898764
  • Loading branch information
seherellis authored and tensorflower-gardener committed Dec 5, 2024
1 parent 2892418 commit 9f45db6
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 11 deletions.
34 changes: 24 additions & 10 deletions third_party/xla/xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,21 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
<< memory_pressure_tracker.memory_usage();
sched_state.ready_set.insert(sched_state.ready_set.end(), roots.begin(),
roots.end());
for (HloGraphNode* root : roots) {
int64_t annotation = root->GetAnnotation();
if (annotation != -1) {
sched_state.ready_num_nodes_with_annotation[annotation]++;
VLOG(2) << "Annotation: " << annotation
<< " ready_num_nodes_with_annotation: "
<< sched_state.ready_num_nodes_with_annotation[annotation]
<< " num_root_instructions: "
<< annotation_tracker_->GetNumRootInstructions(annotation);
if (annotation_tracker_->GetNumRootInstructions(annotation) ==
sched_state.ready_num_nodes_with_annotation[annotation]) {
sched_state.ready_annotations.push_back(annotation);
}
}
}
// Schedule in order bottom up.
while (!sched_state.ready_set.empty() || !sched_state.nop_set.empty()) {
VLOG(10) << "Current ready time: " << sched_state.current_time;
Expand All @@ -2329,16 +2344,15 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
return absl::StrJoin(sched_state.ready_set, "\n", LogFormatter());
}());
if (!sched_state.ready_annotations.empty()) {
// TODO (sacer): If more than one annotations are ready, decide the order
// with a heuristic.
for (int64_t annotation : sched_state.ready_annotations) {
VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------";
sched_state.ongoing_annotation = annotation;
TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state));
VLOG(2) << "------- END ANNOTATION: " << annotation << " --------";
sched_state.ongoing_annotation = -1;
}
sched_state.ready_annotations.clear();
// TODO (sacer): If more than one annotations are ready, decide which one
// to schedule next with a heuristic.
int64_t annotation = sched_state.ready_annotations.back();
sched_state.ready_annotations.pop_back();
VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------";
sched_state.ongoing_annotation = annotation;
TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state));
VLOG(2) << "------- END ANNOTATION: " << annotation << " --------";
sched_state.ongoing_annotation = -1;
continue;
}
TF_RETURN_IF_ERROR(SchedulingStep(&sched_state));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class SchedulerCore {
class AnnotationTracker {
public:
explicit AnnotationTracker(const HloModule* module) : module_(module) {
for (const HloComputation* comp : module_->computations()) {
for (const HloComputation* comp : module_->MakeNonfusionComputations()) {
absl::flat_hash_set<int64_t> annotations;
for (const HloInstruction* instr : comp->instructions()) {
if (auto annotation = GetAnnotation(instr)) {
Expand Down
102 changes: 102 additions & 0 deletions third_party/xla/xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3595,4 +3595,106 @@ ENTRY entry {
EXPECT_LT(GetIndex(new_instruction_sequence, "c0"),
GetIndex(new_instruction_sequence, "cp2d"));
}

TEST_F(LatencyHidingSchedulerTest, SchedulingAnnotationMakesAnotherGroupReady) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
fused_computation {
param0 = f32[16,64,256]{2,1,0} parameter(0)
param1 = f32[16,64,256]{2,1,0} parameter(1)
ROOT c0 = f32[16,256,256]{2,1,0} convolution(param0, param1), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"}
}
fused_computation.1 {
param0.1 = f32[16,256,256]{2,1,0} parameter(0)
param1.1 = f32[16,256,256]{2,1,0} parameter(1)
ROOT c1 = f32[1,256,256]{2,1,0} convolution(param0.1, param1.1), window={size=16 stride=15}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"}
}
ENTRY entry {
p0 = f32[16,64,256]{2,1,0} parameter(0)
p1 = f32[128,2048,2048]{2,1,0} parameter(1)
cp0s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"}
cp0d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp0s), frontend_attributes={_scheduling_group_id="0"}
cp1s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(cp0d), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="1"}
cp1d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="1"}
f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"}
f1 = f32[1,256,256]{2,1,0} fusion(f0, f0), kind=kOutput, calls=fused_computation.1, frontend_attributes={_scheduling_group_id="1"}
ROOT tuple = (f32[128,2048,2048]{2,1,0}, f32[1,256,256]{2,1,0}) tuple(cp1d, f1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.aggressive_scheduling_policies = true;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

// cp0 and cp1 overlap f0 and f1, respectively.
EXPECT_LT(GetIndex(new_instruction_sequence, "cp0s"),
GetIndex(new_instruction_sequence, "f0"));
EXPECT_LT(GetIndex(new_instruction_sequence, "f0"),
GetIndex(new_instruction_sequence, "cp0d"));
EXPECT_LT(GetIndex(new_instruction_sequence, "cp1s"),
GetIndex(new_instruction_sequence, "f1"));
EXPECT_LT(GetIndex(new_instruction_sequence, "f1"),
GetIndex(new_instruction_sequence, "cp1d"));
}

TEST_F(LatencyHidingSchedulerTest, AnnotatedRoot) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
fused_computation {
param0 = f32[16,64,256]{2,1,0} parameter(0)
param1 = f32[16,64,256]{2,1,0} parameter(1)
ROOT c0 = f32[16,256,256]{2,1,0} convolution(param0, param1), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"}
}
ENTRY entry {
p0 = f32[16,64,256]{2,1,0} parameter(0)
p1 = f32[128,2048,2048]{2,1,0} parameter(1)
cp0s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"}
cp0d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp0s), frontend_attributes={_scheduling_group_id="0"}
ROOT f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.aggressive_scheduling_policies = true;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

// cp0 overlaps f0.
EXPECT_LT(GetIndex(new_instruction_sequence, "cp0s"),
GetIndex(new_instruction_sequence, "f0"));
EXPECT_LT(GetIndex(new_instruction_sequence, "f0"),
GetIndex(new_instruction_sequence, "cp0d"));
}
} // namespace xla

0 comments on commit 9f45db6

Please sign in to comment.