From 0a5026e3887a23c7a2cdb58d6f03632cd805ba0c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 12 Apr 2024 04:34:42 +0000 Subject: [PATCH] fix tests --- tests/models/layers/test_dmoe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 9c15745793..c8e7ec3e67 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -239,6 +239,10 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): torch_dmoe_config = copy.deepcopy(mb_dmoe_config) torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe' + del torch_dmoe_config.ffn_config['moe_world_size'] + del torch_dmoe_config.ffn_config['fc_type'] + del torch_dmoe_config.ffn_config['moe_loss_weight'] + del torch_dmoe_config.ffn_config['return_bias'] mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, dtype=dtype)