From 96b8f82e4b5bfaf2e1a94360012c076ebaddb17b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Dec 2024 01:41:12 -0800 Subject: [PATCH] .. --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0654a81839..69d562ad1e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -424,12 +424,13 @@ def __init__(self, config: MPTConfig): self.shift_labels = True if self.attn_impl == 'flex': + flex_attn_compile = config.attn_config.pop('flex_attn_compile') self.compiled_flex_attention = torch.compile( flex_attention, - ) if config.attn_config['flex_attn_compile'] else flex_attention + ) if flex_attn_compile else flex_attention self.compiled_create_block_mask = torch.compile( create_block_mask, - ) if config.attn_config['flex_attn_compile'] else create_block_mask + ) if flex_attn_compile else create_block_mask self.blocks = self.construct_blocks(config=config,)