Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ifu release v1.11 nov #95

Merged
merged 122 commits into from
Dec 25, 2024
Merged

Ifu release v1.11 nov #95

merged 122 commits into from
Dec 25, 2024

Conversation

wangye805
Copy link
Contributor

@wangye805 wangye805 commented Nov 26, 2024

Description

IFU to upstream NVTE release v1.11 based on current dev (v1.9)

Some jax pytests are failing:
1). test_layer.py:
1.1) on both old/new rocm6.3 docker, failed 100 tests due to https://ontrack-internal.amd.com/browse/SWDEV-499638, adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL" does not help.
1.2) on a rocm6.2 docker, printing the following warning messages:
2024-12-12 22:23:14.379038: E external/xla/xla/service/gpu/gemm_algorithm_picker.cc:346] Results mismatch between different GEMM algorithms. This is likely a bug/unexpected loss of precision. E1212 22:23:14.497169 1460702 buffer_comparator.cc:157] Difference at 340: -8.04183e+37, expected -inf E1212 22:23:14.497207 1460702 buffer_comparator.cc:157] Difference at 1000: 1.01021e+38, expected inf E1212 22:23:14.497216 1460702 buffer_comparator.cc:157] Difference at 3408: 2.41919e+38, expected inf E1212 22:23:14.497226 1460702 buffer_comparator.cc:157] Difference at 6484: -1.95397e+38, expected -inf E1212 22:23:14.497230 1460702 buffer_comparator.cc:157] Difference at 7582: -1.48209e+38, expected -inf E1212 22:23:14.497233 1460702 buffer_comparator.cc:157] Difference at 8020: inf, expected 2.53883e+38 E1212 22:23:14.497236 1460702 buffer_comparator.cc:157] Difference at 8650: -inf, expected -1.32258e+38 E1212 22:23:14.497238 1460702 buffer_comparator.cc:157] Difference at 8680: 2.56541e+38, expected inf E1212 22:23:14.497239 1460702 buffer_comparator.cc:157] Difference at 8746: inf, expected 1.00357e+38 E1212 22:23:14.497243 1460702 buffer_comparator.cc:157] Difference at 9552: 1.68147e+38, expected inf

2). test_distributed_fused_attn.py:
2.1). Segmentation fault due to https://ontrack-internal.amd.com/browse/SWDEV-482895, now resolved by setting XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL"
2.2). Another segmentation fault under a rocm6.3 docker image (compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.3:20_ubuntu24.04_py3.12_jax_rocm-jaxlib-v0.4.31-qa_7d0d2dd), but okay with a rocm6.2 docker image (rocm/jax-community:rocm6.2.3-jax0.4.31-py3.12.6)
2.3) On the new rocm6.3 docker, run through without any issues

3). test_distributed_softmax.py
3.1). in the old rocm6.3 docker, failed due to same segmentation fault as 2.2
3.2). in rocm6.2 docker, all pass
3.3). in the new rocm6.3 docker, all pass

4). test_distributed_layernorm_mlp:
4.1). hipblaslt issue under the rocm6.2 docker:
operation would make the legacy stream depend on a capturing blocking stream
4.2). hipblaslt issue under the rocm6.3 docker:
operation would make the legacy stream depend on a capturing blocking stream
4.3). under the new rocm6.2 docker, after adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL", all pass but with warnings in hipblaslt:
2024-12-12 23:05:59.040976: E external/xla/xla/service/gpu/gemm_algorithm_picker.cc:346] Results mismatch between different GEMM algorithms. This is likely a bug/unexpected loss of precision.
4.4). under the new rocm6.2 docker, after adding XLA_FLAGS="--xla_gpu_enable_dot_strength_reduction=false --xla_gpu_enable_command_buffer=CUSTOM_CALL", all passed without any warning messages

5). test_praxis_layer.py: not able to find a compatible praxis installation

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

1). Merge commit 377252
2). Resolving commit 2b6dc5f

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ptrendx and others added 30 commits July 15, 2024 17:54
Signed-off-by: Przemek Tredak <[email protected]>
…iable (#1007)

* Add enabled() to BasePrimitive

* Add layernorm/rmsnorm fallback

* Add cast_fp8 fallback

* Add transpose/cast_transpose XLA fall back

* Act_lu fallback

* Add transpose fallback

* Add softmax fallback

* Unify the use of _cast_fp8

* Add tests for NVTE_JAX_CUSTOM_CALLS_RE

---------

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
* Add option to pass kwargs to CUDA graph module

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug unit tests

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Tweak comments

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* DGRAD_RS UB overlap Bug fixes

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
fix 261 compile

Signed-off-by: Frank Lin (Engrg-Hardware 1) <[email protected]>
Co-authored-by: Frank Lin (Engrg-Hardware 1) <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…duce-scatter overlap (#1023)

* FP8 type switch macro now wraps only the FP8 kernel to avoid invalid type errors

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Update nvtx header

Signed-off-by: Reese Wang <[email protected]>
* initialize output tensors to 0 for THD while waiting for cuDNN bug fix

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move fill_() to F16 loop

Signed-off-by: Charlene Yang <[email protected]>

* fix fused_attn_bwd()

Signed-off-by: Charlene Yang <[email protected]>

* correct typo in check_set_window_size

Signed-off-by: Charlene Yang <[email protected]>

* use nvtx3 instead

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…e class (#1028)

* Update sequential container constructor to handle modules in plain dicts

Signed-off-by: Tim Moon <[email protected]>

* Avoid initializing Sequential with dicts

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
* Fixed convergence issues

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Update transformer_engine/pytorch/module/layernorm_linear.py

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/module/layernorm_mlp.py

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/module/linear.py

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* removed unwanted memcpyDtoD/fixed weight parametrisation

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…PR901 removal of MPI-dependence (#986)

* Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes

* passing data-parallel rank/size info from torch.distributed to userbuffers

Signed-off-by: Alp Dener <[email protected]>

* multi-node example working with UB_SKIPMC=1 but not with multicast

Signed-off-by: Alp Dener <[email protected]>

* fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README

Signed-off-by: Alp Dener <[email protected]>

* fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1

Signed-off-by: Alp Dener <[email protected]>

* corrected example problem to set device by local ordinal instead of global process rank

Signed-off-by: Alp Dener <[email protected]>

* double-free fix in userbuffers destructor

Signed-off-by: Alp Dener <[email protected]>

* removed unnecessary and incorrect torch.cuda.set_device(...)

Signed-off-by: Alp Dener <[email protected]>

* corrected inter-node ranks logic

Signed-off-by: Alp Dener <[email protected]>

* generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node

Signed-off-by: Alp Dener <[email protected]>

* added single-node comm+GEMM overlap unit tests

Signed-off-by: Alp Dener <[email protected]>

* LayerNormMLP example confirmed working with 2 nodes on Eos

Signed-off-by: Alp Dener <[email protected]>

* unit test cleanup

Signed-off-by: Alp Dener <[email protected]>

* corrected DP group ranks logic in LNMLP comm+GEMM overlap example

Signed-off-by: Alp Dener <[email protected]>

* corrected enums in unit test

Signed-off-by: Alp Dener <[email protected]>

* fixed incorrect Ubuf object init signature

Signed-off-by: Alp Dener <[email protected]>

* switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend

Signed-off-by: Alp Dener <[email protected]>

* fixed all comm+GEMM overlap unit tests

Signed-off-by: Alp Dener <[email protected]>

* corrected all_gather use for Gloo backend

Signed-off-by: Alp Dener <[email protected]>

* changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor()

Signed-off-by: Alp Dener <[email protected]>

* restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time

Signed-off-by: Alp Dener <[email protected]>

* disabled scoped GIL release for comm+GEMM overlap algorithms

Signed-off-by: Alp Dener <[email protected]>

* avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions

Signed-off-by: Alp Dener <[email protected]>

* applied RS overlap FP8 fix from PR1004

Signed-off-by: Alp Dener <[email protected]>

* fixed segfault in Userbuffers destructor

Signed-off-by: Alp Dener <[email protected]>

* corrected comm+GEMM overlap unit test arguments

Signed-off-by: Alp Dener <[email protected]>

* fixed unit test run command for when Userbuffers is compiled with MPI

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactored torch.distributed collectives into pure C++ callbacks

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Set minimum CMake version to 3.21

Stop linking to nvtx.

Signed-off-by: Tim Moon <[email protected]>

* Update .github/workflows/build.yml

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* Revert Python version to 3.9

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* i

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>

* .

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>

---------

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>
* Remove extra args to fused attention func

Signed-off-by: Tim Moon <[email protected]>

* Add missing arg to fused attention func

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
* Fix build error with Paddle >2.6.1

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>
* Specify python version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add classifiers for python

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add utils to build wheels

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* make wheel scripts

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add aarch

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix paddle wheel

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* PaddlePaddle only builds for x86

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add optional fwk deps

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Python3.8; catch install error

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [wip] cudnn9 compile with paddle support

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [wip] dont link cudnn

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* dlopen cudnn

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* dynamically load nvrtc

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* remove residual packages; exclude stub from nvrtc .so search

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Exclude builtins from nvrtc .so search

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* properly include files for sdist

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* paddle wheel tie to python version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix paddle build from src [wip]

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix workflow paddle build

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix paddle

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix paddle

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix lint from pr986

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add sanity wheel test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add sanity import to wheel test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* remove upper limit on paddlepaddle version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Remove unused imports

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Remove pybind11 dependency

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cpp tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Search .sos in cuda home

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* CLeanup, remove residual code

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixes for wheels

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix paddle wheel test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…en. (#1031)

This bug will cause bug [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: ~/megatron/bin/python.

That is because we miss the rng_states that is required in attention recompute (for dropout), but no hint is provided.  

It is very very very difficult to trace and cost me two weeks.

```python
before the start of training step] datetime: 2024-07-22 18:26:45 
[2024-07-22 18:27:00,941] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: /home//miniconda3/envs/megatron/bin/python
Traceback (most recent call last):
  File "/home//miniconda3/envs/megatron/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.2.1+cu121', 'console_scripts', 'torchrun')())
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
```

Signed-off-by: 李金梁 <[email protected]>
fix tp_size for GQA

Signed-off-by: Charlene Yang <[email protected]>
Update Paddle image

Signed-off-by: Tian Zheng <[email protected]>
add deterministic option

Signed-off-by: Shijie Wang <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
…n (#1058)

Rm unused import causing CI failures

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
* Load balanced offloading algorithm

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Ensure that the inputs to custom calls are contiguous

Signed-off-by: Przemek Tredak <[email protected]>

* Fixes

Signed-off-by: Przemek Tredak <[email protected]>

* Added test

Signed-off-by: Przemek Tredak <[email protected]>

* Fixes

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes from review

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
* Added tests for silu/relu/swiglu/reglu

Signed-off-by: Przemek Tredak <[email protected]>

* Fixes

Signed-off-by: Przemek Tredak <[email protected]>

* Added other activations/backwards and fixed dqgelu

Signed-off-by: Przemek Tredak <[email protected]>

* Fix

Signed-off-by: Przemek Tredak <[email protected]>

* Fix 2

Signed-off-by: Przemek Tredak <[email protected]>

* Actually adding srelu and qgelu tests

Signed-off-by: Przemek Tredak <[email protected]>

* Fix glu backward test

Signed-off-by: Przemek Tredak <[email protected]>

* Pruning unnecessary test configurations

Signed-off-by: Przemek Tredak <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
* fix workspaces and unfused bias in multi-stream cuBLAS

* Expose num_streams via pybind

* Fix C-compatibility

* rm importing packaging in test_fused_attn.py

---------

Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
* use 2hd layout

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change qkv_format check

Signed-off-by: Xiaowei Ren <[email protected]>

* add a code comment

Signed-off-by: Xiaowei Ren <[email protected]>

* tensor shape bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor shape fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add function to compute cu_seqlens of a cp rank

Signed-off-by: Xiaowei Ren <[email protected]>

* add cu_seqlens and cu_seqlens_padded to context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix FlashAttention output sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* fix cu_seqlens_kv_per_step calculation

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV for ending padded tokens

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV tensors of FlashAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse correction

Signed-off-by: Xiaowei Ren <[email protected]>

* remove padded tokens of KV to save comounication

Signed-off-by: Xiaowei Ren <[email protected]>

* do not need to zero dkv for FlashAttention any mroe

Signed-off-by: Xiaowei Ren <[email protected]>

* zero out tensors

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of cp test with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <[email protected]>
Co-authored-by: Xiaowei Ren <[email protected]>
denera and others added 7 commits September 27, 2024 11:40
…(#1175)

* Check if network interface name is valid and show useful warning message when initializing Userbuffers

Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix formatting issue in warning message.

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Alp Dener <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
* fix NVTE_UB_WITH_MPI read

Signed-off-by: Sangkug Lym <[email protected]>

* Add default value

Signed-off-by: Sangkug Lym <[email protected]>

---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
skip FP8 CP tests if hardware does not support FP8

Signed-off-by: Xiaowei Ren <[email protected]>
…with offsets (#1220)

* Removing the unused options from GroupedLinear docs and fixing the bug
with offsets

Signed-off-by: Przemyslaw Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* offsets -> fp8_meta_offsets

Signed-off-by: Przemyslaw Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
move block_table arg to varlen_func section

Signed-off-by: Charlene Yang <[email protected]>
build_tools/utils.py Show resolved Hide resolved
build_tools/utils.py Outdated Show resolved Hide resolved
transformer_engine/common/CMakeLists.txt Show resolved Hide resolved
transformer_engine/pytorch/csrc/type_shim.h Show resolved Hide resolved
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if (not IS_HIP_EXTENSION) and qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if IS_HIP_EXTENSION and qkv_format == "thd":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note. If move HIP THD and HIP FP8 condition below above corresponding CUDA conditions, the latter do not need (not IS_HIP_EXTENSION)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need this (not IS_HIP_EXTENSION) condition since we also implemented get_device_compute_capability. On MI300, this get_device_compute_capability will return (9, 4) since we are gfx942.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that after (IS_HIP_EXTENSION and qkv_format == "thd") (qkv_format == "thd" and get_device_compute_capability()...) can only be true in CUDA case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which line are you referring to?
line 116 if condition "(not IS_HIP_EXTENSION) and qkv_format == "thd" and get_device_compute_capability() < (9, 0)" is for cuda
line 118 if condition "IS_HIP_EXTENSION and qkv_format == "thd"" is for rocm
I don't think one of them is redundant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you swap and make (IS_HIP_EXTENSION and qkv_format == "thd") firsts, then (qkv_format == "thd" and get_device_compute_capability() < (9, 0)) will only be evaluated in case of CUDA, i.e. (not IS_HIP_EXTENSION) is not necessary. You can represent it it as {if (qkv_format == "thd") { if (IS_HIP_EXTENSION) { skip }; if (/not IS_HIP_EXTENSION is not needed here/ get_device_capability...) skip }}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Thanks

transformer_engine/pytorch/csrc/type_shim.h Show resolved Hide resolved
qa/L0_pytorch_unittest/test.sh Show resolved Hide resolved
@wangye805 wangye805 force-pushed the ifu_release_v1.11_Nov branch from 6ff8ed3 to 18ee57c Compare December 12, 2024 21:07
build_tools/utils.py Show resolved Hide resolved
ci/pytorch.sh Outdated Show resolved Hide resolved
@wangye805 wangye805 force-pushed the ifu_release_v1.11_Nov branch from 3ef33ef to fe9e149 Compare December 18, 2024 22:27
@wangye805 wangye805 requested a review from ipanfilo December 18, 2024 22:27
@wangye805 wangye805 force-pushed the ifu_release_v1.11_Nov branch 2 times, most recently from e1f2ecc to aa8b3ff Compare December 20, 2024 20:57
@wangye805 wangye805 force-pushed the ifu_release_v1.11_Nov branch from aa8b3ff to ff495d1 Compare December 22, 2024 22:07
ci/jax.sh Show resolved Hide resolved
ci/jax.sh Outdated
install_praxis() {
git clone https://github.com/google/praxis.git && cd praxis || return $?
git checkout $_praxis_commit || return $?
#Remove unnecessary dependencies for testing and make sure JAX is not upgraded
sed -i -e 's/^flax/#flax/;s/^jax /#jax /;s/^opt/#opt/;s/^tensorflow/#tensorflow/' requirements.in || return $?
pip list | awk '/jax/ { print $1"=="$2}' >> requirements.in
pip list | awk '/transformer_engine/ { print $1"=="$2}' >> requirements.in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be one awk call with line /jax|transformer_engine/
But why do we need it? Does Praxis depend on TE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. I'm not quite familiar with awk usage. Can you show me how to write two awk into one line?

The reason I add "transformer_engine==1.11.0" into requirements.in: if I don't add it, the pydantic version will mismatch and will cause some jax pytest failures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then. Can you add comments that transformer engine is added to keep its dependencies and merge both lines to one with /jax|transformer_engine/ instead of separate /jax/ and /tansformer_engine/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined jax and transformer_engine in one awk line. Thanks Ilya for your suggestion!

ci/jax.sh Outdated
install_flax() {
pip list | awk '/jax/ { print $1"=="$2}' > reqs
pip install flax -r reqs
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't Flax installed as Praxis prerequisite? Why do we need separate installation?
And please take care of reqs file. Which directory is it created in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it. Let me see whether it can pass the CI. Thanks

@wangye805 wangye805 requested a review from ipanfilo December 23, 2024 17:01
@wangye805 wangye805 merged commit b279c5a into dev Dec 25, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.