Skip to content

Commit

Permalink
[XLA:GPU] Add can_fuse cache.
Browse files Browse the repository at this point in the history
Avoid unnecessary calls to can_fuse_ to save a significant amount of compile time.

PiperOrigin-RevId: 585598642
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Nov 27, 2023
1 parent 116db98 commit 9542da9
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions third_party/xla/xla/service/gpu/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class GpuPriorityFusionQueue : public FusionQueue {
fusion_step->set_consumer_name(std::string(original_consumer->name()));
}

HloInstructionAdaptor fusion_adaptor(*fusion);
can_fuse_cache_.erase(fusion_adaptor);

fusion_analysis_cache_.Invalidate(*fusion);
fusion_analysis_cache_.Invalidate(*original_producer);

Expand Down Expand Up @@ -219,6 +222,9 @@ class GpuPriorityFusionQueue : public FusionQueue {
continue;
}
producer_user_count_[operand] = operand->user_count();

HloInstructionAdaptor operand_adaptor(*operand);
can_fuse_cache_[operand_adaptor].erase(fusion_adaptor);
to_update_priority_.insert(operand);
}
to_update_priority_.insert(fusion);
Expand Down Expand Up @@ -314,14 +320,44 @@ class GpuPriorityFusionQueue : public FusionQueue {
run_times.time_fused);
}

FusionDecision CanFuseWithAllUsers(HloInstruction* producer) const {
FusionDecision CanFuseCached(HloInstruction* producer,
HloInstruction* consumer) {
HloInstructionAdaptor producer_adaptor(*producer);
HloInstructionAdaptor consumer_adaptor(*consumer);

{
absl::MutexLock lock(&can_fuse_cache_mutex_);
auto& producer_cache = can_fuse_cache_[producer_adaptor];

auto it = producer_cache.find(consumer_adaptor);
if (it != producer_cache.end()) {
return it->second;
}
}

auto fusion_decision =
can_fuse_(consumer, consumer->operand_index(producer));

// The lock is required, because writing to a flat_hash_map is not
// thread-safe even for different keys. We never call this computation
// concurrently for the same producer, so it's guaranteed that we don't
// override any value.
{
absl::MutexLock lock(&can_fuse_cache_mutex_);
can_fuse_cache_[producer_adaptor][consumer_adaptor] = fusion_decision;
}

return fusion_decision;
}

FusionDecision CanFuseWithAllUsers(HloInstruction* producer) {
if (producer->users().size() == 0) {
return "No users to fuse";
}

FusionDecision result;
for (const auto& user : producer->users()) {
if (auto fusion_decision = can_fuse_(user, user->operand_index(producer));
if (auto fusion_decision = CanFuseCached(producer, user);
!fusion_decision) {
VLOG(10) << "Cannot fuse " << producer->name() << " with "
<< user->name() << ", because: " << fusion_decision.Explain();
Expand Down Expand Up @@ -376,6 +412,14 @@ class GpuPriorityFusionQueue : public FusionQueue {
tsl::thread::ThreadPool* thread_pool_;

HloFusionAnalysisCache& fusion_analysis_cache_;

// Caches result of can_fuse for a (producer, consumer) pair. A cache entry is
// invalidated if producer or consumer is modified.
absl::flat_hash_map<
HloInstructionAdaptor,
absl::flat_hash_map<HloInstructionAdaptor, FusionDecision>>
can_fuse_cache_;
absl::Mutex can_fuse_cache_mutex_;
};

} // namespace
Expand Down

0 comments on commit 9542da9

Please sign in to comment.