Skip to content

Commit

Permalink
try del cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 12, 2024
1 parent 6688615 commit 8981209
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/models/layers/test_dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import List, Optional

import pytest
import shutil
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
Expand Down Expand Up @@ -187,6 +189,21 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int,
mb_y = mb_dmoe(x)
torch.testing.assert_close(torch_y, mb_y)

# TODO(GRT-2435): Change to fixture
def delete_transformers_cache():
# Only delete the files on local rank 0, otherwise race conditions are created
if not dist.get_local_rank() == 0:
return

hf_cache_home = os.path.expanduser(
os.getenv(
'HF_HOME',
os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'),
'huggingface')))
HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE',
os.path.join(hf_cache_home, 'modules'))
if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE):
shutil.rmtree(HF_MODULES_CACHE)

@pytest.mark.skipif(not is_megablocks_imported,
reason='This test needs megablocks module')
Expand All @@ -195,6 +212,8 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int,
@pytest.mark.parametrize('mlp_type', ['glu', 'mlp'])
@pytest.mark.parametrize('precision', ['bf16', 'fp32'])
def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str):
delete_transformers_cache()

mb_dmoe_config = MPTConfig(d_model=1024,
n_heads=32,
n_layers=1,
Expand Down Expand Up @@ -261,3 +280,5 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str):
mpt_logits = mb_dmoe_model(token_ids).logits
db_logits = torch_dmoe_model(token_ids).logits
assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01)

delete_transformers_cache()

0 comments on commit 8981209

Please sign in to comment.