diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 6d3e1d07b56af8..9ded2dc023a649 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -2314,6 +2314,21 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { << memory_pressure_tracker.memory_usage(); sched_state.ready_set.insert(sched_state.ready_set.end(), roots.begin(), roots.end()); + for (HloGraphNode* root : roots) { + int64_t annotation = root->GetAnnotation(); + if (annotation != -1) { + sched_state.ready_num_nodes_with_annotation[annotation]++; + VLOG(2) << "Annotation: " << annotation + << " ready_num_nodes_with_annotation: " + << sched_state.ready_num_nodes_with_annotation[annotation] + << " num_root_instructions: " + << annotation_tracker_->GetNumRootInstructions(annotation); + if (annotation_tracker_->GetNumRootInstructions(annotation) == + sched_state.ready_num_nodes_with_annotation[annotation]) { + sched_state.ready_annotations.push_back(annotation); + } + } + } // Schedule in order bottom up. while (!sched_state.ready_set.empty() || !sched_state.nop_set.empty()) { VLOG(10) << "Current ready time: " << sched_state.current_time; @@ -2329,16 +2344,15 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { return absl::StrJoin(sched_state.ready_set, "\n", LogFormatter()); }()); if (!sched_state.ready_annotations.empty()) { - // TODO (sacer): If more than one annotations are ready, decide the order - // with a heuristic. - for (int64_t annotation : sched_state.ready_annotations) { - VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------"; - sched_state.ongoing_annotation = annotation; - TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state)); - VLOG(2) << "------- END ANNOTATION: " << annotation << " --------"; - sched_state.ongoing_annotation = -1; - } - sched_state.ready_annotations.clear(); + // TODO (sacer): If more than one annotations are ready, decide which one + // to schedule next with a heuristic. + int64_t annotation = sched_state.ready_annotations.back(); + sched_state.ready_annotations.pop_back(); + VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------"; + sched_state.ongoing_annotation = annotation; + TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state)); + VLOG(2) << "------- END ANNOTATION: " << annotation << " --------"; + sched_state.ongoing_annotation = -1; continue; } TF_RETURN_IF_ERROR(SchedulingStep(&sched_state)); diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index ef158167a0cb3f..d24c734271f6e1 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -337,7 +337,7 @@ class SchedulerCore { class AnnotationTracker { public: explicit AnnotationTracker(const HloModule* module) : module_(module) { - for (const HloComputation* comp : module_->computations()) { + for (const HloComputation* comp : module_->MakeNonfusionComputations()) { absl::flat_hash_set annotations; for (const HloInstruction* instr : comp->instructions()) { if (auto annotation = GetAnnotation(instr)) { diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index 6d77e92fb4bef9..a21a7657736477 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -3595,4 +3595,106 @@ ENTRY entry { EXPECT_LT(GetIndex(new_instruction_sequence, "c0"), GetIndex(new_instruction_sequence, "cp2d")); } + +TEST_F(LatencyHidingSchedulerTest, SchedulingAnnotationMakesAnotherGroupReady) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +fused_computation { + param0 = f32[16,64,256]{2,1,0} parameter(0) + param1 = f32[16,64,256]{2,1,0} parameter(1) + ROOT c0 = f32[16,256,256]{2,1,0} convolution(param0, param1), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"} +} + +fused_computation.1 { + param0.1 = f32[16,256,256]{2,1,0} parameter(0) + param1.1 = f32[16,256,256]{2,1,0} parameter(1) + ROOT c1 = f32[1,256,256]{2,1,0} convolution(param0.1, param1.1), window={size=16 stride=15}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"} +} + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[128,2048,2048]{2,1,0} parameter(1) + cp0s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"} + cp0d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp0s), frontend_attributes={_scheduling_group_id="0"} + cp1s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(cp0d), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="1"} + cp1d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="1"} + f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"} + f1 = f32[1,256,256]{2,1,0} fusion(f0, f0), kind=kOutput, calls=fused_computation.1, frontend_attributes={_scheduling_group_id="1"} + ROOT tuple = (f32[128,2048,2048]{2,1,0}, f32[1,256,256]{2,1,0}) tuple(cp1d, f1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + auto sched_config = GetDefaultSchedConfig(); + sched_config.aggressive_scheduling_policies = true; + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config, + std::make_unique()) + .ok()); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::vector new_instruction_sequence = + module_schedule.sequence(hlo_module->entry_computation()).instructions(); + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } + + // cp0 and cp1 overlap f0 and f1, respectively. + EXPECT_LT(GetIndex(new_instruction_sequence, "cp0s"), + GetIndex(new_instruction_sequence, "f0")); + EXPECT_LT(GetIndex(new_instruction_sequence, "f0"), + GetIndex(new_instruction_sequence, "cp0d")); + EXPECT_LT(GetIndex(new_instruction_sequence, "cp1s"), + GetIndex(new_instruction_sequence, "f1")); + EXPECT_LT(GetIndex(new_instruction_sequence, "f1"), + GetIndex(new_instruction_sequence, "cp1d")); +} + +TEST_F(LatencyHidingSchedulerTest, AnnotatedRoot) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +fused_computation { + param0 = f32[16,64,256]{2,1,0} parameter(0) + param1 = f32[16,64,256]{2,1,0} parameter(1) + ROOT c0 = f32[16,256,256]{2,1,0} convolution(param0, param1), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"} +} + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[128,2048,2048]{2,1,0} parameter(1) + cp0s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"} + cp0d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp0s), frontend_attributes={_scheduling_group_id="0"} + ROOT f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + auto sched_config = GetDefaultSchedConfig(); + sched_config.aggressive_scheduling_policies = true; + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config, + std::make_unique()) + .ok()); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::vector new_instruction_sequence = + module_schedule.sequence(hlo_module->entry_computation()).instructions(); + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } + + // cp0 overlaps f0. + EXPECT_LT(GetIndex(new_instruction_sequence, "cp0s"), + GetIndex(new_instruction_sequence, "f0")); + EXPECT_LT(GetIndex(new_instruction_sequence, "f0"), + GetIndex(new_instruction_sequence, "cp0d")); +} } // namespace xla