Skip to content

Commit

Permalink
bring back parallel_context.destroy()
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Feb 13, 2024
1 parent 6dcb73d commit b64f04f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Instal nanotron
- name: Install nanotron
run: |
python -m pip install --upgrade pip
pip install packaging
Expand Down
7 changes: 3 additions & 4 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float
print(parallel_context.__dir__())

parallel_context.destroy()
parallel_context.destroyaa()


@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_with_tp requires at least 2 gpus")
Expand Down Expand Up @@ -343,7 +342,7 @@ def _test_clip_grads_with_tp(
)
torch.testing.assert_close(total_norm, ref_total_norm)

# parallel_context.destroy()
parallel_context.destroy()


@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_tied_weights requires at least 2 gpus")
Expand Down Expand Up @@ -438,7 +437,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
assert torch.allclose(bias.grad, ref_bias.grad, rtol=1e-7, atol=1e-6)
assert torch.allclose(total_norm, ref_total_norm, rtol=0, atol=0), f"Got {total_norm} and {ref_total_norm}"

# parallel_context.destroy()
parallel_context.destroy()


@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16])
Expand Down Expand Up @@ -628,4 +627,4 @@ def _test_clip_grads_fp32_accumulator(
to_rank=reference_rank,
)

# parallel_context.destroy()
parallel_context.destroy()

0 comments on commit b64f04f

Please sign in to comment.