Skip to content

Commit

Permalink
Adds more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Apr 30, 2024
1 parent 15e9918 commit a7eb513
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

class PallasTest(unittest.TestCase):

# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
Expand All @@ -32,6 +35,7 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids,
def _attention(self, q, k, v, *, attn_mask=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
Expand Down

0 comments on commit a7eb513

Please sign in to comment.