diff --git a/test/test_pallas.py b/test/test_pallas.py index ab88bf6070c..7b8755fc71e 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -22,6 +22,9 @@ class PallasTest(unittest.TestCase): + # This is to create a diagonal mask where only elements within the same segment + # can attend to each other. Since the mask is to mask out the unrelevant parts, + # therefore we use != instead of ==. def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): return q_segment_ids.view(q_segment_ids.shape[0], 1, @@ -32,6 +35,7 @@ def _make_attention_mask_from_segment_ids(self, q_segment_ids, def _attention(self, q, k, v, *, attn_mask=None): attn_weight = q @ k.transpose(-2, -1) if attn_mask is not None: + # Masked out the unrelevant parts. attn_weight = attn_weight.masked_fill(attn_mask, torch.finfo(attn_weight.dtype).min) attn_weight = nn.functional.softmax(attn_weight, dim=-1)