Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Nov 3, 2024
1 parent 07a9b59 commit ad647f2
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4004,6 +4004,8 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
if hasattr(config, "text_config"):
config.text_config.rms_norm_eps = 1.0
config.text_config.layer_norm_eps = 1.0
Expand Down Expand Up @@ -4033,6 +4035,13 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)

for x in model_eager.modules():
if isinstance(x, nn.LayerNorm):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, nn.LayerNorm):
x.eps = 1.0

# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
Expand Down

0 comments on commit ad647f2

Please sign in to comment.