diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0654a81839..69d562ad1e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -424,12 +424,13 @@ def __init__(self, config: MPTConfig): self.shift_labels = True if self.attn_impl == 'flex': + flex_attn_compile = config.attn_config.pop('flex_attn_compile') self.compiled_flex_attention = torch.compile( flex_attention, - ) if config.attn_config['flex_attn_compile'] else flex_attention + ) if flex_attn_compile else flex_attention self.compiled_create_block_mask = torch.compile( create_block_mask, - ) if config.attn_config['flex_attn_compile'] else create_block_mask + ) if flex_attn_compile else create_block_mask self.blocks = self.construct_blocks(config=config,)