-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ifu release v1.11 nov #95
Conversation
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
…iable (#1007) * Add enabled() to BasePrimitive * Add layernorm/rmsnorm fallback * Add cast_fp8 fallback * Add transpose/cast_transpose XLA fall back * Act_lu fallback * Add transpose fallback * Add softmax fallback * Unify the use of _cast_fp8 * Add tests for NVTE_JAX_CUSTOM_CALLS_RE --------- Signed-off-by: Reese Wang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]>
* Add option to pass kwargs to CUDA graph module Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug unit tests Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tweak comments Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* DGRAD_RS UB overlap Bug fixes Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
fix 261 compile Signed-off-by: Frank Lin (Engrg-Hardware 1) <[email protected]> Co-authored-by: Frank Lin (Engrg-Hardware 1) <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…duce-scatter overlap (#1023) * FP8 type switch macro now wraps only the FP8 kernel to avoid invalid type errors Signed-off-by: Alp Dener <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Update nvtx header Signed-off-by: Reese Wang <[email protected]>
* initialize output tensors to 0 for THD while waiting for cuDNN bug fix Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move fill_() to F16 loop Signed-off-by: Charlene Yang <[email protected]> * fix fused_attn_bwd() Signed-off-by: Charlene Yang <[email protected]> * correct typo in check_set_window_size Signed-off-by: Charlene Yang <[email protected]> * use nvtx3 instead Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…e class (#1028) * Update sequential container constructor to handle modules in plain dicts Signed-off-by: Tim Moon <[email protected]> * Avoid initializing Sequential with dicts Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
* Fixed convergence issues Signed-off-by: Selvaraj Anandaraj <[email protected]> * Update transformer_engine/pytorch/module/layernorm_linear.py Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/module/layernorm_mlp.py Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* removed unwanted memcpyDtoD/fixed weight parametrisation Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…PR901 removal of MPI-dependence (#986) * Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes * passing data-parallel rank/size info from torch.distributed to userbuffers Signed-off-by: Alp Dener <[email protected]> * multi-node example working with UB_SKIPMC=1 but not with multicast Signed-off-by: Alp Dener <[email protected]> * fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README Signed-off-by: Alp Dener <[email protected]> * fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1 Signed-off-by: Alp Dener <[email protected]> * corrected example problem to set device by local ordinal instead of global process rank Signed-off-by: Alp Dener <[email protected]> * double-free fix in userbuffers destructor Signed-off-by: Alp Dener <[email protected]> * removed unnecessary and incorrect torch.cuda.set_device(...) Signed-off-by: Alp Dener <[email protected]> * corrected inter-node ranks logic Signed-off-by: Alp Dener <[email protected]> * generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node Signed-off-by: Alp Dener <[email protected]> * added single-node comm+GEMM overlap unit tests Signed-off-by: Alp Dener <[email protected]> * LayerNormMLP example confirmed working with 2 nodes on Eos Signed-off-by: Alp Dener <[email protected]> * unit test cleanup Signed-off-by: Alp Dener <[email protected]> * corrected DP group ranks logic in LNMLP comm+GEMM overlap example Signed-off-by: Alp Dener <[email protected]> * corrected enums in unit test Signed-off-by: Alp Dener <[email protected]> * fixed incorrect Ubuf object init signature Signed-off-by: Alp Dener <[email protected]> * switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend Signed-off-by: Alp Dener <[email protected]> * fixed all comm+GEMM overlap unit tests Signed-off-by: Alp Dener <[email protected]> * corrected all_gather use for Gloo backend Signed-off-by: Alp Dener <[email protected]> * changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor() Signed-off-by: Alp Dener <[email protected]> * restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time Signed-off-by: Alp Dener <[email protected]> * disabled scoped GIL release for comm+GEMM overlap algorithms Signed-off-by: Alp Dener <[email protected]> * avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions Signed-off-by: Alp Dener <[email protected]> * applied RS overlap FP8 fix from PR1004 Signed-off-by: Alp Dener <[email protected]> * fixed segfault in Userbuffers destructor Signed-off-by: Alp Dener <[email protected]> * corrected comm+GEMM overlap unit test arguments Signed-off-by: Alp Dener <[email protected]> * fixed unit test run command for when Userbuffers is compiled with MPI Signed-off-by: Alp Dener <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactored torch.distributed collectives into pure C++ callbacks Signed-off-by: Alp Dener <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Set minimum CMake version to 3.21 Stop linking to nvtx. Signed-off-by: Tim Moon <[email protected]> * Update .github/workflows/build.yml Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> * Revert Python version to 3.9 Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* i Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * . Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> --------- Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>
* Remove extra args to fused attention func Signed-off-by: Tim Moon <[email protected]> * Add missing arg to fused attention func Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
* Fix build error with Paddle >2.6.1 Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Fix linter warnings Signed-off-by: Tim Moon <[email protected]>
* Specify python version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add classifiers for python Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add utils to build wheels Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * make wheel scripts Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add aarch Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix paddle wheel Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * PaddlePaddle only builds for x86 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add optional fwk deps Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Python3.8; catch install error Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * [wip] cudnn9 compile with paddle support Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * [wip] dont link cudnn Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * dlopen cudnn Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * dynamically load nvrtc Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * remove residual packages; exclude stub from nvrtc .so search Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Exclude builtins from nvrtc .so search Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * properly include files for sdist Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * paddle wheel tie to python version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix paddle build from src [wip] Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix workflow paddle build Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix paddle Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix paddle Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix lint from pr986 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add sanity wheel test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add sanity import to wheel test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * remove upper limit on paddlepaddle version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Remove unused imports Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Remove pybind11 dependency Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cpp tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Search .sos in cuda home Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * CLeanup, remove residual code Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixes for wheels Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix paddle wheel test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…en. (#1031) This bug will cause bug [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: ~/megatron/bin/python. That is because we miss the rng_states that is required in attention recompute (for dropout), but no hint is provided. It is very very very difficult to trace and cost me two weeks. ```python before the start of training step] datetime: 2024-07-22 18:26:45 [2024-07-22 18:27:00,941] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: /home//miniconda3/envs/megatron/bin/python Traceback (most recent call last): File "/home//miniconda3/envs/megatron/bin/torchrun", line 33, in <module> sys.exit(load_entry_point('torch==2.2.1+cu121', 'console_scripts', 'torchrun')()) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper return f(*args, **kwargs) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main run(args) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run elastic_launch( File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ``` Signed-off-by: 李金梁 <[email protected]>
fix tp_size for GQA Signed-off-by: Charlene Yang <[email protected]>
Update Paddle image Signed-off-by: Tian Zheng <[email protected]>
add deterministic option Signed-off-by: Shijie Wang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…n (#1058) Rm unused import causing CI failures Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Load balanced offloading algorithm Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Ensure that the inputs to custom calls are contiguous Signed-off-by: Przemek Tredak <[email protected]> * Fixes Signed-off-by: Przemek Tredak <[email protected]> * Added test Signed-off-by: Przemek Tredak <[email protected]> * Fixes Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes from review Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Added tests for silu/relu/swiglu/reglu Signed-off-by: Przemek Tredak <[email protected]> * Fixes Signed-off-by: Przemek Tredak <[email protected]> * Added other activations/backwards and fixed dqgelu Signed-off-by: Przemek Tredak <[email protected]> * Fix Signed-off-by: Przemek Tredak <[email protected]> * Fix 2 Signed-off-by: Przemek Tredak <[email protected]> * Actually adding srelu and qgelu tests Signed-off-by: Przemek Tredak <[email protected]> * Fix glu backward test Signed-off-by: Przemek Tredak <[email protected]> * Pruning unnecessary test configurations Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]>
* fix workspaces and unfused bias in multi-stream cuBLAS * Expose num_streams via pybind * Fix C-compatibility * rm importing packaging in test_fused_attn.py --------- Signed-off-by: Xin Yao <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]>
* use 2hd layout Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren <[email protected]> * add a code comment Signed-off-by: Xiaowei Ren <[email protected]> * tensor shape bug fix Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren <[email protected]> * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren <[email protected]> * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren <[email protected]> * typo fix Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren <[email protected]> * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren <[email protected]> * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren <[email protected]> * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren <[email protected]> * fix softmax_lse correction Signed-off-by: Xiaowei Ren <[email protected]> * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren <[email protected]> * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren <[email protected]> * zero out tensors Signed-off-by: Xiaowei Ren <[email protected]> * remove redundant code Signed-off-by: Xiaowei Ren <[email protected]> * fix CP unit test Signed-off-by: Xiaowei Ren <[email protected]> * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren <[email protected]> * update cp unit test Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant code Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Co-authored-by: Xiaowei Ren <[email protected]>
…(#1175) * Check if network interface name is valid and show useful warning message when initializing Userbuffers Signed-off-by: Alp Dener <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix formatting issue in warning message. Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Alp Dener <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]>
* fix NVTE_UB_WITH_MPI read Signed-off-by: Sangkug Lym <[email protected]> * Add default value Signed-off-by: Sangkug Lym <[email protected]> --------- Signed-off-by: Sangkug Lym <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
skip FP8 CP tests if hardware does not support FP8 Signed-off-by: Xiaowei Ren <[email protected]>
…with offsets (#1220) * Removing the unused options from GroupedLinear docs and fixing the bug with offsets Signed-off-by: Przemyslaw Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * offsets -> fp8_meta_offsets Signed-off-by: Przemyslaw Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemyslaw Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
move block_table arg to varlen_func section Signed-off-by: Charlene Yang <[email protected]>
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Show resolved
Hide resolved
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): | ||
if (not IS_HIP_EXTENSION) and qkv_format == "thd" and get_device_compute_capability() < (9, 0): | ||
pytest.skip("THD format is only supported on sm90+!") | ||
if IS_HIP_EXTENSION and qkv_format == "thd": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just note. If move HIP THD and HIP FP8 condition below above corresponding CUDA conditions, the latter do not need (not IS_HIP_EXTENSION)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still need this (not IS_HIP_EXTENSION) condition since we also implemented get_device_compute_capability. On MI300, this get_device_compute_capability will return (9, 4) since we are gfx942.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean that after (IS_HIP_EXTENSION and qkv_format == "thd") (qkv_format == "thd" and get_device_compute_capability()...) can only be true in CUDA case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which line are you referring to?
line 116 if condition "(not IS_HIP_EXTENSION) and qkv_format == "thd" and get_device_compute_capability() < (9, 0)" is for cuda
line 118 if condition "IS_HIP_EXTENSION and qkv_format == "thd"" is for rocm
I don't think one of them is redundant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you swap and make (IS_HIP_EXTENSION and qkv_format == "thd") firsts, then (qkv_format == "thd" and get_device_compute_capability() < (9, 0)) will only be evaluated in case of CUDA, i.e. (not IS_HIP_EXTENSION) is not necessary. You can represent it it as {if (qkv_format == "thd") { if (IS_HIP_EXTENSION) { skip }; if (/not IS_HIP_EXTENSION is not needed here/ get_device_capability...) skip }}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed. Thanks
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Show resolved
Hide resolved
…and address more reviewer commits
6ff8ed3
to
18ee57c
Compare
3ef33ef
to
fe9e149
Compare
e1f2ecc
to
aa8b3ff
Compare
aa8b3ff
to
ff495d1
Compare
ci/jax.sh
Outdated
install_praxis() { | ||
git clone https://github.com/google/praxis.git && cd praxis || return $? | ||
git checkout $_praxis_commit || return $? | ||
#Remove unnecessary dependencies for testing and make sure JAX is not upgraded | ||
sed -i -e 's/^flax/#flax/;s/^jax /#jax /;s/^opt/#opt/;s/^tensorflow/#tensorflow/' requirements.in || return $? | ||
pip list | awk '/jax/ { print $1"=="$2}' >> requirements.in | ||
pip list | awk '/transformer_engine/ { print $1"=="$2}' >> requirements.in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be one awk call with line /jax|transformer_engine/
But why do we need it? Does Praxis depend on TE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. I'm not quite familiar with awk usage. Can you show me how to write two awk into one line?
The reason I add "transformer_engine==1.11.0" into requirements.in: if I don't add it, the pydantic version will mismatch and will cause some jax pytest failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then. Can you add comments that transformer engine is added to keep its dependencies and merge both lines to one with /jax|transformer_engine/ instead of separate /jax/ and /tansformer_engine/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combined jax and transformer_engine in one awk line. Thanks Ilya for your suggestion!
ci/jax.sh
Outdated
install_flax() { | ||
pip list | awk '/jax/ { print $1"=="$2}' > reqs | ||
pip install flax -r reqs | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't Flax installed as Praxis prerequisite? Why do we need separate installation?
And please take care of reqs file. Which directory is it created in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it. Let me see whether it can pass the CI. Thanks
Description
IFU to upstream NVTE release v1.11 based on current dev (v1.9)
Some jax pytests are failing:
1). test_layer.py:
1.1) on both old/new rocm6.3 docker, failed 100 tests due to https://ontrack-internal.amd.com/browse/SWDEV-499638, adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL" does not help.
1.2) on a rocm6.2 docker, printing the following warning messages:
2024-12-12 22:23:14.379038: E external/xla/xla/service/gpu/gemm_algorithm_picker.cc:346] Results mismatch between different GEMM algorithms. This is likely a bug/unexpected loss of precision. E1212 22:23:14.497169 1460702 buffer_comparator.cc:157] Difference at 340: -8.04183e+37, expected -inf E1212 22:23:14.497207 1460702 buffer_comparator.cc:157] Difference at 1000: 1.01021e+38, expected inf E1212 22:23:14.497216 1460702 buffer_comparator.cc:157] Difference at 3408: 2.41919e+38, expected inf E1212 22:23:14.497226 1460702 buffer_comparator.cc:157] Difference at 6484: -1.95397e+38, expected -inf E1212 22:23:14.497230 1460702 buffer_comparator.cc:157] Difference at 7582: -1.48209e+38, expected -inf E1212 22:23:14.497233 1460702 buffer_comparator.cc:157] Difference at 8020: inf, expected 2.53883e+38 E1212 22:23:14.497236 1460702 buffer_comparator.cc:157] Difference at 8650: -inf, expected -1.32258e+38 E1212 22:23:14.497238 1460702 buffer_comparator.cc:157] Difference at 8680: 2.56541e+38, expected inf E1212 22:23:14.497239 1460702 buffer_comparator.cc:157] Difference at 8746: inf, expected 1.00357e+38 E1212 22:23:14.497243 1460702 buffer_comparator.cc:157] Difference at 9552: 1.68147e+38, expected inf
2). test_distributed_fused_attn.py:
2.1). Segmentation fault due to https://ontrack-internal.amd.com/browse/SWDEV-482895, now resolved by setting XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL"
2.2). Another segmentation fault under a rocm6.3 docker image (compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.3:20_ubuntu24.04_py3.12_jax_rocm-jaxlib-v0.4.31-qa_7d0d2dd), but okay with a rocm6.2 docker image (rocm/jax-community:rocm6.2.3-jax0.4.31-py3.12.6)
2.3) On the new rocm6.3 docker, run through without any issues
3). test_distributed_softmax.py
3.1). in the old rocm6.3 docker, failed due to same segmentation fault as 2.2
3.2). in rocm6.2 docker, all pass
3.3). in the new rocm6.3 docker, all pass
4). test_distributed_layernorm_mlp:
4.1). hipblaslt issue under the rocm6.2 docker:
operation would make the legacy stream depend on a capturing blocking stream
4.2). hipblaslt issue under the rocm6.3 docker:
operation would make the legacy stream depend on a capturing blocking stream
4.3). under the new rocm6.2 docker, after adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL", all pass but with warnings in hipblaslt:
2024-12-12 23:05:59.040976: E external/xla/xla/service/gpu/gemm_algorithm_picker.cc:346] Results mismatch between different GEMM algorithms. This is likely a bug/unexpected loss of precision.
4.4). under the new rocm6.2 docker, after adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL", all passed without any warning messages
5). test_praxis_layer.py: not able to find a compatible praxis installation
Type of change
Changes
1). Merge commit 377252
2). Resolving commit 2b6dc5f
Checklist: