diff --git a/msccl/language/mscclpp/instruction_dag.py b/msccl/language/mscclpp/instruction_dag.py index b7499d0..ce92b87 100644 --- a/msccl/language/mscclpp/instruction_dag.py +++ b/msccl/language/mscclpp/instruction_dag.py @@ -294,7 +294,7 @@ def _optimize_rrcs_rs(self): fused = False if op.inst == Instruction.read_reduce_copy or op.inst == Instruction.read_reduce_copy_send: for next_op in op.next: - fused = optimizer.try_merge_with_put(op, next_op, tb, queue, Instruction.read_reduce_copy_send) + fused = optimizer.try_merge_with_put(op, next_op, tb, queue, op.inst) if fused: break if fused: