diff --git a/tests/tp/test_tp_strategies.py b/tests/tp/test_tp_strategies.py index 5ce63e78e4..77ec61de26 100644 --- a/tests/tp/test_tp_strategies.py +++ b/tests/tp/test_tp_strategies.py @@ -2,12 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import os +import shutil from pathlib import Path -from shutil import rmtree from tempfile import TemporaryDirectory import pytest -from composer.utils import dist from omegaconf import OmegaConf as om from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -122,7 +121,7 @@ def test_tp_train(tp_strategy: str): finally: # always remove data directory if os.path.isdir(data_dir): - rmtree(data_dir) + shutil.rmtree(data_dir) @pytest.mark.gpu