From c0d7559c7ee1b8a116d1b7198e0fa8d3e1241e62 Mon Sep 17 00:00:00 2001 From: zjjott Date: Tue, 27 Feb 2024 17:27:07 +0800 Subject: [PATCH] add dag --- .../experimental/auto_reorder/auto_reorder.cc | 6 +- .../auto_reorder/auto_reorder_solver.cc | 78 ++++++- .../auto_reorder/auto_reorder_solver.h | 83 ++++++-- .../auto_reorder/auto_reorder_test.cc | 32 +-- xla/service/gpu/gpu_hlo_schedule.h | 1 + xla/service/latency_hiding_scheduler_test.cc | 192 ++++++++++++++++++ xla/xla.proto | 4 +- 7 files changed, 356 insertions(+), 40 deletions(-) diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder.cc b/xla/hlo/experimental/auto_reorder/auto_reorder.cc index cccd83a2d2278..b8f782a25625d 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder.cc @@ -130,8 +130,10 @@ AutoReorderPass::ScheduleComputation(HloComputation* computation) { std::vector new_schedule; auto sorted_nodes = solver_->GetSortedNodes(); for (auto node : sorted_nodes) { - new_schedule.push_back( - const_cast(node->GetValue())); + auto insts = node->GetValues(); + for (auto inst : insts) { + new_schedule.push_back(const_cast(inst)); + } } return new_schedule; } diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc index 991e03c135311..9e9dfb80f4ac3 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc @@ -5,6 +5,11 @@ #ifndef LPSchedulerFunc(return_type) #define LPSchedulerFunc(return_type) template return_type LinearProgramScheduler #endif + +#ifndef LPContainerDAGFunc(return_type) +#define LPContainerDAGFunc(return_type) template return_type LPContainerDAG +#endif + namespace xla { using IntVar = operations_research::sat::IntVar; using CpModelBuilder = operations_research::sat::CpModelBuilder; @@ -100,7 +105,7 @@ LPSchedulerFunc(tsl::Status)::Solve() { max_execution_time += cost; } } - SetHorizon(max_execution_time * reorder::kChannelNumber); + SetHorizon(reorder::get_horizon(max_execution_time)); for (auto node : nodes_) { VLOG(3) << "Add to scheduler" << node->GetName(); @@ -141,7 +146,7 @@ LPSchedulerFunc(tsl::Status)::Solve() { parameters.set_log_to_stdout(true); parameters.set_log_search_progress(true); } - parameters.set_num_search_workers(8); + parameters.set_num_search_workers(reorder::get_cpu_number()); const operations_research::sat::CpSolverResponse response = operations_research::sat::SolveWithParameters(cp_model_.Build(), parameters); @@ -294,6 +299,75 @@ LPSchedulerFunc(std::vector)::GetSortedNodes() { [this](ContainerType* a, ContainerType* b) { return a->GetStart() < b->GetStart(); }); return sorted_nodes; } + +LPContainerDAGFunc(bool)::IsIn(LPContainer* a) { + return operands_.find(a) != operands_.end(); +}; +LPContainerDAGFunc(void)::AddToDAG(LPContainer* child){ + inner_elements.push_back(child); + if(IsIn(child)){ + operands_.erase(child); + } + for(auto dep_pair: child->GetDeps()){ + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair);//if cost need store ? + operands_.insert(dep); + } +} +LPContainerDAGFunc(Status)::MergeFrom(LPContainerDAG* other){ + /* + step 1: this inner_elements must have dep to other's inner_elements. so that link to other's inner_elements change to inner edges + */ + + // maintain this LPContainerDAG inner_elements's deps,so that can create inner edge after merge + // {dep: [,]} + std::unordered_map< + int, + std::vector*, CostType>> + > dep_operands2element; + + for(LPContainer* element: GetInnerElements()){ + // from operate to element, there are outer edge,maybe convert to inner edge + for(auto dep_pair: element->GetDeps()){ + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); + if(dep_operands2element.find(dep->UUID())==dep_operands2element.end()){ + dep_operands2element[dep->UUID()] = std::vector*, CostType>>(); + } + dep_operands2element[dep->UUID()].push_back(std::make_tuple(element, cost)); + } + } + //other + for(auto child:other->GetInnerElements()){ + // there child must in inner_elements_deps + TF_RET_CHECK(dep_operands2element.find(child->UUID())==dep_operands2element.end() + )<<"child is not in dep_operands2element"; + for(auto dep_pair: dep_operands2element[child->UUID()]){ + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); + if(dep_operands2element.find(dep->UUID())!=dep_operands2element.end()){ + for(auto element_pair: dep_operands2element[dep->UUID()]){ + auto element = std::get<0>(element_pair); + auto cost = std::get<1>(element_pair); + //create edge between element and child + DAGEdge edge; + edge.from = element; + edge.to = child; + edge.cost = cost; + edges_.push_back(edge); + } + } + } + + AddToDAG(child); + + }; +} template class LPContainer; template class LinearProgramScheduler, const HloInstruction*>; + + +template class LPContainerDAG; +// template class LinearProgramScheduler, const HloInstruction*>; + } // namespace xla diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h index 6f0246657a244..9d4bdb8a76115 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h @@ -5,7 +5,7 @@ #include #include #include - +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" @@ -17,10 +17,19 @@ namespace xla { using CpModelBuilder = operations_research::sat::CpModelBuilder; using IntervalVar = operations_research::sat::IntervalVar; namespace reorder{ - const uint32_t ksolveTimeout = 30; // 30s + const uint32_t ksolveTimeout = 60; // 30s static const int kChannelNumber = 2; - bool solve_debug=false; + int get_horizon(int max_time){ + //scale + return max_time*1.2; + } + bool solve_debug=true; + //get cpu number of current machine + int get_cpu_number(){ + // return 8; + return std::thread::hardware_concurrency(); + } } enum class NodeType { kCompute = 0, kCommunication = 1 }; @@ -59,7 +68,7 @@ class LPNode{ template class LPContainer{ public: - + //create a LPContainer with inner_element, cost and type LPContainer(ElementType inner_element, CostType cost, NodeType type) : inner_element_(inner_element), cost_(cost), type_(type) { uuid_ = reinterpret_cast(this); @@ -71,15 +80,20 @@ class LPContainer{ CostType GetCost() const { return cost_; } void SetStart(CostType start) { startat_ = start; } CostType GetStart() { return startat_; } + // Get the type of the container: compute or communication bool IsComputation() const { return type_ == NodeType::kCompute; } bool IsCommunication() const { return type_ == NodeType::kCommunication; } NodeType GetType() const { return type_; } - bool HasValue() const { return inner_element_ != nullptr; } - ElementType GetValue() const { return inner_element_; } + + const bool HasValue() { return inner_element_ != nullptr; } + const std::vector GetValues() { return std::vector{inner_element_}; } + // Add a dep of this container, cost is the cost of the edge; this Container will be executed after dep void AddDep(LPContainer* dep, CostType cost); + // Get all deps of the container const std::vector> GetDeps() const { return deps_; } + //when a container is frozen, it can not be add deps void Freeze() { frozen_ = true; } private: @@ -97,32 +111,59 @@ class LPContainer{ // LPContainerDAG is a graph of container, it can be used to store the DAG of container // be used as a atomic unit of LPContainer template -class LPContainerDAG{ +class LPContainerDAG: public LPContainer{ //we can use InstructionDAG to get memory effect order public: // maintain a DAG of inner elements struct DAGEdge{ - LPContainerDAG* from; - LPContainerDAG* to; + LPContainer* from; + LPContainer* to; CostType cost; }; - //create a LPContainerDAG with - LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): cost_(cost), type_(type){ - inner_elements.push_back(LPContainer(inner_element, cost, type)); - }; - bool IsIn(LPContainerDAG *a){ - return users_.find(a) != users_.end(); + //create a LPContainerDAG with one element + LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): LPContainer(inner_element,cost,type){ + //TODO: there should not create element? + auto ele = new LPContainer(inner_element, cost, type); + inner_elements.push_back(ele); }; + bool IsIn(LPContainer* a); //which container can be put together:1. they have the same type 2. they have dep between them - static bool CanFused(LPContainerDAG* a, LPContainerDAG* b){ + // static bool CanFused(LPContainerDAG* a, LPContainerDAG* b); - }; - // AddChild + //override LPContainer + const std::string GetName(){ + std::string name = "LPContainerDAG{"; + for(auto ele: inner_elements){ + name += ele->GetName(); + name+="\n"; + } + name+="}"; + return name; + } + const int UUID() { return inner_elements[0]->UUID(); } + const bool HasValue() { return inner_elements.size()>0;} + const std::vector GetValues() { + std::vector values; + for(auto ele: inner_elements){ + for(auto inst:ele->GetValues()){ + values.push_back(inst); + } + } + return values; + } + // AddChild, child should maintain the deps before + void AddToDAG(LPContainer* child); + const std::vector*> GetInnerElements() const{ + return inner_elements; + } + //merge other LPContainerDAG to this LPContainerDAG,then destroy other LPContainerDAG + Status MergeFrom(LPContainerDAG* other); private: - std::set users_; - std::set operands_; - std::vector> inner_elements; + std::set*> operands_; + std::vector*> inner_elements; + //maintain edges between inner_elements + std::vector edges_; CostType cost_; CostType startat_; NodeType type_; diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc index 0abfe8c55faa0..dabfb604601d8 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc @@ -436,7 +436,7 @@ ENTRY %elementwise { insts2cost.push_back(std::make_tuple(ar_done, 1)); insts_list.push_back(ar_done); - edge2cost.push_back(std::make_tuple(ar_done, cost_gen())); + edge2cost.push_back(std::make_tuple(ar_done, cost_gen()+50)); not_used_insts.insert(ar_done); } } @@ -861,9 +861,7 @@ TEST_F(AutoReorderingTest, ReorderScheduleComputation) { std::unique_ptr latency_estimator; int pointer_size_ = 4; Backend& test_backend = backend(); - const se::DeviceDescription& gpu_device_info = - test_backend.default_stream_executor()->GetDeviceDescription(); - // auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); VLOG(2) << "threads_per_block_limit:" << gpu_device_info.threads_per_block_limit() << " threads_per_warp" @@ -914,8 +912,7 @@ TEST_F(AutoReorderingTest, ReorderPass) { EXPECT_TRUE(st.ok()); int pointer_size_ = 4; Backend& test_backend = backend(); - const se::DeviceDescription& gpu_device_info = - test_backend.default_stream_executor()->GetDeviceDescription(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( hlo_module.get(), gpu_device_info, pointer_size_); SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); @@ -954,8 +951,7 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) { EXPECT_TRUE(st.ok()); int pointer_size_ = 4; Backend& test_backend = backend(); - const se::DeviceDescription& gpu_device_info = - test_backend.default_stream_executor()->GetDeviceDescription(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( hlo_module.get(), gpu_device_info, pointer_size_); SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); @@ -972,8 +968,8 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) { EXPECT_TRUE(status.ok()); } TEST_F(AutoReorderingTest, ReorderPassWithRandom) { + // GTEST_SKIP() << "Skipping single test"; std::srand(kRandomSeed); - // communication rate from 0.05 to 0.95,step is 0.05 auto hlo_module = CreateNewUnverifiedModule(); auto gpu_latency_estimator = std::make_unique(); SchedulerConfig sched_config = GetDefaultSchedConfig(); @@ -1027,12 +1023,20 @@ TEST_F(AutoReorderingTest, ReorderPassWithRandom) { } // skip this test TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) { - // GTEST_SKIP() << "Skipping single test"; + GTEST_SKIP() << "Skipping single test"; std::srand(kRandomSeed); auto gen = std::mt19937{kRandomSeed}; - int repeat_time = 3; - uint32_t nnodes = 50; - std::vector communication_rates = {0.1,0.15,0.2,0.25,0.3,0.65,0.7,0.75,0.8,0.85}; + int repeat_time = 1; + uint32_t nnodes = 100; + std::vector communication_rates; + // = { + // 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 + // }; + for (float current=0.1; current < 0.9; current+=0.05) + { + communication_rates.push_back(current); + } + // communication rate from 0.05 to 0.95,step is 0.05 std::ofstream csv_out("/tmp/test_ret.csv"); csv_out<<"exp_id,nnodes,communication_rate,auto_reorder_cost,post_order_cost,xla_hiding_order_cost,xla_hiding_solve_time,auto_reorder_solve_time"<clone(); auto gpu_latency_estimator3 = gpu_latency_estimator->clone(); // run AutoReorder for compare diff --git a/xla/service/gpu/gpu_hlo_schedule.h b/xla/service/gpu/gpu_hlo_schedule.h index 173e50dc1f794..5737114ee00bf 100644 --- a/xla/service/gpu/gpu_hlo_schedule.h +++ b/xla/service/gpu/gpu_hlo_schedule.h @@ -44,6 +44,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/buffer_value.h" +#include "xla/hlo/experimental/auto_reorder/auto_reorder.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_schedule_postprocessing.h" diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index 4edcc1a682c8b..9f4e15eab066c 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -2977,4 +2977,196 @@ ENTRY main { // not create a failure of scheduling by the async done checks. EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); } +TEST_F(LatencyHidingSchedulerTest, RunRandomComputation) { + absl::string_view hlo_string = R"( + HloModule ReorderPassWithRandom, is_scheduled=true, entry_computation_layout={(f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0})->(f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0})} + +%add (x: f32[4,256,256], y: f32[4,256,256]) -> f32[4,256,256] { + %y = f32[4,256,256] parameter(1) + %x = f32[4,256,256] parameter(0) + ROOT %add = f32[4,256,256] add(f32[4,256,256] %x, f32[4,256,256] %y) +} + +%add.1 (x.1: f32[4,256,256], y.1: f32[4,256,256]) -> f32[4,256,256] { + %y.1 = f32[4,256,256] parameter(1) + %x.1 = f32[4,256,256] parameter(0) + ROOT %add.1 = f32[4,256,256] add(f32[4,256,256] %x.1, f32[4,256,256] %y.1) +} + +%add.2 (x.2: f32[4,256,256], y.2: f32[4,256,256]) -> f32[4,256,256] { + %y.2 = f32[4,256,256] parameter(1) + %x.2 = f32[4,256,256] parameter(0) + ROOT %add.2 = f32[4,256,256] add(f32[4,256,256] %x.2, f32[4,256,256] %y.2) +} + +%add.3 (x.3: f32[4,256,256], y.3: f32[4,256,256]) -> f32[4,256,256] { + %y.3 = f32[4,256,256] parameter(1) + %x.3 = f32[4,256,256] parameter(0) + ROOT %add.3 = f32[4,256,256] add(f32[4,256,256] %x.3, f32[4,256,256] %y.3) +} + +%add.4 (x.4: f32[4,256,256], y.4: f32[4,256,256]) -> f32[4,256,256] { + %y.4 = f32[4,256,256] parameter(1) + %x.4 = f32[4,256,256] parameter(0) + ROOT %add.4 = f32[4,256,256] add(f32[4,256,256] %x.4, f32[4,256,256] %y.4) +} + +%add.5 (x.5: f32[4,256,256], y.5: f32[4,256,256]) -> f32[4,256,256] { + %y.5 = f32[4,256,256] parameter(1) + %x.5 = f32[4,256,256] parameter(0) + ROOT %add.5 = f32[4,256,256] add(f32[4,256,256] %x.5, f32[4,256,256] %y.5) +} + +%add.6 (x.6: f32[4,256,256], y.6: f32[4,256,256]) -> f32[4,256,256] { + %y.6 = f32[4,256,256] parameter(1) + %x.6 = f32[4,256,256] parameter(0) + ROOT %add.6 = f32[4,256,256] add(f32[4,256,256] %x.6, f32[4,256,256] %y.6) +} + +%add.7 (x.7: f32[4,256,256], y.7: f32[4,256,256]) -> f32[4,256,256] { + %y.7 = f32[4,256,256] parameter(1) + %x.7 = f32[4,256,256] parameter(0) + ROOT %add.7 = f32[4,256,256] add(f32[4,256,256] %x.7, f32[4,256,256] %y.7) +} + +%add.8 (x.8: f32[4,256,256], y.8: f32[4,256,256]) -> f32[4,256,256] { + %y.8 = f32[4,256,256] parameter(1) + %x.8 = f32[4,256,256] parameter(0) + ROOT %add.8 = f32[4,256,256] add(f32[4,256,256] %x.8, f32[4,256,256] %y.8) +} + +%add.9 (x.9: f32[4,256,256], y.9: f32[4,256,256]) -> f32[4,256,256] { + %y.9 = f32[4,256,256] parameter(1) + %x.9 = f32[4,256,256] parameter(0) + ROOT %add.9 = f32[4,256,256] add(f32[4,256,256] %x.9, f32[4,256,256] %y.9) +} + +%add.10 (x.10: f32[4,256,256], y.10: f32[4,256,256]) -> f32[4,256,256] { + %y.10 = f32[4,256,256] parameter(1) + %x.10 = f32[4,256,256] parameter(0) + ROOT %add.10 = f32[4,256,256] add(f32[4,256,256] %x.10, f32[4,256,256] %y.10) +} + +%add.11 (x.11: f32[4,256,256], y.11: f32[4,256,256]) -> f32[4,256,256] { + %y.11 = f32[4,256,256] parameter(1) + %x.11 = f32[4,256,256] parameter(0) + ROOT %add.11 = f32[4,256,256] add(f32[4,256,256] %x.11, f32[4,256,256] %y.11) +} + +%add.12 (x.12: f32[4,256,256], y.12: f32[4,256,256]) -> f32[4,256,256] { + %y.12 = f32[4,256,256] parameter(1) + %x.12 = f32[4,256,256] parameter(0) + ROOT %add.12 = f32[4,256,256] add(f32[4,256,256] %x.12, f32[4,256,256] %y.12) +} + +%add.13 (x.13: f32[4,256,256], y.13: f32[4,256,256]) -> f32[4,256,256] { + %y.13 = f32[4,256,256] parameter(1) + %x.13 = f32[4,256,256] parameter(0) + ROOT %add.13 = f32[4,256,256] add(f32[4,256,256] %x.13, f32[4,256,256] %y.13) +} + +%add.14 (x.14: f32[4,256,256], y.14: f32[4,256,256]) -> f32[4,256,256] { + %y.14 = f32[4,256,256] parameter(1) + %x.14 = f32[4,256,256] parameter(0) + ROOT %add.14 = f32[4,256,256] add(f32[4,256,256] %x.14, f32[4,256,256] %y.14) +} + +%add.15 (x.15: f32[4,256,256], y.15: f32[4,256,256]) -> f32[4,256,256] { + %y.15 = f32[4,256,256] parameter(1) + %x.15 = f32[4,256,256] parameter(0) + ROOT %add.15 = f32[4,256,256] add(f32[4,256,256] %x.15, f32[4,256,256] %y.15) +} + +ENTRY %ReorderPassWithRandom (p0: f32[4,256,256], p1: f32[4,256,256]) -> (f32[4,256,256], f32[4,256,256], f32[4,256,256], f32[4,256,256]) { + %p1 = f32[4,256,256]{2,1,0} parameter(1) + %p0 = f32[4,256,256]{2,1,0} parameter(0) + %add.16 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %p0, f32[4,256,256]{2,1,0} %p1) + %add.17 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %p1, f32[4,256,256]{2,1,0} %add.16) + %add.19 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.16, f32[4,256,256]{2,1,0} %add.17) + %add.18 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.16, f32[4,256,256]{2,1,0} %add.17) + %add.21 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.17, f32[4,256,256]{2,1,0} %add.19) + %all-reduce-start.3 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.21), replica_groups={{0,1}}, to_apply=%add.3 + %all-reduce-done.3 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.3) + %all-reduce-start.1 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.18), replica_groups={{0,1}}, to_apply=%add.1 + %all-reduce-done.1 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.1) + %all-reduce-start = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.19), replica_groups={{0,1}}, to_apply=%add + %all-reduce-done = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start) + %add.20 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done, f32[4,256,256]{2,1,0} %all-reduce-done.1) + %add.23 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.17, f32[4,256,256]{2,1,0} %add.20) + %add.22 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.20, f32[4,256,256]{2,1,0} %add.21) + %add.24 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.22, f32[4,256,256]{2,1,0} %add.23) + %all-reduce-start.6 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.22), replica_groups={{0,1}}, to_apply=%add.6 + %all-reduce-done.6 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.6) + %all-reduce-start.14 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %p0), replica_groups={{0,1}}, to_apply=%add.14 + %all-reduce-done.14 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.14) + %all-reduce-start.2 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %p0), replica_groups={{0,1}}, to_apply=%add.2 + %all-reduce-done.2 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.2) + %add.25 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.3, f32[4,256,256]{2,1,0} %all-reduce-done.2) + %add.31 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %p1, f32[4,256,256]{2,1,0} %add.25) + %add.26 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.25, f32[4,256,256]{2,1,0} %add.24) + %all-reduce-start.7 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.31), replica_groups={{0,1}}, to_apply=%add.7 + %all-reduce-done.7 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.7) + %add.29 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %p0, f32[4,256,256]{2,1,0} %all-reduce-done.2) + %all-reduce-start.5 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.29), replica_groups={{0,1}}, to_apply=%add.5 + %all-reduce-done.5 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.5) + %add.27 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.1, f32[4,256,256]{2,1,0} %all-reduce-done.2) + %add.28 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.26, f32[4,256,256]{2,1,0} %add.27) + %add.43 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.18, f32[4,256,256]{2,1,0} %add.28) + %all-reduce-start.4 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.28), replica_groups={{0,1}}, to_apply=%add.4 + %all-reduce-done.4 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.4) + %add.30 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.4, f32[4,256,256]{2,1,0} %all-reduce-done.5) + %add.32 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.30, f32[4,256,256]{2,1,0} %add.31) + %add.33 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %p1, f32[4,256,256]{2,1,0} %all-reduce-done.4) + %add.34 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.32, f32[4,256,256]{2,1,0} %add.33) + %add.35 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.7, f32[4,256,256]{2,1,0} %add.34) + %add.41 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.28, f32[4,256,256]{2,1,0} %add.35) + %add.36 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.6, f32[4,256,256]{2,1,0} %add.35) + %add.37 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.1, f32[4,256,256]{2,1,0} %add.36) + %add.38 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.36, f32[4,256,256]{2,1,0} %add.37) + %add.39 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.2, f32[4,256,256]{2,1,0} %add.38) + %all-reduce-start.8 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.38), replica_groups={{0,1}}, to_apply=%add.8 + %all-reduce-done.8 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.8) + %all-reduce-start.9 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.39), replica_groups={{0,1}}, to_apply=%add.9 + %all-reduce-done.9 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.9) + %add.40 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.8, f32[4,256,256]{2,1,0} %all-reduce-done.9) + %add.51 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.21, f32[4,256,256]{2,1,0} %add.40) + %add.47 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.19, f32[4,256,256]{2,1,0} %add.40) + %add.42 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.40, f32[4,256,256]{2,1,0} %add.41) + %add.44 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.42, f32[4,256,256]{2,1,0} %add.43) + %add.57 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.4, f32[4,256,256]{2,1,0} %add.44) + %all-reduce-start.15 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.41), replica_groups={{0,1}}, to_apply=%add.15 + %all-reduce-done.15 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.15) + %all-reduce-start.10 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.34), replica_groups={{0,1}}, to_apply=%add.10 + %all-reduce-done.10 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.10) + %add.49 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.24, f32[4,256,256]{2,1,0} %all-reduce-done.10) + %all-reduce-start.13 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.49), replica_groups={{0,1}}, to_apply=%add.13 + %all-reduce-done.13 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.13) + %all-reduce-start.11 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.44), replica_groups={{0,1}}, to_apply=%add.11 + %all-reduce-done.11 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.11) + %add.45 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.44, f32[4,256,256]{2,1,0} %all-reduce-done.11) + %add.46 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.10, f32[4,256,256]{2,1,0} %add.45) + %add.48 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.46, f32[4,256,256]{2,1,0} %add.47) + %all-reduce-start.12 = f32[4,256,256]{2,1,0} all-reduce-start(f32[4,256,256]{2,1,0} %add.48), replica_groups={{0,1}}, to_apply=%add.12 + %all-reduce-done.12 = f32[4,256,256]{2,1,0} all-reduce-done(f32[4,256,256]{2,1,0} %all-reduce-start.12) + %add.50 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.12, f32[4,256,256]{2,1,0} %all-reduce-done.13) + %add.52 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.51, f32[4,256,256]{2,1,0} %add.50) + %add.53 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.45, f32[4,256,256]{2,1,0} %add.50) + %add.54 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.52, f32[4,256,256]{2,1,0} %add.53) + %add.55 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.54, f32[4,256,256]{2,1,0} %all-reduce-done.15) + %add.56 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %all-reduce-done.14, f32[4,256,256]{2,1,0} %add.55) + %add.58 = f32[4,256,256]{2,1,0} add(f32[4,256,256]{2,1,0} %add.56, f32[4,256,256]{2,1,0} %add.57) + ROOT %tuple = (f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0}, f32[4,256,256]{2,1,0}) tuple(f32[4,256,256]{2,1,0} %add.46, f32[4,256,256]{2,1,0} %add.45, f32[4,256,256]{2,1,0} %all-reduce-done.11, f32[4,256,256]{2,1,0} %add.58) +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); +} } // namespace xla diff --git a/xla/xla.proto b/xla/xla.proto index 0618c243fe69c..1ece9d2fdf7f3 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -564,7 +564,7 @@ message DebugOptions { bool xla_gpu_enable_latency_hiding_scheduler = 186; bool xla_gpu_enable_highest_priority_async_stream = 216; bool xla_gpu_enable_analytical_latency_estimator = 255; - bool xla_gpu_enable_linear_program_scheduler = 258; + bool xla_gpu_enable_linear_program_scheduler = 266; bool xla_gpu_lhs_enable_gpu_async_tracker = 204; string xla_gpu_pgle_profile_file_or_directory_path = 210; @@ -675,7 +675,7 @@ message DebugOptions { // Threshold to enable windowed einsum (collective matmul) in MB. int64 xla_gpu_threshold_for_windowed_einsum_mib = 265; - // Next id: 266 + // Next id: 267 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.