diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 491170ff67df00..06136e64333f9c 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -185,6 +185,9 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } + HloInstructionAdaptor fusion_adaptor(*fusion); + can_fuse_cache_.erase(fusion_adaptor); + fusion_analysis_cache_.Invalidate(*fusion); fusion_analysis_cache_.Invalidate(*original_producer); @@ -219,6 +222,9 @@ class GpuPriorityFusionQueue : public FusionQueue { continue; } producer_user_count_[operand] = operand->user_count(); + + HloInstructionAdaptor operand_adaptor(*operand); + can_fuse_cache_[operand_adaptor].erase(fusion_adaptor); to_update_priority_.insert(operand); } to_update_priority_.insert(fusion); @@ -314,14 +320,44 @@ class GpuPriorityFusionQueue : public FusionQueue { run_times.time_fused); } - FusionDecision CanFuseWithAllUsers(HloInstruction* producer) const { + FusionDecision CanFuseCached(HloInstruction* producer, + HloInstruction* consumer) { + HloInstructionAdaptor producer_adaptor(*producer); + HloInstructionAdaptor consumer_adaptor(*consumer); + + { + absl::MutexLock lock(&can_fuse_cache_mutex_); + auto& producer_cache = can_fuse_cache_[producer_adaptor]; + + auto it = producer_cache.find(consumer_adaptor); + if (it != producer_cache.end()) { + return it->second; + } + } + + auto fusion_decision = + can_fuse_(consumer, consumer->operand_index(producer)); + + // The lock is required, because writing to a flat_hash_map is not + // thread-safe even for different keys. We never call this computation + // concurrently for the same producer, so it's guaranteed that we don't + // override any value. + { + absl::MutexLock lock(&can_fuse_cache_mutex_); + can_fuse_cache_[producer_adaptor][consumer_adaptor] = fusion_decision; + } + + return fusion_decision; + } + + FusionDecision CanFuseWithAllUsers(HloInstruction* producer) { if (producer->users().size() == 0) { return "No users to fuse"; } FusionDecision result; for (const auto& user : producer->users()) { - if (auto fusion_decision = can_fuse_(user, user->operand_index(producer)); + if (auto fusion_decision = CanFuseCached(producer, user); !fusion_decision) { VLOG(10) << "Cannot fuse " << producer->name() << " with " << user->name() << ", because: " << fusion_decision.Explain(); @@ -376,6 +412,14 @@ class GpuPriorityFusionQueue : public FusionQueue { tsl::thread::ThreadPool* thread_pool_; HloFusionAnalysisCache& fusion_analysis_cache_; + + // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is + // invalidated if producer or consumer is modified. + absl::flat_hash_map< + HloInstructionAdaptor, + absl::flat_hash_map> + can_fuse_cache_; + absl::Mutex can_fuse_cache_mutex_; }; } // namespace