Skip to content

Commit

Permalink
Fix minor typo in example (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat authored Oct 21, 2024
1 parent f7c93de commit 0e0b44d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/flex_attn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
"outputs": [],
"source": [
"def checkerboard(score, batch, head, token_q, token_kv):\n",
" score = torch.where(torch.abs(token_kv - token_q) % 1 == 0, score * 0.5, score)\n",
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 1, score * 0.5, score)\n",
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 0, score * 2.0, score)\n",
" return score\n",
"\n",
Expand Down Expand Up @@ -316,7 +316,7 @@
"The implementation using a score_mod:\n",
"```Python\n",
"def causal_bias(score, b, h, q_idx, kv_idx):\n",
" return torch.where(q >= kv_idx, score, -float(\"inf\"))\n",
" return torch.where(q_idx >= kv_idx, score, -float(\"inf\"))\n",
"```\n",
"\n",
"Whenever you are writing a score_mod function that passes through the original score for some elements and sets others to -inf, you should likely be using a mask mod.\n",
Expand All @@ -326,7 +326,7 @@
"```Python\n",
"The implementation using a mask_mod:\n",
"def casual_mask(b,h,q_idx, kv_idx):\n",
" return q >= kv_idx\n",
" return q_idx >= kv_idx\n",
"```\n",
"As you can see they look very similar, both return scalar tensors. The key differences\n",
"1. mask_mods return boolean tensors where `True` indicates this score should be calculated, and `False` indicates we that we want to mask out this score\n",
Expand Down

0 comments on commit 0e0b44d

Please sign in to comment.