Skip to content

Commit

Permalink
Fix attention autocast (#94)
Browse files Browse the repository at this point in the history
Torch autocasts attention weights into FP32 because of softmax, but
doesn't autocast back into the user-specified data type.

Up until recently, we explicitly passed the autocast dtype in all
autograd function wrappers (reference:

https://github.com/SHI-Labs/NATTEN/blob/3b54c76185904f3cb59a49fff7bc044e4513d106/src/natten/functional.py#L149),
but this is wrong, because the user might be doing BF16.

According to the latest torch documentation, this has not been changed
since the first NATTEN release.

Because it's error prone, this commit explicitly calls cast on all
attention tensors to match the dtype of value. If it's already matching,
          torch will ignore it, and it shouldn't really get in the way
          of AMP mechanics.

Reference: #93
  • Loading branch information
alihassanijr authored Jan 25, 2024
1 parent fc84218 commit 2e5a20f
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 56 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,3 @@ dmypy.json

# Output files
*.out

# Don't commit manifest.in
MANIFEST.in
25 changes: 15 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# Changelog

## [0.15.1] - 2024-01-24
* Attention tensors can now be views, which allows combining neighborhood and any other attention pattern (i.e. registers,
cross attention tokens, and the like) without extra copies. ([#85](https://github.com/SHI-Labs/NATTEN/pull/85) and [#87](https://github.com/SHI-Labs/NATTEN/pull/87)).
* Minor bug fixes ([#86](https://github.com/SHI-Labs/NATTEN/pull/86) and [#94](https://github.com/SHI-Labs/NATTEN/pull/94)).

## [0.15.0] - 2024-01-09
* Refactored kernels
* The backend is messy, particularly the CUDA backend. A step in the right direction is at least factoring out duplicated.
* Out of the 7 operations in NATTEN's backend, 6 have duplicates (really 3 underlying ops with different inputs.)
* See #26 for more details.
* See [#26](https://github.com/SHI-Labs/NATTEN/pull/26) for more details.
* 3D Neighborhood Attention: naive CPU and CUDA kernels were added.
* Major refactoring of the C++ API (#38, #47, #53, and #81)
* GEMM kernels (#38 and #47)
* New build system with cmake (#38, #53, #81)
* Bfloat16 support (#38 and #81)
* Kepler and Maxwell support (#81)
* Forward mode automatic differentiation support (#74)
* Experimental support for Nested Tensors (inference only) (#76)
* Type checking, clang format, and other typesetting/formatting changes (#80)
* Added profiling scripts (#81)
* Major refactoring of the C++ API ([#38](https://github.com/SHI-Labs/NATTEN/pull/38), [#47](https://github.com/SHI-Labs/NATTEN/pull/47), [#53](https://github.com/SHI-Labs/NATTEN/pull/53), and [#81](https://github.com/SHI-Labs/NATTEN/pull/81))
* GEMM kernels ([#38](https://github.com/SHI-Labs/NATTEN/pull/38) and [#47](https://github.com/SHI-Labs/NATTEN/pull/47))
* New build system with cmake ([#38](https://github.com/SHI-Labs/NATTEN/pull/38), [#53](https://github.com/SHI-Labs/NATTEN/pull/53), [#81](https://github.com/SHI-Labs/NATTEN/pull/81))
* Bfloat16 support ([#38](https://github.com/SHI-Labs/NATTEN/pull/38) and [#81](https://github.com/SHI-Labs/NATTEN/pull/81))
* Kepler and Maxwell support ([#81](https://github.com/SHI-Labs/NATTEN/pull/81))
* Forward mode automatic differentiation support ([#74](https://github.com/SHI-Labs/NATTEN/pull/74))
* Experimental support for Nested Tensors (inference only) ([#76](https://github.com/SHI-Labs/NATTEN/pull/76))
* Type checking, clang format, and other typesetting/formatting changes ([#80](https://github.com/SHI-Labs/NATTEN/pull/80))
* Added profiling scripts ([#81](https://github.com/SHI-Labs/NATTEN/pull/81))

## [0.14.6] - 2023-03-21
Just a really small update that syncs the changes to the private branch.
Expand Down
7 changes: 7 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
recursive-include csrc *.cpp *.cuh *.cu *.h *.txt *.hpp *.cc
recursive-include third_party/cutlass/include *.cpp *.cuh *.cu *.h *.txt *.hpp *.cc
prune webpage/
prune tools/
prune scripts/
prune tests/
graft assets/
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ release:

clean:
@echo "Cleaning up"
rm -rf build/
rm -rf dist/
rm -rf natten.egg-info/
rm -rf src/natten/_C.*
Expand Down
2 changes: 1 addition & 1 deletion src/natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@
"has_fp64_gemm",
]

__version__ = "0.15.0"
__version__ = "0.15.1"
6 changes: 6 additions & 0 deletions src/natten/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def forward(
dilation: int,
):
num_na_weights = kernel_size
attn = attn.to(value.dtype)
value = value.contiguous()
out = torch.empty_like(value)
out_add = None
Expand Down Expand Up @@ -318,6 +319,7 @@ def jvp(ctx, *grad_inputs: Any) -> Tensor:
"Expected either both additional_value_t and additional_value_p, or neither."
)

attn_t = attn_t.to(value_t.dtype)
attn_t = attn_t.contiguous()
value_t = value_t.contiguous()
out_0 = torch.empty_like(value_p)
Expand Down Expand Up @@ -544,6 +546,7 @@ def forward(
dilation: int,
) -> Tensor:
num_na_weights = kernel_size**2
attn = attn.to(value.dtype)
value = value.contiguous()
out = torch.empty_like(value)
out_add = None
Expand Down Expand Up @@ -592,6 +595,7 @@ def jvp(ctx, *grad_inputs: Any) -> Tensor:
"Expected either both additional_value_t and additional_value_p, or neither."
)

attn_t = attn_t.to(value_t.dtype)
attn_t = attn_t.contiguous()
value_t = value_t.contiguous()
out_0 = torch.empty_like(value_p)
Expand Down Expand Up @@ -849,6 +853,7 @@ def forward(
dilation: int,
) -> Tensor:
num_na_weights = kernel_size_d * kernel_size * kernel_size
attn = attn.to(value.dtype)
value = value.contiguous()
out = torch.empty_like(value)
out_add = None
Expand Down Expand Up @@ -901,6 +906,7 @@ def jvp(ctx, *grad_inputs: Any) -> Tensor:
"Expected either both additional_value_t and additional_value_p, or neither."
)

attn_t = attn_t.to(value_t.dtype)
attn_t = attn_t.contiguous()
value_t = value_t.contiguous()
out_0 = torch.empty_like(value_p)
Expand Down
3 changes: 3 additions & 0 deletions src/natten/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def na1d_av_nested(
"nested."
)

attn = attn.to(value.dtype)
out = torch.empty_like(value)
additional_values_list: List | Tensor = (
[None for _ in range(attn.size(0))]
Expand Down Expand Up @@ -273,6 +274,7 @@ def na2d_av_nested(
"nested."
)

attn = attn.to(value.dtype)
out = torch.empty_like(value)
additional_values_list: List | Tensor = (
[None for _ in range(attn.size(0))]
Expand Down Expand Up @@ -408,6 +410,7 @@ def na3d_av_nested(
"nested."
)

attn = attn.to(value.dtype)
out = torch.empty_like(value)
additional_values_list: List | Tensor = (
[None for _ in range(attn.size(0))]
Expand Down
Loading

0 comments on commit 2e5a20f

Please sign in to comment.