Skip to content

Commit

Permalink
Fix cuda graph capture for grouped gemm (#1345)
Browse files Browse the repository at this point in the history
* retain_graph=True for grouped gemm

Signed-off-by: Xiaowei Ren <[email protected]>

* remove an unnecessary retain_graph=True

Signed-off-by: Xiaowei Ren <[email protected]>

* make retain_graph in graph capture configurable

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
  • Loading branch information
xrennvidia authored Nov 27, 2024
1 parent 60ce21f commit a132ac4
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
Expand Down Expand Up @@ -320,6 +321,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
Expand Down Expand Up @@ -371,6 +373,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
Expand Down Expand Up @@ -606,6 +609,7 @@ def make_graphed_callables(
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
Make CUDA graph version of Transformer Engine modules
Expand All @@ -632,6 +636,8 @@ def make_graphed_callables(
pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False`
Whether to set retain_graph=True in backward graph capture.
FP8-related parameters
----------------------
Expand Down Expand Up @@ -716,6 +722,7 @@ def forward_func(*args, **kwargs):
sample_kwargs=sample_kwargs,
_order=_order,
pool=pool,
retain_graph_in_backward=retain_graph_in_backward,
)

# Ensures warmup does not affect numerics for ops such as dropout.
Expand Down

0 comments on commit a132ac4

Please sign in to comment.