Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 5, 2024
1 parent f1ad991 commit 18afcc5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
):
if attn_config is None:
attn_config = attn_config_defaults
attn_config.pop('flex_attn_compile', None)

if ffn_config is None:
self.ffn_config: dict[str, Any] = {
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(self, config: MPTConfig):
self.mb_args = None
self.shift_labels = True

flex_attn_compile = config.attn_config.pop('flex_attn_compile')
flex_attn_compile = config.attn_config.pop('flex_attn_compile', False)
if self.attn_impl == 'flex':
self.compiled_flex_attention = torch.compile(
flex_attention,
Expand Down

0 comments on commit 18afcc5

Please sign in to comment.