diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 10bbe6427e..328140d4a3 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -2,17 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import copy +import os +import shutil from contextlib import nullcontext from functools import partial from typing import List, Optional import pytest -import shutil -import os import torch import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim +from composer.utils import dist as cdist from torch.distributed._tensor import DTensor, Placement, Replicate, Shard from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import (StateDictOptions, @@ -20,8 +21,6 @@ from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform from torch.nn.parallel import DistributedDataParallel as DDP -from composer.utils import dist as cdist - from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.ffn import dtensorify_param from llmfoundry.models.mpt.configuration_mpt import MPTConfig @@ -191,6 +190,7 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_y = mb_dmoe(x) torch.testing.assert_close(torch_y, mb_y) + # TODO(GRT-2435): Change to fixture def delete_transformers_cache(): # Only delete the files on local rank 0, otherwise race conditions are created @@ -207,6 +207,7 @@ def delete_transformers_cache(): if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): shutil.rmtree(HF_MODULES_CACHE) + @pytest.mark.skipif(not is_megablocks_imported, reason='This test needs megablocks module') @pytest.mark.gpu