From a132ac499d3b388b6fe658dfda03829a06edc3c3 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:35:33 -0800 Subject: [PATCH] Fix cuda graph capture for grouped gemm (#1345) * retain_graph=True for grouped gemm Signed-off-by: Xiaowei Ren * remove an unnecessary retain_graph=True Signed-off-by: Xiaowei Ren * make retain_graph in graph capture configurable Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/graph.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6c33cc72b9..f44500f7f2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -64,6 +64,7 @@ def _make_graphed_callables( sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -320,6 +321,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs @@ -371,6 +373,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that @@ -606,6 +609,7 @@ def make_graphed_callables( fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -632,6 +636,8 @@ def make_graphed_callables( pool: (tuple of) int, default = `None`, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. + retain_graph_in_backward: bool, default = `False` + Whether to set retain_graph=True in backward graph capture. FP8-related parameters ---------------------- @@ -716,6 +722,7 @@ def forward_func(*args, **kwargs): sample_kwargs=sample_kwargs, _order=_order, pool=pool, + retain_graph_in_backward=retain_graph_in_backward, ) # Ensures warmup does not affect numerics for ops such as dropout.