Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
1 parent 4fe928c commit 88a6511
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/models/layers/test_dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import os
import shutil
from contextlib import nullcontext
from functools import partial
from typing import List, Optional

import pytest
import shutil
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.optim as optim
from composer.utils import dist as cdist
from torch.distributed._tensor import DTensor, Placement, Replicate, Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
get_model_state_dict)
from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
from torch.nn.parallel import DistributedDataParallel as DDP

from composer.utils import dist as cdist

from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.ffn import dtensorify_param
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
Expand Down Expand Up @@ -191,6 +190,7 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int,
mb_y = mb_dmoe(x)
torch.testing.assert_close(torch_y, mb_y)


# TODO(GRT-2435): Change to fixture
def delete_transformers_cache():
# Only delete the files on local rank 0, otherwise race conditions are created
Expand All @@ -207,6 +207,7 @@ def delete_transformers_cache():
if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE):
shutil.rmtree(HF_MODULES_CACHE)


@pytest.mark.skipif(not is_megablocks_imported,
reason='This test needs megablocks module')
@pytest.mark.gpu
Expand Down

0 comments on commit 88a6511

Please sign in to comment.