-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Torch autocasts attention weights into FP32 because of softmax, but doesn't autocast back into the user-specified data type. Up until recently, we explicitly passed the autocast dtype in all autograd function wrappers (reference: https://github.com/SHI-Labs/NATTEN/blob/3b54c76185904f3cb59a49fff7bc044e4513d106/src/natten/functional.py#L149), but this is wrong, because the user might be doing BF16. According to the latest torch documentation, this has not been changed since the first NATTEN release. Because it's error prone, this commit explicitly calls cast on all attention tensors to match the dtype of value. If it's already matching, torch will ignore it, and it shouldn't really get in the way of AMP mechanics. Reference: #93
- Loading branch information
1 parent
fc84218
commit 2e5a20f
Showing
8 changed files
with
73 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,3 @@ dmypy.json | |
|
||
# Output files | ||
*.out | ||
|
||
# Don't commit manifest.in | ||
MANIFEST.in |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
recursive-include csrc *.cpp *.cuh *.cu *.h *.txt *.hpp *.cc | ||
recursive-include third_party/cutlass/include *.cpp *.cuh *.cu *.h *.txt *.hpp *.cc | ||
prune webpage/ | ||
prune tools/ | ||
prune scripts/ | ||
prune tests/ | ||
graft assets/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,4 +59,4 @@ | |
"has_fp64_gemm", | ||
] | ||
|
||
__version__ = "0.15.0" | ||
__version__ = "0.15.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.