diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index d4164279968863..e6e1940bff129c 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -7,6 +7,8 @@ import unittest from collections import namedtuple from contextlib import contextmanager +from dataclasses import dataclass +from itertools import product from typing import Callable, List, Optional, Tuple, Union from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -1214,18 +1216,6 @@ def composed_score_mod(score, b, h, m, n): self.run_test(composed_score_mod, dtype) self.run_test_with_paged_attention(composed_score_mod, dtype) - @supported_platform - @expectedFailure # TODO: Remove this after supporting compiled flex attention with training bias - @common_utils.parametrize("dtype", test_dtypes) - def test_captured_buffers_req_grad(self, dtype: torch.dtype): - head_offset = torch.rand(8, device="cuda", dtype=dtype, requires_grad=True) - - def score_mod(score, b, h, m, n): - return score + head_offset[h] - - self.run_test(score_mod, dtype, 4, 8, 128, 128) - self.run_test_with_paged_attention(score_mod, dtype, 4, 8, 128, 128) - @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_captured_buffers_all_dims(self, dtype: torch.dtype): @@ -4203,9 +4193,608 @@ def causal_mask(b, h, q, kv): self._check_equal(golden_out, ref_out, paged_out, fudge_factor, "Out") +@dataclass +class Params: + batch_size: int + num_heads: int + seq_length: int + head_dim: int + dtype: torch.dtype + config_str: Optional[str] = None + + def __str__(self): + return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}" + + +def get_params(dtypes: List[torch.dtype]) -> List[Params]: + params = [] + seq_lengths = [37, 256, 277] + for seq_len, dtype in product(seq_lengths, dtypes): + params.append( + Params( + batch_size=2, num_heads=4, seq_length=seq_len, head_dim=16, dtype=dtype + ) + ) + return params + + +# ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855 +supports_learnable_bias = unittest.skipUnless( + torch.cuda.is_available() + and torch.utils._triton.has_triton() + and torch.cuda.get_device_capability() >= (8, 0) + and not TEST_WITH_ROCM, + "Requires CUDA and Triton, and is not supported on ROCm", +) + + +@supports_learnable_bias +class TestLearnableBiases(InductorTestCase): + def setUp(self): + super().setUp() + self.device = "cuda" + self.dtype = torch.float32 + self.atol = 3e-2 + self.rtol = 3e-2 + + def _init_tensors(self, params: Params): + make_tensor = functools.partial( + torch.randn, + (params.batch_size, params.num_heads, params.seq_length, params.head_dim), + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + return (make_tensor(), make_tensor(), make_tensor()) + + @torch.no_grad() + def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35): + ref_error = rmse(eager, gold) + comp_error = rmse(compiled, gold) + # Note: This has been carefully tested that FlexAttention is within + # 20% of the average error of SDPA! Do not bump this tolerance + # unless you are absolutely sure you are not worsening the accuracy + # of FlexAttention! + if eager.dtype == torch.float32: + fudge_factor = 10.0 * fudge_factor + + comp_error = comp_error.item() + ref_error = ref_error.item() * fudge_factor + + if ( + tensor_name == "out" + and eager.dtype == torch.float32 + and comp_error > ref_error + ): + self.skipTest("Compiled FlexAttention is less accurate than eager in fp32") + + self.assertLessEqual( + comp_error, + (ref_error * fudge_factor), + f"\nTensor: {tensor_name}\nCompiled error ({comp_error:.8f}) exceeds " + f"reference error ({ref_error:.8f}) * fudge_factor ({fudge_factor})", + ) + + def _check_outputs_and_grads( + self, out_eager, out_compiled, out_gold, tensors, names=None + ): + backwards_grad = torch.randn_like(out_eager) + grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad) + grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad) + grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad) + + tensor_names = ( + ["out", "grad_query", "grad_key", "grad_value", "grad_bias"] + if names is None + else names + ) + + eager_tensors = (out_eager, *grads_eager) + compiled_tensors = (out_compiled, *grads_compiled) + gold_tensors = (out_gold, *grads_gold) + + for eager, compiled, gold, name in zip( + eager_tensors, compiled_tensors, gold_tensors, tensor_names, strict=True + ): + self._gold_check(eager, compiled, gold, name) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_relative_1d_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + 2 * params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[torch.abs(q_idx - kv_idx)] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_absolute_2d_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_head_specific_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.num_heads, + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[h, q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_batch_head_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.batch_size, + params.num_heads, + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[b, h, q_idx, kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_multiplicative_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score * bias[q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_local_window_bias(self, params): + query, key, value = self._init_tensors(params) + window_size = 8 + bias = torch.randn( + 2 * window_size + 1, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + window_idx = torch.clamp(q_idx - kv_idx + window_size, 0, 2 * window_size) + return score + bias[window_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_global_tokens_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_weird_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.batch_size, + params.num_heads, + 4, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + which_bias = torch.tensor(0, device=self.device) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[b, h, which_bias, q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_indirect_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + offset = torch.randint( + 0, + params.seq_length, + (params.seq_length,), + device=self.device, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[offset[q_idx]] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params([torch.float32]), name_fn=lambda x: f"{x}" + ) + def test_symmetric_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[q_idx] + bias[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + # Error in backwards + with self.assertRaisesRegex( + torch._inductor.exc.LoweringException, + "Using multiple indexing operations on the same tensor that requires gradients", + ): + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_flipped_indexed_bias(self, params): + query, key, value = self._init_tensors(params) + bias = torch.randn( + params.seq_length, + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[kv_idx, q_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_head_specific_gate(self, params): + query, key, value = self._init_tensors(params) + gate_score = torch.randn( + params.num_heads, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score * torch.sigmoid(gate_score[h].to(torch.float32)) + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, gate_score), + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_distinct_biases(self, params): + query, key, value = self._init_tensors(params) + # Create two separate bias tensors + bias1 = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + bias2 = torch.randn( + params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias1[q_idx] + bias2[kv_idx] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + # Include both bias tensors in the tuple for gradient checking + self._check_outputs_and_grads( + out_eager, + out_compiled, + out_gold, + (query, key, value, bias1, bias2), + names=[ + "out", + "grad_query", + "grad_key", + "grad_value", + "grad_bias1", + "grad_bias2", + ], + ) + + @common_utils.parametrize( + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" + ) + def test_relative_1d_bias_only_grad(self, params): + query, key, value = self._init_tensors(params) + query = query.detach().requires_grad_(False) + key = key.detach().requires_grad_(False) + value = value.detach().requires_grad_(False) + + # Only bias requires gradients + bias = torch.randn( + 2 * params.seq_length, + device=self.device, + dtype=params.dtype, + requires_grad=True, # Only bias needs gradients + ) + + def bias_func(score, b, h, q_idx, kv_idx): + return score + bias[torch.abs(q_idx - kv_idx)] + + flex_compiled = torch.compile(flex_attention) + out_eager = flex_attention(query, key, value, score_mod=bias_func) + out_compiled = flex_compiled(query, key, value, score_mod=bias_func) + + out_gold = flex_attention( + query.to(torch.float64), + key.to(torch.float64), + value.to(torch.float64), + score_mod=bias_func, + ) + + # For gradient checking, we only pass the bias tensor since it's the only one requiring gradients + self._check_outputs_and_grads( + out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"] + ) + + common_utils.instantiate_parametrized_tests(TestFlexAttention) common_utils.instantiate_parametrized_tests(TestBlockMask) common_utils.instantiate_parametrized_tests(TestPagedAttention) +common_utils.instantiate_parametrized_tests(TestLearnableBiases) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 25d9f2f1d88c97..045d7c98ae20fd 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -586,7 +586,7 @@ def forward( ) assert ( not any_buffer_requires_grad - ), "Captured buffers from mask mod that require grad are not yet supported." + ), "Captured buffers from mask mod that require grad are not supported." ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] @@ -707,7 +707,9 @@ def flex_attention_autograd( from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex with TransformGetItemToIndex(): - input_requires_grad = any(t.requires_grad for t in (query, key, value)) + input_requires_grad = any( + t.requires_grad for t in (query, key, value, *score_mod_other_buffers) + ) if torch.is_grad_enabled() and input_requires_grad: example_vals = ( query.new_zeros((), requires_grad=input_requires_grad), @@ -930,7 +932,8 @@ def trace_flex_attention_backward( mask_mod_other_buffers, ) - fw_example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [ + requires_grad = any(pytree.tree_map(lambda x: x.requires_grad, (query, key))) + fw_example_vals = [query.new_zeros((), requires_grad=requires_grad)] + [ query.new_zeros((), dtype=torch.int) for _ in range(4) ] bw_example_vals = fw_example_vals + [query.new_zeros(())] diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 80711827db0cbb..e38c01fc27f5f7 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,7 +3,8 @@ import logging import math -from typing import Any, List, Optional, Sequence, Tuple +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence, Tuple, Union import sympy @@ -20,11 +21,23 @@ get_fill_order, InputBuffer, IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, StorageBox, Subgraph, TensorBox, ) -from ..lowering import empty, empty_strided, lowerings, register_lowering +from ..lowering import ( + _full, + check_and_broadcast_indices, + empty, + empty_strided, + expand, + index_output_size_and_inner_fn, + lowerings, + register_lowering, + to_dtype, +) from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate @@ -93,10 +106,54 @@ def get_float32_precision(): return "'tf32'" -def build_subgraph_buffer( - args: List[TensorBox], - subgraph: Subgraph, -): +def zeros_and_scatter_lowering(shape: List[int], indices, values): + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +SubgraphResults = Union[List[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def build_subgraph_buffer(args: List[TensorBox], subgraph: Subgraph) -> SubgraphResults: """This function's goal is to take in the required args and produce the subgraph buffer The subgraph buffer is a ComputedBuffer that will be inlined into the triton template @@ -107,17 +164,30 @@ def build_subgraph_buffer( from ..subgraph_lowering import PointwiseSubgraphLowering pw_subgraph = PointwiseSubgraphLowering( - subgraph.graph_module, root_graph_lowering=V.graph + subgraph.graph_module, + root_graph_lowering=V.graph, + allowed_mutations={torch.ops.flex_lib.zeros_and_scatter.default}, + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, ) with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] pw_subgraph.run(*args) - def convert_output_node_to_buffer(output): - if output is None: + # Since we are allowing mutations/buffer creation, we need to register any fresh buffers + # creating during the pointwise subgraph lowering + if len(pw_subgraph.buffers) > 0: + for buffer in pw_subgraph.buffers: + V.graph.register_buffer(buffer) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: return None - output_buffer = output + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer assert isinstance(output_buffer, TensorBox), ( - "The output node for flex attention's subgraph must be a TensorBox, but got: ", + "The output node for flex attention's subgraph must be a TensorBox, but got: ", type(output_buffer), ) assert isinstance(output_buffer.data, StorageBox), ( @@ -135,8 +205,6 @@ def convert_output_node_to_buffer(output): ) return subgraph_buffer - # node.args[0] is either a single element or a list of elements - # representing all outputs of the function. return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) @@ -361,10 +429,9 @@ def get_bounded_indices(indices, max_len=None): idx_d = tl.arange(0, V_HEAD_DIM)[None, :] mask = idx_m < Q_LEN - # TODO generalize and add proper mask support + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: off_hz = tl.program_id(1) l_ptrs = LSE + off_hz * Q_LEN + offs_m @@ -768,6 +835,7 @@ def flex_attention( subgraph_buffer = build_subgraph_buffer( placeholder_inps + list(score_mod_other_buffers), subgraph ) + mask_graph_placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -780,6 +848,7 @@ def flex_attention( mask_graph_buffer = build_subgraph_buffer( mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) + kernel_options = dict(kernel_options) kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if _use_flex_decoding(query, kernel_options): @@ -1654,6 +1723,26 @@ def bwd_dkdv_block_mn( n="n", grad_score_mod="dsT" ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if CHECK_BLOCK_BOUNDARY: grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) @@ -1673,6 +1762,82 @@ def bwd_dkdv_block_mn( ) +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: List[ComputedBuffer] + captured_grads: List[Optional[TensorBox]] + mutated_grads: List[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, List) + assert ( + all_joint_outputs[0] is not None + ), "joint_subgraph_buffer is None this is a bug!" + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + # TODO: We probably also need a layout constraint? @register_lowering( torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None @@ -1774,8 +1939,18 @@ def flex_attention_backward(*args, **kwargs): ] # Sometimes we have weird unused nodes here joint_graph.graph_module.graph.eliminate_dead_code() - joint_subgraph_buffer, *_ = build_subgraph_buffer( - joint_placeholder_inps + list(score_mod_other_buffers), joint_graph + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) ) mask_graph_placeholder_inps = [ @@ -1791,6 +1966,8 @@ def flex_attention_backward(*args, **kwargs): mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) + mask_graph_buffer = mask_graph_buffer + layout_broadcasted_k = FixedLayout( key.get_device(), key.get_dtype(), @@ -1891,8 +2068,17 @@ def flex_attention_backward(*args, **kwargs): full_q_indices, ], layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer], - mutated_inputs=[grad_query, broadcasted_grad_value], + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], call_sizes=query.get_size() + key.get_size()[1:3], num_stages=num_stages, num_warps=num_warps, @@ -1949,8 +2135,4 @@ def flex_attention_backward(*args, **kwargs): grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) - return ( - grad_query, - grad_key, - grad_value, - ) + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 995ad3e75f94b3..14960eb510568a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -75,7 +75,7 @@ log = logging.getLogger(__name__) -lowerings: Dict[Callable[..., Any], Callable[..., Any]] = {} +lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints _maybe_layout_constraints: Dict[ torch._ops.OpOverload, Optional[Callable[..., Any]] diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bbf0b50a56546d..8b07f6528e7976 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -177,6 +177,57 @@ def finalize_all(self) -> str: ) +class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] + """Handles placeholder substitutions during subgraph processing.""" + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: Dict[str, Any], + mask: Optional[str], + ): + super().__init__(V.ops) + self.name = f"PlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed input.""" + if name not in self.fixed_inputs: + index_str = self._process_indexing(index) + var = self._add_kernel_input(name) + return f"tl.load({var} + {index_str})" + return f"({self.fixed_inputs[name]})" + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + def store(self, name, index, value, mode): + """Store value and track the store's mask and output value on the kernel. + + The template_mask and template_out are used by the indexing() method to properly + mask store operations in the generated Triton code. The mask ensures stores only + affect elements matching the mask condition. This is currently only used for scatter node's store + """ + assert ( + self.mask is not None + ), "Mask is required for inner stores in modifications" + self.kernel.template_out = value + self.kernel.template_mask = self.mask + return self._inner.store(name, index, value, mode) + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + return self.kernel.args.input(name) + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + return self.kernel.kexpr(self.kernel.rename_indexing(index)) + + class TritonTemplateKernel(TritonKernel): def __init__( self, @@ -411,69 +462,78 @@ def stride(self, name, index=None): return texpr(self.rename_indexing(val[index])) return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + def _get_subgraph(self, subgraph_number: int): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + assert ( + self.body.getvalue() == "" + ), "Body should be clear before adding a modification" + return self.subgraphs[subgraph_number] + + def _handle_scatter_graph(self, scatter_graph): + """Handle processing for a single scatter graph. + + Args: + scatter_graph: The scatter graph to process + """ + assert isinstance( + scatter_graph, ir.ComputedBuffer + ), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + + def contiguous_strides(x): + # We always create a fresh contiguous grad for scattering into + return sum( + x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) + ) + + scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined] + def modification( - self, subgraph_number: int, output_name: str, **fixed_inputs + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, ) -> str: """This creates a modification function for a subgraph. To use this inside a template, the first argument should specify which subgraph to codegen for Args: subgraph_number (int): The index of the subgraph in self.subgraphs + output_name (Optional[str]): The name of the output variable to store the result in + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. """ - outer_self = self num = 0 + out = None while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: num += 1 with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): - assert isinstance(subgraph_number, int) - assert isinstance(self.subgraphs, list) - assert ( - self.body.getvalue() == "" - ), "Body should be clear before adding a modification" - assert subgraph_number < len( - self.subgraphs - ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" - - subgraph = self.subgraphs[subgraph_number] - - def add_input(name): - # This also implicitly adds name as an input to the kernel - return self.args.input(name) - - def print_and_rename_indexing(index): - # This also implicitly adds the indexing symbols as an input to - # the kernel - return self.kexpr(self.rename_indexing(index)) - - name = f"PlaceholderSubstitution_{subgraph_number}" - - class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = name - - def load(self, name: str, index: sympy.Expr): - if name not in fixed_inputs: - # If it's not a fixed input, it's a load from a captured - # tensor - index_str = print_and_rename_indexing(index) - var = add_input(name) - return f"tl.load({var} + {index_str})" - - return f"({fixed_inputs[name]})" - - def indirect_indexing(self, index_var, size, check, wrap_neg=True): - return sympy_index_symbol(str(index_var)) - - with V.set_ops_handler(PlaceholderSubstitution(V.ops)): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapper( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_ops_handler(modification_handler): assert isinstance( - subgraph, ir.ComputedBuffer - ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" - if isinstance(subgraph.data, ir.InputBuffer): + subgraph, (ir.ComputedBuffer, List) + ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + # Handle scatter stores + if isinstance(subgraph, list): + for scatter_graph in subgraph: + self._handle_scatter_graph(scatter_graph) + elif isinstance(subgraph.data, ir.InputBuffer): out = subgraph.data.make_loader()(()) else: out = subgraph.data.inner_fn(()) self.codegen_body() - self.body.writeline(f"{output_name} = {out.value}") + if output_name is not None: + assert isinstance(output_name, str) + assert out is not None + self.body.writeline(f"{output_name} = {out.value}") body_val = self.body.getvalue() self.cse.invalidate(set()) # type: ignore[arg-type] @@ -783,19 +843,13 @@ def generate( # type: ignore[override] mod = PyCodeCache.load(code, extra) input_call_args = tuple(kernel.args.input_buffers.keys()) - output_call_args = tuple(kernel.args.output_buffers.keys()) # We expect the input_buffer order to be [*input_nodes, *captured_buffers] expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) - expected_output_args = (fake_out.get_name(),) assert input_call_args[: len(expected_input_args)] == expected_input_args, ( input_call_args, expected_input_args, ) - assert output_call_args == expected_output_args, ( - output_call_args, - expected_output_args, - ) full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) extra_args = V.graph.sizevars.size_hints( diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index e0b36c41965244..1093f684e8b548 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -1,11 +1,21 @@ -"""Utilities for lowering subgraphs used by higher order operators - -""" +"""Utilities for lowering subgraphs used by higher order operators""" import functools import operator +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from typing_extensions import ParamSpec import torch @@ -19,6 +29,10 @@ T = TypeVar("T") _P = ParamSpec("_P") +OpOverload = torch._ops.OpOverload +LoweringDict = Dict[Union[OpOverload, str], Callable[..., Any]] +TargetType = Union[Callable[..., Any], str] + class PointwiseSubgraphLowering(torch.fx.Interpreter): """ @@ -27,46 +41,97 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): """ graph_outputs: Optional[List[ir.IRNode]] + root_graph: torch._inductor.graph.GraphLowering + _current_op: Optional[TargetType] + # For backwards of buffer_grads with scatters we allow mutations + allowed_mutations: Optional[Set[OpOverload]] + additional_lowerings: Optional[LoweringDict] + buffers: List[ir.Buffer] + mutated_buffers: Set[str] def __init__( self, gm: torch.fx.GraphModule, - root_graph_lowering: "torch._inductor.graph.GraphLowering", + root_graph_lowering: torch._inductor.graph.GraphLowering, + allowed_mutations: Optional[Set[OpOverload]] = None, + additional_lowerings: Optional[LoweringDict] = None, ) -> None: super().__init__(gm) self.graph_outputs = None self.root_graph = root_graph_lowering + self.allowed_mutations = allowed_mutations + self.additional_lowerings = additional_lowerings + self._current_op = None + + # Used to track buffers created during lowering + self.mutated_buffers = set() + self.buffers = [] + + @contextmanager + def _op_context(self, op: TargetType) -> Generator[None, None, None]: + """Set which op is being processed in call function to know if we can mutate buffers""" + previous = self._current_op + self._current_op = op + try: + yield + finally: + self._current_op = previous + + def _approved_mutator(self) -> bool: + return ( + self.allowed_mutations is not None + and self._current_op in self.allowed_mutations + ) def mark_buffer_mutated(self, name: str) -> None: - raise SubgraphLoweringException("Mutations are not supported in this context") + if self._approved_mutator(): + self.mutated_buffers.add(name) + else: + raise SubgraphLoweringException( + f"Buffer mutation detected during lowering of {self._current_op}. " + "Buffer mutations are only allowed in approved mutation ops. " + "This is an error in the lowering of the subgraph, please file a bug report." + ) - def register_buffer(self, buffer: ir.Buffer) -> str: - raise SubgraphLoweringException( - "Buffers cannot be created while lowering a pointwise subgraph. " - "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " - "but it could also be a bug. Please file a bug report if you think this should be supportable." - ) + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + if self._approved_mutator(): + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + return name + else: + raise SubgraphLoweringException( + "Buffers cannot be created while lowering a pointwise subgraph. " + "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " + "but it could also be a bug. Please file a bug report if you think this should be supportable." + ) def __getattr__(self, name: str) -> Any: return getattr(self.root_graph, name) def call_function( self, - target: Callable[[Any], Any], # type: ignore[override] + target: TargetType, args: Any, kwargs: Dict[str, Any], ) -> Any: from .lowering import lowerings - if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): - return super().call_function(target, args, kwargs) + with self._op_context(target): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) - if target not in lowerings: - raise SubgraphLoweringException( - f"{target} not supported in subgraph, (missing lowering)" - ) + # These takes precedence over the main lowerings + if self.additional_lowerings is not None: + if target in self.additional_lowerings: + assert isinstance(target, OpOverload) + return self.additional_lowerings[target](*args, **kwargs) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) - return lowerings[target](*args, **kwargs) + return lowerings[target](*args, **kwargs) def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] assert len(args) == 1 diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 31060d38c13da9..00774fac7a7609 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1066,14 +1066,13 @@ def _apply_kernel_options( kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) - # If foward kernel needs to return logsumexp is decided by this rule internally. + # If forward kernel needs to return logsumexp is decided by this rule internally. assert "OUTPUT_LOGSUMEXP" not in kernel_options kernel_options["OUTPUT_LOGSUMEXP"] = True if not return_lse: - any_inputs_require_grad = ( - query.requires_grad or key.requires_grad or value.requires_grad - ) - output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() + # We used to check if q,k,v required grads but since captured buffers can require grad + # we always write unless in no_grad + output_logsumexp = torch.is_grad_enabled() kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp return kernel_options @@ -1240,9 +1239,7 @@ def score_mod( block_mask = _create_empty_block_mask(query, key) elif ( not query.is_nested - and ( - query.requires_grad or key.requires_grad or value.requires_grad - ) # skip adjust block if no grad + and (query.requires_grad or key.requires_grad or value.requires_grad) and ( query.size(-2) < block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]