Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
..
Browse files Browse the repository at this point in the history
ShashankMosaicML committed Dec 5, 2024
1 parent 5093efd commit 96b8f82
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
@@ -424,12 +424,13 @@ def __init__(self, config: MPTConfig):
self.shift_labels = True

if self.attn_impl == 'flex':
flex_attn_compile = config.attn_config.pop('flex_attn_compile')
self.compiled_flex_attention = torch.compile(
flex_attention,
) if config.attn_config['flex_attn_compile'] else flex_attention
) if flex_attn_compile else flex_attention
self.compiled_create_block_mask = torch.compile(
create_block_mask,
) if config.attn_config['flex_attn_compile'] else create_block_mask
) if flex_attn_compile else create_block_mask

self.blocks = self.construct_blocks(config=config,)

0 comments on commit 96b8f82

Please sign in to comment.