From 9e8cd4ae15f9ace774ee76c62917c62798c8d560 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 25 Nov 2024 06:43:51 -0800 Subject: [PATCH] [FlexAttention] add support for learnable biases in Inductor (#137452) # Summary The follow up PR to: https://github.com/pytorch/pytorch/pull/137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads. We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs. There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now. By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic. ## Worked Example Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does. ScoreMod: ```Python bias = torch.randn( params.seq_length, device=self.device, dtype=params.dtype, requires_grad=True, ) offset = torch.randint( 0, params.seq_length, (params.seq_length,), device=self.device, ) def score_mod(score, b, h, q_idx, kv_idx): return score + bias[offset[q_idx]] ``` I am removing all but the new subgraph injected into the backwards: ``` Python dsT = pT * (dpT - Di[None, :]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ grad_scores = (dsT) # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ idx_b = off_z idx_h = off_hq idx_m = m idx_n = n scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN tmp4 = (dsT).to(tl.float32) tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``` ## Key points * We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions * We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up. * Will do more extensive performance/ memory profiling in a follow up. ### Toy E2E example I have a toy E2E training example PR in the gym for now: https://github.com/pytorch-labs/attention-gym/pull/84/ I plan to update to a realistic learnable bias before landing Pull Request resolved: https://github.com/pytorch/pytorch/pull/137452 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 613 +++++++++++++++++++++- torch/_higher_order_ops/flex_attention.py | 9 +- torch/_inductor/kernel/flex_attention.py | 230 +++++++- torch/_inductor/lowering.py | 2 +- torch/_inductor/select_algorithm.py | 156 ++++-- torch/_inductor/subgraph_lowering.py | 105 +++- torch/nn/attention/flex_attention.py | 13 +- 7 files changed, 1009 insertions(+), 119 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index d4164279968863..e6e1940bff129c 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -7,6 +7,8 @@ import unittest from collections import namedtuple from contextlib import contextmanager +from dataclasses import dataclass +from itertools import product from typing import Callable, List, Optional, Tuple, Union from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -1214,18 +1216,6 @@ def composed_score_mod(score, b, h, m, n): self.run_test(composed_score_mod, dtype) self.run_test_with_paged_attention(composed_score_mod, dtype) - @supported_platform - @expectedFailure # TODO: Remove this after supporting compiled flex attention with training bias - @common_utils.parametrize("dtype", test_dtypes) - def test_captured_buffers_req_grad(self, dtype: torch.dtype): - head_offset = torch.rand(8, device="cuda", dtype=dtype, requires_grad=True) - - def score_mod(score, b, h, m, n): - return score + head_offset[h] - - self.run_test(score_mod, dtype, 4, 8, 128, 128) - self.run_test_with_paged_attention(score_mod, dtype, 4, 8, 128, 128) - @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_captured_buffers_all_dims(self, dtype: torch.dtype): @@ -4203,9 +4193,608 @@ def causal_mask(b, h, q, kv): self._check_equal(golden_out, ref_out, paged_out, fudge_factor, "Out") +@dataclass +class Params: + batch_size: int + num_heads: int + seq_length: int + head_dim: int + dtype: torch.dtype + config_str: Optional[str] = None + + def __str__(self): + return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}" + + +def get_params(dtypes: List[torch.dtype]) -> List[Params]: + params = [] + seq_lengths = [37, 256, 277] + for seq_len, dtype in product(seq_lengths, dtypes): + params.append( + Params( + batch_size=2, num_heads=4, seq_length=seq_len, head_dim=16, dtype=dtype + ) + ) + return params + + +# ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855 +supports_learnable_bias = unittest.skipUnless( + torch.cuda.is_available() + and torch.utils._triton.has_triton() + and torch.cuda.get_device_capability() >= (8, 0) + and not TEST_WITH_ROCM, + "Requires CUDA and Triton, and is not supported on ROCm", +) + + +@supports_learnable_bias +class TestLearnableBiases(InductorTestCase): + def setUp(self): + super().setUp() + self.device = "cuda" + self.dtype = torch.float32 + self.atol = 3e-2 + self.rtol = 3e-2 + + def _init_tensors(self, params: Params): + make_tensor = functools.partial( + torch.randn, + (params.batch_size, params.num_heads, params.seq_length, params.head_dim), + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + return (make_tensor(), make_tensor(), make_tensor()) + + @torch.no_grad() + def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35): + ref_error = rmse(eager, gold) + comp_error = rmse(compiled, gold) + # Note: This has been carefully tested that FlexAttention is within + # 20% of the average error of SDPA! Do not bump this tolerance + # unless you are absolutely sure you are not worsening the accuracy + # of FlexAttention! + if eager.dtype == torch.float32: + fudge_factor = 10.0 * fudge_factor + + comp_error = comp_error.item() + ref_error = ref_error.item() * fudge_factor + + if ( + tensor_name == "out" + and eager.dtype == torch.float32 + and comp_error > ref_error + ): + self.skipTest("Compiled FlexAttention is less accurate than eager in fp32") + + self.assertLessEqual( + comp_error, + (ref_error * fudge_factor), + f"\nTensor: {tensor_name}\nCompiled error ({comp_error:.8f}) exceeds " + f"reference error ({ref_error:.8f}) * fudge_factor ({fudge_factor})", + ) + + def _check_outputs_and_grads( + self, out_eager, out_compiled, out_gold, tensors, names=None + ): + backwards_grad = torch.randn_like(out_eager) + grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad) + grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad) + grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad) + + tensor_names = ( + ["out", "grad_query", "grad_key", "grad_value", "grad_bias"] + if names is None + else names + ) + + eager_tensors = (out_eager, *grads_eager) + compiled_tensors = (out_compiled, *grads_compiled) + gold_tensors = (out_gold, *grads_gold) + + for eager, compiled, gold, name in zip( + eager_tensors, compiled_tensors, gold_tensors, tensor_names, strict=True + ): + self._gold_check(eager, compiled, gold, name) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_relative_1d_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + 2 * params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[torch.abs(q_idx - kv_idx)] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_absolute_2d_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_head_specific_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.num_heads, + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[h, q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_batch_head_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.batch_size, + params.num_heads, + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[b, h, q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_multiplicative_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score * bias[q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_local_window_bias(self, params): + query, key, value = self._init_tensors(params) + window_size = 8 + bias = torch.randn( + 2 * window_size + 1, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + window_idx = torch.clamp(q_idx - kv_idx + window_size, 0, 2 * window_size) + return score + bias[window_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_global_tokens_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_weird_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.batch_size, + params.num_heads, + 4, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + which_bias = torch.tensor(0, device=self.device) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[b, h, which_bias, q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_indirect_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + offset = torch.randint( + 0, + params.seq_length, + (params.seq_length,), + device=self.device, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[offset[q_idx]] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params([torch.float32]), name_fn=lambda x: f"{x}" + ) + def test_symmetric_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[q_idx] + bias[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + # Error in backwards + with self.assertRaisesRegex( + torch._inductor.exc.LoweringException, + "Using multiple indexing operations on the same tensor that requires gradients", + ): + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_flipped_indexed_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[kv_idx, q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_head_specific_gate(self, params): + query, key, value = self._init_tensors(params) + gate_score = torch.randn( + params.num_heads, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score * torch.sigmoid(gate_score[h].to(torch.float32)) + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, gate_score), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_distinct_biases(self, params): + query, key, value = self._init_tensors(params) + # Create two separate bias tensors + bias1 = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + bias2 = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias1[q_idx] + bias2[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + # Include both bias tensors in the tuple for gradient checking + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias1, bias2), + names=[ + "out", + "grad_query", + "grad_key", + "grad_value", + "grad_bias1", + "grad_bias2", + ], + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_relative_1d_bias_only_grad(self, params): + query, key, value = self._init_tensors(params) + query = query.detach().requires_grad_(False) + key = key.detach().requires_grad_(False) + value = value.detach().requires_grad_(False) + + # Only bias requires gradients + bias = torch.randn( + 2 * params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, # Only bias needs gradients + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[torch.abs(q_idx - kv_idx)] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + # For gradient checking, we only pass the bias tensor since it's the only one requiring gradients + self._check_outputs_and_grads( + out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"] + ) + + common_utils.instantiate_parametrized_tests(TestFlexAttention) common_utils.instantiate_parametrized_tests(TestBlockMask) common_utils.instantiate_parametrized_tests(TestPagedAttention) +common_utils.instantiate_parametrized_tests(TestLearnableBiases) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 25d9f2f1d88c97..045d7c98ae20fd 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -586,7 +586,7 @@ def forward( ) assert ( not any_buffer_requires_grad - ), "Captured buffers from mask mod that require grad are not yet supported." + ), "Captured buffers from mask mod that require grad are not supported." ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] @@ -707,7 +707,9 @@ def flex_attention_autograd( from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex with TransformGetItemToIndex(): - input_requires_grad = any(t.requires_grad for t in (query, key, value)) + input_requires_grad = any( + t.requires_grad for t in (query, key, value, *score_mod_other_buffers) + ) if torch.is_grad_enabled() and input_requires_grad: example_vals = ( query.new_zeros((), requires_grad=input_requires_grad), @@ -930,7 +932,8 @@ def trace_flex_attention_backward( mask_mod_other_buffers, ) - fw_example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [ + requires_grad = any(pytree.tree_map(lambda x: x.requires_grad, (query, key))) + fw_example_vals = [query.new_zeros((), requires_grad=requires_grad)] + [ query.new_zeros((), dtype=torch.int) for _ in range(4) ] bw_example_vals = fw_example_vals + [query.new_zeros(())] diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 80711827db0cbb..e38c01fc27f5f7 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,7 +3,8 @@ import logging import math -from typing import Any, List, Optional, Sequence, Tuple +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence, Tuple, Union import sympy @@ -20,11 +21,23 @@ get_fill_order, InputBuffer, IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, StorageBox, Subgraph, TensorBox, ) -from ..lowering import empty, empty_strided, lowerings, register_lowering +from ..lowering import ( + _full, + check_and_broadcast_indices, + empty, + empty_strided, + expand, + index_output_size_and_inner_fn, + lowerings, + register_lowering, + to_dtype, +) from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate @@ -93,10 +106,54 @@ def get_float32_precision(): return "'tf32'" -def build_subgraph_buffer( - args: List[TensorBox], - subgraph: Subgraph, -): +def zeros_and_scatter_lowering(shape: List[int], indices, values): + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +SubgraphResults = Union[List[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def build_subgraph_buffer(args: List[TensorBox], subgraph: Subgraph) -> SubgraphResults: """This function's goal is to take in the required args and produce the subgraph buffer The subgraph buffer is a ComputedBuffer that will be inlined into the triton template @@ -107,17 +164,30 @@ def build_subgraph_buffer( from ..subgraph_lowering import PointwiseSubgraphLowering pw_subgraph = PointwiseSubgraphLowering( - subgraph.graph_module, root_graph_lowering=V.graph + subgraph.graph_module, + root_graph_lowering=V.graph, + allowed_mutations={torch.ops.flex_lib.zeros_and_scatter.default}, + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, ) with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] pw_subgraph.run(*args) - def convert_output_node_to_buffer(output): - if output is None: + # Since we are allowing mutations/buffer creation, we need to register any fresh buffers + # creating during the pointwise subgraph lowering + if len(pw_subgraph.buffers) > 0: + for buffer in pw_subgraph.buffers: + V.graph.register_buffer(buffer) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: return None - output_buffer = output + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer assert isinstance(output_buffer, TensorBox), ( - "The output node for flex attention's subgraph must be a TensorBox, but got: ", + "The output node for flex attention's subgraph must be a TensorBox, but got: ", type(output_buffer), ) assert isinstance(output_buffer.data, StorageBox), ( @@ -135,8 +205,6 @@ def convert_output_node_to_buffer(output): ) return subgraph_buffer - # node.args[0] is either a single element or a list of elements - # representing all outputs of the function. return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) @@ -361,10 +429,9 @@ def get_bounded_indices(indices, max_len=None): idx_d = tl.arange(0, V_HEAD_DIM)[None, :] mask = idx_m < Q_LEN - # TODO generalize and add proper mask support + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: off_hz = tl.program_id(1) l_ptrs = LSE + off_hz * Q_LEN + offs_m @@ -768,6 +835,7 @@ def flex_attention( subgraph_buffer = build_subgraph_buffer( placeholder_inps + list(score_mod_other_buffers), subgraph ) + mask_graph_placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -780,6 +848,7 @@ def flex_attention( mask_graph_buffer = build_subgraph_buffer( mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) + kernel_options = dict(kernel_options) kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if _use_flex_decoding(query, kernel_options): @@ -1654,6 +1723,26 @@ def bwd_dkdv_block_mn( n="n", grad_score_mod="dsT" ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if CHECK_BLOCK_BOUNDARY: grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) @@ -1673,6 +1762,82 @@ def bwd_dkdv_block_mn( ) +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: List[ComputedBuffer] + captured_grads: List[Optional[TensorBox]] + mutated_grads: List[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, List) + assert ( + all_joint_outputs[0] is not None + ), "joint_subgraph_buffer is None this is a bug!" + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + # TODO: We probably also need a layout constraint? @register_lowering( torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None @@ -1774,8 +1939,18 @@ def flex_attention_backward(*args, **kwargs): ] # Sometimes we have weird unused nodes here joint_graph.graph_module.graph.eliminate_dead_code() - joint_subgraph_buffer, *_ = build_subgraph_buffer( - joint_placeholder_inps + list(score_mod_other_buffers), joint_graph + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) ) mask_graph_placeholder_inps = [ @@ -1791,6 +1966,8 @@ def flex_attention_backward(*args, **kwargs): mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) + mask_graph_buffer = mask_graph_buffer + layout_broadcasted_k = FixedLayout( key.get_device(), key.get_dtype(), @@ -1891,8 +2068,17 @@ def flex_attention_backward(*args, **kwargs): full_q_indices, ], layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer], - mutated_inputs=[grad_query, broadcasted_grad_value], + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], call_sizes=query.get_size() + key.get_size()[1:3], num_stages=num_stages, num_warps=num_warps, @@ -1949,8 +2135,4 @@ def flex_attention_backward(*args, **kwargs): grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) - return ( - grad_query, - grad_key, - grad_value, - ) + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 995ad3e75f94b3..14960eb510568a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -75,7 +75,7 @@ log = logging.getLogger(__name__) -lowerings: Dict[Callable[..., Any], Callable[..., Any]] = {} +lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints _maybe_layout_constraints: Dict[ torch._ops.OpOverload, Optional[Callable[..., Any]] diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bbf0b50a56546d..8b07f6528e7976 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -177,6 +177,57 @@ def finalize_all(self) -> str: ) +class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] + """Handles placeholder substitutions during subgraph processing.""" + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: Dict[str, Any], + mask: Optional[str], + ): + super().__init__(V.ops) + self.name = f"PlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed input.""" + if name not in self.fixed_inputs: + index_str = self._process_indexing(index) + var = self._add_kernel_input(name) + return f"tl.load({var} + {index_str})" + return f"({self.fixed_inputs[name]})" + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + def store(self, name, index, value, mode): + """Store value and track the store's mask and output value on the kernel. + + The template_mask and template_out are used by the indexing() method to properly + mask store operations in the generated Triton code. The mask ensures stores only + affect elements matching the mask condition. This is currently only used for scatter node's store + """ + assert ( + self.mask is not None + ), "Mask is required for inner stores in modifications" + self.kernel.template_out = value + self.kernel.template_mask = self.mask + return self._inner.store(name, index, value, mode) + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + return self.kernel.args.input(name) + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + return self.kernel.kexpr(self.kernel.rename_indexing(index)) + + class TritonTemplateKernel(TritonKernel): def __init__( self, @@ -411,69 +462,78 @@ def stride(self, name, index=None): return texpr(self.rename_indexing(val[index])) return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + def _get_subgraph(self, subgraph_number: int): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + assert ( + self.body.getvalue() == "" + ), "Body should be clear before adding a modification" + return self.subgraphs[subgraph_number] + + def _handle_scatter_graph(self, scatter_graph): + """Handle processing for a single scatter graph. + + Args: + scatter_graph: The scatter graph to process + """ + assert isinstance( + scatter_graph, ir.ComputedBuffer + ), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + + def contiguous_strides(x): + # We always create a fresh contiguous grad for scattering into + return sum( + x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) + ) + + scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined] + def modification( - self, subgraph_number: int, output_name: str, **fixed_inputs + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, ) -> str: """This creates a modification function for a subgraph. To use this inside a template, the first argument should specify which subgraph to codegen for Args: subgraph_number (int): The index of the subgraph in self.subgraphs + output_name (Optional[str]): The name of the output variable to store the result in + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. """ - outer_self = self num = 0 + out = None while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: num += 1 with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): - assert isinstance(subgraph_number, int) - assert isinstance(self.subgraphs, list) - assert ( - self.body.getvalue() == "" - ), "Body should be clear before adding a modification" - assert subgraph_number < len( - self.subgraphs - ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" - - subgraph = self.subgraphs[subgraph_number] - - def add_input(name): - # This also implicitly adds name as an input to the kernel - return self.args.input(name) - - def print_and_rename_indexing(index): - # This also implicitly adds the indexing symbols as an input to - # the kernel - return self.kexpr(self.rename_indexing(index)) - - name = f"PlaceholderSubstitution_{subgraph_number}" - - class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = name - - def load(self, name: str, index: sympy.Expr): - if name not in fixed_inputs: - # If it's not a fixed input, it's a load from a captured - # tensor - index_str = print_and_rename_indexing(index) - var = add_input(name) - return f"tl.load({var} + {index_str})" - - return f"({fixed_inputs[name]})" - - def indirect_indexing(self, index_var, size, check, wrap_neg=True): - return sympy_index_symbol(str(index_var)) - - with V.set_ops_handler(PlaceholderSubstitution(V.ops)): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapper( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_ops_handler(modification_handler): assert isinstance( - subgraph, ir.ComputedBuffer - ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" - if isinstance(subgraph.data, ir.InputBuffer): + subgraph, (ir.ComputedBuffer, List) + ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + # Handle scatter stores + if isinstance(subgraph, list): + for scatter_graph in subgraph: + self._handle_scatter_graph(scatter_graph) + elif isinstance(subgraph.data, ir.InputBuffer): out = subgraph.data.make_loader()(()) else: out = subgraph.data.inner_fn(()) self.codegen_body() - self.body.writeline(f"{output_name} = {out.value}") + if output_name is not None: + assert isinstance(output_name, str) + assert out is not None + self.body.writeline(f"{output_name} = {out.value}") body_val = self.body.getvalue() self.cse.invalidate(set()) # type: ignore[arg-type] @@ -783,19 +843,13 @@ def generate( # type: ignore[override] mod = PyCodeCache.load(code, extra) input_call_args = tuple(kernel.args.input_buffers.keys()) - output_call_args = tuple(kernel.args.output_buffers.keys()) # We expect the input_buffer order to be [*input_nodes, *captured_buffers] expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) - expected_output_args = (fake_out.get_name(),) assert input_call_args[: len(expected_input_args)] == expected_input_args, ( input_call_args, expected_input_args, ) - assert output_call_args == expected_output_args, ( - output_call_args, - expected_output_args, - ) full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) extra_args = V.graph.sizevars.size_hints( diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index e0b36c41965244..1093f684e8b548 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -1,11 +1,21 @@ -"""Utilities for lowering subgraphs used by higher order operators - -""" +"""Utilities for lowering subgraphs used by higher order operators""" import functools import operator +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from typing_extensions import ParamSpec import torch @@ -19,6 +29,10 @@ T = TypeVar("T") _P = ParamSpec("_P") +OpOverload = torch._ops.OpOverload +LoweringDict = Dict[Union[OpOverload, str], Callable[..., Any]] +TargetType = Union[Callable[..., Any], str] + class PointwiseSubgraphLowering(torch.fx.Interpreter): """ @@ -27,46 +41,97 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): """ graph_outputs: Optional[List[ir.IRNode]] + root_graph: torch._inductor.graph.GraphLowering + _current_op: Optional[TargetType] + # For backwards of buffer_grads with scatters we allow mutations + allowed_mutations: Optional[Set[OpOverload]] + additional_lowerings: Optional[LoweringDict] + buffers: List[ir.Buffer] + mutated_buffers: Set[str] def __init__( self, gm: torch.fx.GraphModule, - root_graph_lowering: "torch._inductor.graph.GraphLowering", + root_graph_lowering: torch._inductor.graph.GraphLowering, + allowed_mutations: Optional[Set[OpOverload]] = None, + additional_lowerings: Optional[LoweringDict] = None, ) -> None: super().__init__(gm) self.graph_outputs = None self.root_graph = root_graph_lowering + self.allowed_mutations = allowed_mutations + self.additional_lowerings = additional_lowerings + self._current_op = None + + # Used to track buffers created during lowering + self.mutated_buffers = set() + self.buffers = [] + + @contextmanager + def _op_context(self, op: TargetType) -> Generator[None, None, None]: + """Set which op is being processed in call function to know if we can mutate buffers""" + previous = self._current_op + self._current_op = op + try: + yield + finally: + self._current_op = previous + + def _approved_mutator(self) -> bool: + return ( + self.allowed_mutations is not None + and self._current_op in self.allowed_mutations + ) def mark_buffer_mutated(self, name: str) -> None: - raise SubgraphLoweringException("Mutations are not supported in this context") + if self._approved_mutator(): + self.mutated_buffers.add(name) + else: + raise SubgraphLoweringException( + f"Buffer mutation detected during lowering of {self._current_op}. " + "Buffer mutations are only allowed in approved mutation ops. " + "This is an error in the lowering of the subgraph, please file a bug report." + ) - def register_buffer(self, buffer: ir.Buffer) -> str: - raise SubgraphLoweringException( - "Buffers cannot be created while lowering a pointwise subgraph. " - "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " - "but it could also be a bug. Please file a bug report if you think this should be supportable." - ) + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + if self._approved_mutator(): + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + return name + else: + raise SubgraphLoweringException( + "Buffers cannot be created while lowering a pointwise subgraph. " + "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " + "but it could also be a bug. Please file a bug report if you think this should be supportable." + ) def __getattr__(self, name: str) -> Any: return getattr(self.root_graph, name) def call_function( self, - target: Callable[[Any], Any], # type: ignore[override] + target: TargetType, args: Any, kwargs: Dict[str, Any], ) -> Any: from .lowering import lowerings - if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): - return super().call_function(target, args, kwargs) + with self._op_context(target): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) - if target not in lowerings: - raise SubgraphLoweringException( - f"{target} not supported in subgraph, (missing lowering)" - ) + # These takes precedence over the main lowerings + if self.additional_lowerings is not None: + if target in self.additional_lowerings: + assert isinstance(target, OpOverload) + return self.additional_lowerings[target](*args, **kwargs) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) - return lowerings[target](*args, **kwargs) + return lowerings[target](*args, **kwargs) def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] assert len(args) == 1 diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 31060d38c13da9..00774fac7a7609 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1066,14 +1066,13 @@ def _apply_kernel_options( kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) - # If foward kernel needs to return logsumexp is decided by this rule internally. + # If forward kernel needs to return logsumexp is decided by this rule internally. assert "OUTPUT_LOGSUMEXP" not in kernel_options kernel_options["OUTPUT_LOGSUMEXP"] = True if not return_lse: - any_inputs_require_grad = ( - query.requires_grad or key.requires_grad or value.requires_grad - ) - output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() + # We used to check if q,k,v required grads but since captured buffers can require grad + # we always write unless in no_grad + output_logsumexp = torch.is_grad_enabled() kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp return kernel_options @@ -1240,9 +1239,7 @@ def score_mod( block_mask = _create_empty_block_mask(query, key) elif ( not query.is_nested - and ( - query.requires_grad or key.requires_grad or value.requires_grad - ) # skip adjust block if no grad + and (query.requires_grad or key.requires_grad or value.requires_grad) and ( query.size(-2) < block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]