Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 6, 2024
1 parent 2bb25ee commit 9831b5e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ def _check_mod_list(mod_list: list[dict[str, Any]], name: str):
flex_attn_mod_list.append({
'mod_name': 'sliding_window_mask',
'mod_kwargs': {
'sliding_window_size': sliding_window_size,
'sliding_window_size':
torch.tensor(sliding_window_size, device=query.device),
},
})
if sequence_id_info is not None and 'sequence_id' in sequence_id_info and sequence_id_info[
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/flex_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def _mask_mod_fn(
) -> torch.Tensor:
del sequence_id_info, b, h
q_idx = q_idx + query_offset
return q_idx - kv_idx <= self.sliding_window_size
return torch.abs(q_idx - kv_idx) <= self.sliding_window_size

def __init__(self, sliding_window_size: int) -> None:
def __init__(self, sliding_window_size: torch.Tensor) -> None:
super().__init__(mod_type='mask')
self.sliding_window_size = sliding_window_size

Expand Down

0 comments on commit 9831b5e

Please sign in to comment.