From 18afcc5f1ab08a664e035034f83d32881c3fa199 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Dec 2024 09:43:06 -0800 Subject: [PATCH] .. --- llmfoundry/models/layers/blocks.py | 1 + llmfoundry/models/mpt/modeling_mpt.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index e9ca5c17ba..d2e81a886b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -51,6 +51,7 @@ def __init__( ): if attn_config is None: attn_config = attn_config_defaults + attn_config.pop('flex_attn_compile', None) if ffn_config is None: self.ffn_config: dict[str, Any] = { diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 22b442104b..d3eb834365 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -423,7 +423,7 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True - flex_attn_compile = config.attn_config.pop('flex_attn_compile') + flex_attn_compile = config.attn_config.pop('flex_attn_compile', False) if self.attn_impl == 'flex': self.compiled_flex_attention = torch.compile( flex_attention,