Skip to content

Commit

Permalink
more debug cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
1 parent f072487 commit 4ead37c
Showing 1 changed file with 0 additions and 31 deletions.
31 changes: 0 additions & 31 deletions tests/models/layers/test_dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import os
import shutil
from contextlib import nullcontext
from functools import partial
from typing import List, Optional
Expand All @@ -13,7 +11,6 @@
import torch.distributed as dist
import torch.nn.functional as F
import torch.optim as optim
from composer.utils import dist as cdist
from torch.distributed._tensor import DTensor, Placement, Replicate, Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
Expand Down Expand Up @@ -191,39 +188,13 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int,
torch.testing.assert_close(torch_y, mb_y)


# TODO(GRT-2435): Change to fixture
def delete_transformers_cache():
# Only delete the files on local rank 0, otherwise race conditions are created
if not cdist.get_local_rank() == 0:
return

hf_cache_home = os.path.expanduser(
os.getenv(
'HF_HOME',
os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'),
'huggingface')))
HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE',
os.path.join(hf_cache_home, 'modules'))
if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE):
shutil.rmtree(HF_MODULES_CACHE)


@pytest.mark.skipif(not is_megablocks_imported,
reason='This test needs megablocks module')
@pytest.mark.gpu
@pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('mlp_type', ['glu', 'mlp'])
@pytest.mark.parametrize('precision', ['bf16', 'fp32'])
def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str):
delete_transformers_cache()

from llmfoundry.layers_registry import module_init_fns
print(module_init_fns.get_all())

from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore

print(module_init_fns.get_all())

mb_dmoe_config = MPTConfig(d_model=1024,
n_heads=32,
n_layers=1,
Expand Down Expand Up @@ -290,5 +261,3 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str):
mpt_logits = mb_dmoe_model(token_ids).logits
db_logits = torch_dmoe_model(token_ids).logits
assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01)

delete_transformers_cache()

0 comments on commit 4ead37c

Please sign in to comment.