From 8981209546ce9ebc6407f8c518df69df2fec7951 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 05:34:13 +0000 Subject: [PATCH] try del cache --- tests/models/layers/test_dmoe.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..082a1ad02a 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -7,6 +7,8 @@ from typing import List, Optional import pytest +import shutil +import os import torch import torch.distributed as dist import torch.nn.functional as F @@ -187,6 +189,21 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_y = mb_dmoe(x) torch.testing.assert_close(torch_y, mb_y) +# TODO(GRT-2435): Change to fixture +def delete_transformers_cache(): + # Only delete the files on local rank 0, otherwise race conditions are created + if not dist.get_local_rank() == 0: + return + + hf_cache_home = os.path.expanduser( + os.getenv( + 'HF_HOME', + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), + 'huggingface'))) + HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE', + os.path.join(hf_cache_home, 'modules')) + if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): + shutil.rmtree(HF_MODULES_CACHE) @pytest.mark.skipif(not is_megablocks_imported, reason='This test needs megablocks module') @@ -195,6 +212,8 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('precision', ['bf16', 'fp32']) def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): + delete_transformers_cache() + mb_dmoe_config = MPTConfig(d_model=1024, n_heads=32, n_layers=1, @@ -261,3 +280,5 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): mpt_logits = mb_dmoe_model(token_ids).logits db_logits = torch_dmoe_model(token_ids).logits assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01) + + delete_transformers_cache()