diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 957b243baf..a0628acb7a 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -49,7 +49,7 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = norms.get(mod_name) else: msg = ', '.join( - list(attention_classes.keys()) + list(ffns.get_all()) + + list(attention_classes.get_all()) + list(ffns.get_all()) + list(ffns_with_norm.get_all()) + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + ['MPTBlock'],