Skip to content

Commit

Permalink
add dag
Browse files Browse the repository at this point in the history
  • Loading branch information
zjjott committed Mar 7, 2024
1 parent 91ac2fe commit c0d7559
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 40 deletions.
6 changes: 4 additions & 2 deletions xla/hlo/experimental/auto_reorder/auto_reorder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ AutoReorderPass::ScheduleComputation(HloComputation* computation) {
std::vector<HloInstruction*> new_schedule;
auto sorted_nodes = solver_->GetSortedNodes();
for (auto node : sorted_nodes) {
new_schedule.push_back(
const_cast<xla::HloInstruction*>(node->GetValue()));
auto insts = node->GetValues();
for (auto inst : insts) {
new_schedule.push_back(const_cast<xla::HloInstruction*>(inst));
}
}
return new_schedule;
}
Expand Down
78 changes: 76 additions & 2 deletions xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
#ifndef LPSchedulerFunc(return_type)
#define LPSchedulerFunc(return_type) template <typename ContainerType,typename ElementType> return_type LinearProgramScheduler<ContainerType,ElementType>
#endif

#ifndef LPContainerDAGFunc(return_type)
#define LPContainerDAGFunc(return_type) template <typename ElementType> return_type LPContainerDAG<ElementType>
#endif

namespace xla {
using IntVar = operations_research::sat::IntVar;
using CpModelBuilder = operations_research::sat::CpModelBuilder;
Expand Down Expand Up @@ -100,7 +105,7 @@ LPSchedulerFunc(tsl::Status)::Solve() {
max_execution_time += cost;
}
}
SetHorizon(max_execution_time * reorder::kChannelNumber);
SetHorizon(reorder::get_horizon(max_execution_time));

for (auto node : nodes_) {
VLOG(3) << "Add to scheduler" << node->GetName();
Expand Down Expand Up @@ -141,7 +146,7 @@ LPSchedulerFunc(tsl::Status)::Solve() {
parameters.set_log_to_stdout(true);
parameters.set_log_search_progress(true);
}
parameters.set_num_search_workers(8);
parameters.set_num_search_workers(reorder::get_cpu_number());
const operations_research::sat::CpSolverResponse response =
operations_research::sat::SolveWithParameters(cp_model_.Build(),
parameters);
Expand Down Expand Up @@ -294,6 +299,75 @@ LPSchedulerFunc(std::vector<ContainerType*>)::GetSortedNodes() {
[this](ContainerType* a, ContainerType* b) { return a->GetStart() < b->GetStart(); });
return sorted_nodes;
}

LPContainerDAGFunc(bool)::IsIn(LPContainer<ElementType>* a) {
return operands_.find(a) != operands_.end();
};
LPContainerDAGFunc(void)::AddToDAG(LPContainer<ElementType>* child){
inner_elements.push_back(child);
if(IsIn(child)){
operands_.erase(child);
}
for(auto dep_pair: child->GetDeps()){
auto dep = std::get<0>(dep_pair);
auto cost = std::get<1>(dep_pair);//if cost need store ?
operands_.insert(dep);
}
}
LPContainerDAGFunc(Status)::MergeFrom(LPContainerDAG<ElementType>* other){
/*
step 1: this inner_elements must have dep to other's inner_elements. so that link to other's inner_elements change to inner edges
*/

// maintain this LPContainerDAG inner_elements's deps,so that can create inner edge after merge
// {dep: [<element1, cost>,<element2, cost>]}
std::unordered_map<
int,
std::vector<std::tuple<LPContainer<ElementType>*, CostType>>
> dep_operands2element;

for(LPContainer<ElementType>* element: GetInnerElements()){
// from operate to element, there are outer edge,maybe convert to inner edge
for(auto dep_pair: element->GetDeps()){
auto dep = std::get<0>(dep_pair);
auto cost = std::get<1>(dep_pair);
if(dep_operands2element.find(dep->UUID())==dep_operands2element.end()){
dep_operands2element[dep->UUID()] = std::vector<std::tuple<LPContainer<ElementType>*, CostType>>();
}
dep_operands2element[dep->UUID()].push_back(std::make_tuple(element, cost));
}
}
//other
for(auto child:other->GetInnerElements()){
// there child must in inner_elements_deps
TF_RET_CHECK(dep_operands2element.find(child->UUID())==dep_operands2element.end()
)<<"child is not in dep_operands2element";
for(auto dep_pair: dep_operands2element[child->UUID()]){
auto dep = std::get<0>(dep_pair);
auto cost = std::get<1>(dep_pair);
if(dep_operands2element.find(dep->UUID())!=dep_operands2element.end()){
for(auto element_pair: dep_operands2element[dep->UUID()]){
auto element = std::get<0>(element_pair);
auto cost = std::get<1>(element_pair);
//create edge between element and child
DAGEdge edge;
edge.from = element;
edge.to = child;
edge.cost = cost;
edges_.push_back(edge);
}
}
}

AddToDAG(child);

};
}
template class LPContainer<const HloInstruction*>;
template class LinearProgramScheduler<LPContainer<const HloInstruction*>, const HloInstruction*>;


template class LPContainerDAG<const HloInstruction*>;
// template class LinearProgramScheduler<LPContainerDAG<const HloInstruction*>, const HloInstruction*>;

} // namespace xla
83 changes: 62 additions & 21 deletions xla/hlo/experimental/auto_reorder/auto_reorder_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <tuple>
#include <unordered_map>
#include <set>

#include <thread>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
Expand All @@ -17,10 +17,19 @@ namespace xla {
using CpModelBuilder = operations_research::sat::CpModelBuilder;
using IntervalVar = operations_research::sat::IntervalVar;
namespace reorder{
const uint32_t ksolveTimeout = 30; // 30s
const uint32_t ksolveTimeout = 60; // 30s

static const int kChannelNumber = 2;
bool solve_debug=false;
int get_horizon(int max_time){
//scale
return max_time*1.2;
}
bool solve_debug=true;
//get cpu number of current machine
int get_cpu_number(){
// return 8;
return std::thread::hardware_concurrency();
}
}
enum class NodeType { kCompute = 0, kCommunication = 1 };

Expand Down Expand Up @@ -59,7 +68,7 @@ class LPNode{
template <typename ElementType>
class LPContainer{
public:

//create a LPContainer with inner_element, cost and type
LPContainer(ElementType inner_element, CostType cost, NodeType type)
: inner_element_(inner_element), cost_(cost), type_(type) {
uuid_ = reinterpret_cast<uintptr_t>(this);
Expand All @@ -71,15 +80,20 @@ class LPContainer{
CostType GetCost() const { return cost_; }
void SetStart(CostType start) { startat_ = start; }
CostType GetStart() { return startat_; }
// Get the type of the container: compute or communication
bool IsComputation() const { return type_ == NodeType::kCompute; }
bool IsCommunication() const { return type_ == NodeType::kCommunication; }
NodeType GetType() const { return type_; }
bool HasValue() const { return inner_element_ != nullptr; }
ElementType GetValue() const { return inner_element_; }

const bool HasValue() { return inner_element_ != nullptr; }
const std::vector<ElementType> GetValues() { return std::vector<ElementType>{inner_element_}; }
// Add a dep of this container, cost is the cost of the edge; this Container will be executed after dep
void AddDep(LPContainer* dep, CostType cost);
// Get all deps of the container
const std::vector<std::tuple<LPContainer*, CostType>> GetDeps() const {
return deps_;
}
//when a container is frozen, it can not be add deps
void Freeze() { frozen_ = true; }

private:
Expand All @@ -97,32 +111,59 @@ class LPContainer{
// LPContainerDAG is a graph of container, it can be used to store the DAG of container
// be used as a atomic unit of LPContainer
template <typename ElementType>
class LPContainerDAG{
class LPContainerDAG: public LPContainer<ElementType>{
//we can use InstructionDAG to get memory effect order
public:
// maintain a DAG of inner elements
struct DAGEdge{
LPContainerDAG* from;
LPContainerDAG* to;
LPContainer<ElementType>* from;
LPContainer<ElementType>* to;
CostType cost;
};
//create a LPContainerDAG with
LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): cost_(cost), type_(type){
inner_elements.push_back(LPContainer<ElementType>(inner_element, cost, type));
};
bool IsIn(LPContainerDAG<ElementType> *a){
return users_.find(a) != users_.end();
//create a LPContainerDAG with one element
LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): LPContainer<ElementType>(inner_element,cost,type){
//TODO: there should not create element?
auto ele = new LPContainer<ElementType>(inner_element, cost, type);
inner_elements.push_back(ele);
};
bool IsIn(LPContainer<ElementType>* a);
//which container can be put together:1. they have the same type 2. they have dep between them
static bool CanFused(LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b){
// static bool CanFused(LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b);

};
// AddChild
//override LPContainer
const std::string GetName(){
std::string name = "LPContainerDAG{";
for(auto ele: inner_elements){
name += ele->GetName();
name+="\n";
}
name+="}";
return name;
}
const int UUID() { return inner_elements[0]->UUID(); }
const bool HasValue() { return inner_elements.size()>0;}
const std::vector<ElementType> GetValues() {
std::vector<ElementType> values;
for(auto ele: inner_elements){
for(auto inst:ele->GetValues()){
values.push_back(inst);
}
}
return values;
}
// AddChild, child should maintain the deps before
void AddToDAG(LPContainer<ElementType>* child);
const std::vector<LPContainer<ElementType>*> GetInnerElements() const{
return inner_elements;
}
//merge other LPContainerDAG to this LPContainerDAG,then destroy other LPContainerDAG
Status MergeFrom(LPContainerDAG<ElementType>* other);
private:

std::set<ElementType> users_;
std::set<ElementType> operands_;
std::vector<LPContainer<ElementType>> inner_elements;
std::set<LPContainer<ElementType>*> operands_;
std::vector<LPContainer<ElementType>*> inner_elements;
//maintain edges between inner_elements
std::vector<DAGEdge> edges_;
CostType cost_;
CostType startat_;
NodeType type_;
Expand Down
32 changes: 19 additions & 13 deletions xla/hlo/experimental/auto_reorder/auto_reorder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ ENTRY %elementwise {
insts2cost.push_back(std::make_tuple(ar_done, 1));

insts_list.push_back(ar_done);
edge2cost.push_back(std::make_tuple(ar_done, cost_gen()));
edge2cost.push_back(std::make_tuple(ar_done, cost_gen()+50));
not_used_insts.insert(ar_done);
}
}
Expand Down Expand Up @@ -861,9 +861,7 @@ TEST_F(AutoReorderingTest, ReorderScheduleComputation) {
std::unique_ptr<LatencyEstimator> latency_estimator;
int pointer_size_ = 4;
Backend& test_backend = backend();
const se::DeviceDescription& gpu_device_info =
test_backend.default_stream_executor()->GetDeviceDescription();
// auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

VLOG(2) << "threads_per_block_limit:"
<< gpu_device_info.threads_per_block_limit() << " threads_per_warp"
Expand Down Expand Up @@ -914,8 +912,7 @@ TEST_F(AutoReorderingTest, ReorderPass) {
EXPECT_TRUE(st.ok());
int pointer_size_ = 4;
Backend& test_backend = backend();
const se::DeviceDescription& gpu_device_info =
test_backend.default_stream_executor()->GetDeviceDescription();
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit(
hlo_module.get(), gpu_device_info, pointer_size_);
SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit);
Expand Down Expand Up @@ -954,8 +951,7 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) {
EXPECT_TRUE(st.ok());
int pointer_size_ = 4;
Backend& test_backend = backend();
const se::DeviceDescription& gpu_device_info =
test_backend.default_stream_executor()->GetDeviceDescription();
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit(
hlo_module.get(), gpu_device_info, pointer_size_);
SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit);
Expand All @@ -972,8 +968,8 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) {
EXPECT_TRUE(status.ok());
}
TEST_F(AutoReorderingTest, ReorderPassWithRandom) {
// GTEST_SKIP() << "Skipping single test";
std::srand(kRandomSeed);
// communication rate from 0.05 to 0.95,step is 0.05
auto hlo_module = CreateNewUnverifiedModule();
auto gpu_latency_estimator = std::make_unique<SavedInstLatencyEstimator>();
SchedulerConfig sched_config = GetDefaultSchedConfig();
Expand Down Expand Up @@ -1027,12 +1023,20 @@ TEST_F(AutoReorderingTest, ReorderPassWithRandom) {
}
// skip this test
TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) {
// GTEST_SKIP() << "Skipping single test";
GTEST_SKIP() << "Skipping single test";
std::srand(kRandomSeed);
auto gen = std::mt19937{kRandomSeed};
int repeat_time = 3;
uint32_t nnodes = 50;
std::vector<float> communication_rates = {0.1,0.15,0.2,0.25,0.3,0.65,0.7,0.75,0.8,0.85};
int repeat_time = 1;
uint32_t nnodes = 100;
std::vector<float> communication_rates;
// = {
// 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9
// };
for (float current=0.1; current < 0.9; current+=0.05)
{
communication_rates.push_back(current);
}

// communication rate from 0.05 to 0.95,step is 0.05
std::ofstream csv_out("/tmp/test_ret.csv");
csv_out<<"exp_id,nnodes,communication_rate,auto_reorder_cost,post_order_cost,xla_hiding_order_cost,xla_hiding_solve_time,auto_reorder_solve_time"<<std::endl;
Expand All @@ -1050,6 +1054,8 @@ TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) {
/*communication rate*/ communication_rate,
/* gen */gen);
EXPECT_TRUE(st.ok());
// auto latency_estimator = create_latency_estimator();

auto gpu_latency_estimator2 = gpu_latency_estimator->clone();
auto gpu_latency_estimator3 = gpu_latency_estimator->clone();
// run AutoReorder for compare
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/gpu_hlo_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/buffer_value.h"
#include "xla/hlo/experimental/auto_reorder/auto_reorder.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/gpu_schedule_postprocessing.h"
Expand Down
Loading

0 comments on commit c0d7559

Please sign in to comment.