From 41acb509ccb567b330e15dd68cbcd40986ec6db3 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 22 Aug 2024 14:48:41 -0700 Subject: [PATCH] scan and apply_layers: milestone 1 This commit adds the lowering of scan to HLO While op. It also introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. In this milestone we use AOTAutograd to obtain the backward of the function being scanned. Users can either save the activations in fn or recompute them by passing different graph partitioners to AOTAutograd. --- examples/decoder_only_model.py | 8 +- test/run_tests.sh | 3 +- test/scan/test_scan.py | 477 ++++++++++++++++++ test/scan/test_scan_layers.py | 280 +++++++++++ test/test_operations.py | 24 + test/test_scan.py | 107 ----- test/tpu/run_tests.sh | 5 +- torch_xla/csrc/init_python_bindings.cpp | 10 +- torch_xla/csrc/lowering_context.cpp | 15 +- torch_xla/csrc/lowering_context.h | 6 + torch_xla/experimental/pytreeify.py | 50 ++ torch_xla/experimental/scan.py | 614 +++++++++++++++++++++++- torch_xla/experimental/scan_layers.py | 142 ++++++ 13 files changed, 1597 insertions(+), 144 deletions(-) create mode 100644 test/scan/test_scan.py create mode 100644 test/scan/test_scan_layers.py delete mode 100644 test/test_scan.py create mode 100644 torch_xla/experimental/pytreeify.py create mode 100644 torch_xla/experimental/scan_layers.py diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 712423d79ad..79040e5d24d 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -7,16 +7,16 @@ from torch import nn -# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core. +# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core. @dataclass class DecoderOnlyConfig: hidden_size: int = 1024 num_hidden_layers: int = 2 num_attention_heads: int = 8 num_key_value_heads: int = 4 - intermediate_size = 32 * 1024 - vocab_size = 3200 - use_flash_attention = False + intermediate_size: int = 32 * 1024 + vocab_size: int = 3200 + use_flash_attention: bool = False def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/test/run_tests.sh b/test/run_tests.sh index 0912d53ded5..543bc5f8403 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -208,7 +208,8 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" - run_test "$CDIR/test_scan.py" + run_test "$CDIR/scan/test_scan.py" + run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py new file mode 100644 index 00000000000..f344386c852 --- /dev/null +++ b/test/scan/test_scan.py @@ -0,0 +1,477 @@ +import sys +import os +import re +import unittest +from functools import reduce + +import torch +from functorch.compile import default_partition, min_cut_rematerialization_partition # type: ignore +from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree + +import torch_xla +from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from test_utils import XlaTestCase # type:ignore + + +def _loopy_scan(fn, init, xs): + """A simple scan implemented with for loops serving as reference + implementation.""" + carry = init + ys = [] + xs_len = len(next(iter(tree_iter(xs)))) + for i in range(xs_len): + carry, y = fn(carry, tree_map(lambda x: x[i], xs)) + ys.append(y) + + def none_stack(*ys): + if len(ys) == 0: + return None + if ys[0] is None: + assert all(y is None for y in ys) + return None + return torch.stack(ys) + + ys = tree_map(none_stack, *ys) + return carry, ys + + +class TestBase(XlaTestCase): + + def setUp(self): + super().setUp() + self.device = torch_xla.device() + + def compare_pytree(self, expected_pytree, actual_pytree): + flat_expected_pytree, expected_spec = tree_flatten(expected_pytree) + flat_actual_pytree, actual_spec = tree_flatten(actual_pytree) + assert expected_spec == actual_spec, f"{expected_spec} != {actual_spec}" + # If there are `None`, they must happen in the same location. + for expected, actual in zip(flat_expected_pytree, flat_actual_pytree): + assert (expected is None) == (actual is None), \ + f"Mismatched None. expected: {expected}, actual: {actual}" + # Get rid of `None` before passing to compareResults. + flat_expected_pytree = [x for x in flat_expected_pytree if x is not None] + flat_actual_pytree = [x for x in flat_actual_pytree if x is not None] + super().compareResults(flat_expected_pytree, flat_actual_pytree) + + +class ScanTest(TestBase): + + def run_test(self, + fn, + init: PyTree, + xs: PyTree, + partition_fn=default_partition): + """Compares the result of scanning with `fn` with our optimized HLO implementation + against a for loop implementation. Checks both output values and gradients. + """ + squish = lambda t: reduce( + lambda a, b: a + b, + map(lambda v: v.sum() + if v is not None else 0, tree_leaves(t)), torch.tensor(0.0)) + dupe = lambda v: v.detach().clone().requires_grad_(v.requires_grad) + + # Actual output + init_scan = tree_map(dupe, init) + xs_scan = tree_map(dupe, xs) + final_carry, ys = scan(fn, init_scan, xs_scan, partition_fn=partition_fn) + # Add up all leaves and `backward()` once. + (squish(final_carry) + squish(ys)).backward() + torch_xla.sync() + + # Expected output + init_loop = tree_map(dupe, init) + xs_loop = tree_map(dupe, xs) + expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop) + # Add up all leaves and `backward()` once. + (squish(expected_final_carry) + squish(expected_ys)).backward() + torch_xla.sync() + + # Compare values + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + # Compare gradients + self.compare_pytree( + tree_map(lambda v: v.grad, init_loop), + tree_map(lambda v: v.grad, init_scan)) + self.compare_pytree( + tree_map(lambda v: v.grad, xs_loop), tree_map(lambda v: v.grad, + xs_scan)) + + return final_carry, ys + + def test_scan_simple(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + final_carry, ys = self.run_test(step_fn, init, xs) + + # Also ensure that our loop-based scan is correct, with manual checks + # that replicate the step_fn. + expected_final_carry = torch.sum(xs, dim=0) + init + expected_ys = torch.cumsum(xs, dim=0) + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + def test_scan_fn_not_callable(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) + with self.assertRaises(ValueError): + scan(1000, init, xs) # type: ignore + + def test_scan_incompatible_length(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + device=self.device) + xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) + with self.assertRaises(ValueError): + scan(lambda a, b: (a, b), init, (xs_1, xs_2)) + + def test_scan_tuples(self): + """Test scanning over the leading axis of a tuple of tensors simultaneously, + which is a simple PyTree.""" + + def fn(carry, x): + carry1, carry2 = carry + x1, x2 = x + new_carry1 = carry1 + x1.sum() + new_carry2 = carry2 + x2.sum() + y1 = x1 * 2 + torch.sum(new_carry1) + y2 = x2 * 2 + torch.sum(new_carry2) + return (new_carry1, new_carry2), (y1, y2) + + init = (torch.tensor([0.0], requires_grad=True, device=self.device), + torch.tensor([1.0, 2.0], requires_grad=True, device=self.device)) + + xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], + requires_grad=True, + device=self.device), + torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], + requires_grad=True, + device=self.device)) + + self.run_test(fn, init, xs) + + def test_scan_create_tensors(self): + """Test scanning over a function that internally creates tensors.""" + + def fn(carry, x): + a = torch.tensor([1.0, 2.0], device=self.device) + b = torch.tensor([3.0, 4.0], device=self.device) + return carry + a, x + b + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + self.run_test(fn, init, xs) + + def test_scan_internal_in_place_mutation(self): + """ + Test internal in-place mutations inside the `fn` to be scanned over. + """ + + def fn(carry, x): + carry = carry.clone() + carry.add_(x) + y = x.clone() + y.add_(42) + return carry, y + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + self.run_test(fn, init, xs) + + def test_scan_external_in_place_mutation(self): + """ + Test that external in-place mutations raise an exception instead of silently + giving wrong results. + """ + # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. + weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + + def step_fn(carry, x): + new_carry = carry + x + weird_global.add_(1.0) + y = new_carry + weird_global + return new_carry, y + + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) + + with self.assertRaisesRegex(AssertionError, "FakeTensor"): + scan(step_fn, init, xs) + + def test_scan_gradness(self): + """ + Test the gradient output of `scan` when various inputs require or doesn't + require gradients. + """ + + def test_case(init_requires_grad: bool, xs_requires_grad: bool): + + def fn(carry, x): + new_carry = carry * x + y = new_carry + x + return new_carry, y + + init = torch.tensor([1.0, 1.0], + requires_grad=init_requires_grad, + device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=xs_requires_grad, + device=self.device) + self.run_test(fn, init, xs) + + test_case(True, True) + test_case(True, False) + test_case(False, True) + + def test_scan_output_none(self): + """ + Test scan when `fn` returns `None` as output. This case is exercised by + `scan_layers`, which only needs the carry. + """ + + def fn(carry, x): + return torch.cos(carry) + x, None + + init = torch.tensor([1.0, 1.0], requires_grad=True, device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=True, + device=self.device) + _final_carry, ys = self.run_test(fn, init, xs) + self.assertIsNone(ys) + + def test_scan_output_unit(self): + """ + Test scan when `fn` returns `()` as output. + """ + + def fn(carry, x): + return torch.cos(carry) + x, () + + init = torch.tensor([1.0, 1.0], requires_grad=True, device=self.device) + xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], + requires_grad=True, + device=self.device) + _final_carry, ys = self.run_test(fn, init, xs) + self.assertEqual(ys, ()) + + def test_scan_rand_in_fn(self): + """ + Test that the RNG state in each iteration of `fn` is not the same. + """ + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + torch.rand(2, device=torch_xla.device()) + return new_carry, y + + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) + _, ys = scan(step_fn, init, xs) + # ys should be a 2D tensor with this shape. + self.assertEqual(ys.shape, (3, 2)) + # Values across the first dimension should not be the same. + self.assertNotEqual(ys[0][0], ys[1][0]) + self.assertNotEqual(ys[0][1], ys[1][1]) + + def test_scan_with_rematerialization(self): + """ + Test scanning `fn` but also the backward pass recomputes the forward. + """ + + def fn(carry, x): + for _ in range(10): + carry = torch.sin(carry) + for _ in range(10): + x = torch.sin(x) + return carry, x + + carry = torch.randn(4, 4, requires_grad=True, device=self.device) + xs = torch.randn(20, 4, 4, requires_grad=True, device=self.device) + + # Check the gradients and also cross-check with results from a run + # where we don't have activation checkpointing. + final_carry_remat, ys_remat = self.run_test( + fn, carry, xs, partition_fn=min_cut_rematerialization_partition) + final_carry, ys = self.run_test(fn, carry, xs) + super().compareResults(final_carry, final_carry_remat) + super().compareResults(ys, ys_remat) + torch_xla.sync() + + SINE_OP = re.compile(r" sine\(f32\b") + + def count_number_of_sines(partition_fn): + """ + Uses `partition_fn` to partition `fn` into forward and backward passes + while building the scan operation, then counts the number of `sine` HLO + operators in the joint graph. + + The intention is that if `partition_fn` recomputes some forward ops + during the backward, we'll see a larger number of `sine` operations since + `fn` consists of only `torch.sin` in this test. + """ + own_carry = carry.clone().detach().requires_grad_() + own_xs = xs.clone().detach().requires_grad_() + final_carry, ys = scan(fn, own_carry, own_xs, partition_fn=partition_fn) + torch_xla.sync() + (torch.sum(final_carry) + torch.sum(ys)).backward() + assert own_carry.grad is not None + assert own_xs.grad is not None + text: str = torch_xla._XLAC._get_xla_tensors_hlo( + [own_carry.grad, own_xs.grad]) + return len(SINE_OP.findall(text)) + + # Check the HLO to verify that `sine(...)` recomputation happens in the backward + # in the version using `min_cut_rematerialization_partition`, and never happens + # in the default partition. + self.assertGreater( + count_number_of_sines(min_cut_rematerialization_partition), 10) + self.assertEqual(count_number_of_sines(default_partition), 0) + + +class PyTreeTest(TestBase): + + def test_tree_flatten_none(self): + pytree = ((1, 2), (None, 3), None) + flat, unflatten = tree_flatten_none(pytree) + assert tuple(flat) == (1, 2, 3) + assert unflatten(flat) == ((1, 2), (None, 3), None) + + +class ValueAndGradPartitionedTest(TestBase): + + def test_transform_linear_layer(self): + + def fn(carry, x): + new_carry = carry @ x + y = new_carry + return new_carry, y + + init = torch.tensor([[1.0, 2.0], [3.0, 4.0]], + requires_grad=True, + device=self.device) + xs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], + requires_grad=True, + device=self.device) + forward, backward = value_and_grad_partitioned(fn, init, xs) + + # Forward should return `(new_carry, (y, (carry, x)))`, + # because `(carry, x)` are the two intermediate activations (primals), + # and they will be packed alongside the original output `y`. + out = forward(init, xs[0]) + torch_xla.sync() + carry = init + x = xs[0] + new_carry = init @ x + y = new_carry + self.compare_pytree(out, (new_carry, (y, (carry, x)))) + + # Backward should take in `(grad_new_carry, (grad_y, (carry, x)))`, and + # return `(grad_carry, grad_x)`. `(carry, x)` are the two intermediate + # activations (primals), and they are packed alongside the gradient with + # respect to y, `grad_y`. + grad_new_carry = torch.ones_like(new_carry) + grad_y = torch.ones_like(y) + out = backward(grad_new_carry, (grad_y, (carry, x))) + torch_xla.sync() + grad_carry = (grad_new_carry + grad_y) @ x.T + grad_x = carry.T @ (grad_new_carry + grad_y) + self.compare_pytree(out, (grad_carry, grad_x)) + + def test_transform_non_trivial_pytree(self): + """ + `fn` simulates two linear layers operating on two values a and b. + Test that we can trace `fn` when it uses non-trivial pytree, and + compare gradients against those from torch.autograd. + """ + + def fn(carry, x): + weights = x['weights'] + biases = x['biases'] + carry_a = carry['a'] + carry_b = carry['b'] + new_carry_a = torch.sin((carry_a @ weights) + biases) + new_carry_b = torch.cos((carry_b @ weights) + biases) + y = torch.sigmoid(new_carry_a + new_carry_b) + return {'a': new_carry_a, 'b': new_carry_b}, y + + init = { + 'a': torch.randn(2, 3, requires_grad=True, device=self.device), + 'b': torch.randn(2, 3, requires_grad=True, device=self.device) + } + x = { + 'weights': torch.randn(3, 3, requires_grad=True, device=self.device), + 'biases': torch.randn(2, 3, requires_grad=True, device=self.device) + } + + # Get the forward and backward functions using value_and_grad_partitioned + forward, backward = value_and_grad_partitioned( + fn, init, tree_map(lambda v: v.unsqueeze(0), x)) + + # Run the forward function + carry_out, (y_out, activations) = forward(init, x) + torch_xla.sync() + + # Compute expected outputs and gradients using PyTorch autograd + def compute_outputs_and_gradients(carry, x): + # Clone inputs to ensure they're independent + carry = tree_map(lambda v: v.clone().detach().requires_grad_(True), carry) + x = tree_map(lambda v: v.clone().detach().requires_grad_(True), x) + + # Forward pass + new_carry, y = fn(carry, x) + + # Run backward to compute gradients. + out, _ = tree_flatten((new_carry, y)) + torch.autograd.backward(out, tree_map(lambda v: torch.ones_like(v), out)) + + # Collect gradients + grads = { + 'init': tree_map(lambda v: v.grad, carry), + 'x': tree_map(lambda v: v.grad, x), + } + outputs = {'carry': new_carry, 'y': y} + return outputs, grads + + # Compute expected outputs and gradients + expected_outputs, expected_grads = compute_outputs_and_gradients(init, x) + + # Compare the outputs from the forward function with the expected outputs + self.compare_pytree(carry_out, expected_outputs['carry']) + self.compare_pytree(y_out, expected_outputs['y']) + + # Prepare gradients for the backward function + grad_carry = tree_map(lambda v: torch.ones_like(v), carry_out) + grad_y = torch.ones_like(y_out) + + # Run the backward function + grad_init, grad_x = backward(grad_carry, (grad_y, activations)) + torch_xla.sync() + + # Compare the gradients from the backward function with the expected gradients + self.compare_pytree(grad_init, expected_grads['init']) + self.compare_pytree(grad_x, expected_grads['x']) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py new file mode 100644 index 00000000000..a1eb68bd7d2 --- /dev/null +++ b/test/scan/test_scan_layers.py @@ -0,0 +1,280 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname( + os.path.dirname(__file__))) + "/examples" +sys.path.append(example_folder) +from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore + +import unittest +from copy import deepcopy +from typing import Iterable + +import torch +import torch.nn as nn + +import torch_xla +from torch_xla.experimental.scan_layers import scan_layers + +parent_folder = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(parent_folder) +from test_utils import XlaTestCase # type:ignore + + +class ScanLayersTest(XlaTestCase): + + def setUp(self): + super().setUp() + + self.device = torch_xla.device() + + def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor): + assert a is not b, f"Expected {a} and {b} to be different tensors" + assert a.data is not b.data, f"Expected {a} and {b} to have different storage" + + def assert_while_found_in_hlo(self, tensor: torch.Tensor): + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) + self.assertIn("while(", hlo_text) + self.assertIn("condition=", hlo_text) + self.assertIn("body=", hlo_text) + + def test_empty_layers(self): + layers = [] + input_data = torch.randn(64).to(self.device) + output = scan_layers(layers, input_data.clone()) + super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.001) + + def test_linear_layers(self): + # Fix the random seed to avoid flakes. + with torch.random.fork_rng(): + with torch_xla.xm.fork_rng(): + torch.random.manual_seed(42) + torch_xla.xm.set_rng_state(42) + # We want to apply these layers sequentially + layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)] + input_data = torch.randn(64).to(self.device) + torch_xla.sync(wait=True) + + layers_for_scan = deepcopy(layers) + layers_for_loop = deepcopy(layers) + torch_xla.sync() + + output = scan_layers(layers_for_scan, input_data.clone()) + self.assert_while_found_in_hlo(output) + output.sum().backward() + torch_xla.sync() + + # Test that the result is the same as for loop. + loop_output = input_data.clone() + for layer in layers_for_loop: + loop_output = layer(loop_output) + torch_xla.sync() + + super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001) + self.assert_different_tensor(loop_output, output) + + loop_output.sum().backward() + torch_xla.sync() + + # Test that the gradients are the same too. + for layer_scan, layer_loop in zip(layers_for_scan, layers_for_loop): + assert layer_scan.weight.grad is not None + assert layer_loop.weight.grad is not None + assert layer_scan.bias.grad is not None + assert layer_loop.bias.grad is not None + super().compareResults( + layer_scan.weight.grad, + layer_loop.weight.grad, + abs_err=0.0001, + rel_err=0.001) + super().compareResults( + layer_scan.bias.grad, + layer_loop.bias.grad, + abs_err=0.0001, + rel_err=0.001) + self.assert_different_tensor(layer_scan.weight.grad, + layer_loop.weight.grad) + self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad) + + def test_tuple_layers(self): + """Test applying layers that consume and return tuples. Construct a module + that transforms each element in the tuple. + """ + + class TupleModule(torch.nn.Module): + + def __init__(self): + super(TupleModule, self).__init__() + self.linear = nn.Linear(64, 64) + self.w = nn.Parameter(torch.randn(64, 64, requires_grad=True)) + + def forward(self, x, y, z): + return self.linear(x).sin(), self.linear( + y).cos(), self.linear(z) @ self.w + + layers = [TupleModule().to(self.device) for _ in range(10)] + torch_xla.sync() + + layers_for_scan = deepcopy(layers) + layers_for_loop = deepcopy(layers) + torch_xla.sync() + + # Also make input data some non-trivial graph instead of just device data. + input_data = (torch.randn(64).to(self.device) * 100, + torch.randn(64).to(self.device) * 200, + torch.randn(64).to(self.device) * 300) + a = torch.randn(64).to(self.device) + input_data = tuple(t + a for t in input_data) + output = scan_layers(layers_for_scan, input_data) + self.assert_while_found_in_hlo(output[0]) + self.assert_while_found_in_hlo(output[1]) + output[0].sum().backward() + torch_xla.sync() + + # Test that the result is the same as for loop. + loop_output = input_data + for layer in layers_for_loop: + loop_output = layer(*loop_output) + torch_xla.sync() + + super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001) + self.assert_different_tensor(loop_output[0], output[0]) + + loop_output[0].sum().backward() + torch_xla.sync() + + # Test that the gradients are the same too. + for layer_scan, layer_loop in zip(layers_for_scan, layers_for_loop): + assert layer_scan.linear.weight.grad is not None + assert layer_loop.linear.weight.grad is not None + assert layer_scan.linear.bias.grad is not None + assert layer_loop.linear.bias.grad is not None + super().compareResults( + layer_scan.linear.weight.grad, + layer_loop.linear.weight.grad, + abs_err=0.0001, + rel_err=0.001) + super().compareResults( + layer_scan.linear.bias.grad, + layer_loop.linear.bias.grad, + abs_err=0.0001, + rel_err=0.001) + self.assert_different_tensor(layer_scan.linear.weight.grad, + layer_loop.linear.weight.grad) + self.assert_different_tensor(layer_scan.linear.bias.grad, + layer_loop.linear.bias.grad) + + def test_decoder_model(self): + # Define a decoder model that composes the decoder model in the example, + # but adds the ability to run the layers with the `scan` operator. + class DecoderOnlyModelWithScan(torch.nn.Module): + + def __init__(self, **kwargs): + super(DecoderOnlyModelWithScan, self).__init__() + self.decoder = DecoderOnlyModel(**kwargs) + + @property + def layers(self) -> Iterable[torch.nn.Module]: + return self.decoder.layers + + def forward( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.decoder.forward(input_ids) + + def forward_scan( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.decoder.embed_tokens(input_ids) + # embed positions + assert isinstance(inputs_embeds, torch.Tensor) + # decoder layers + hidden_states = scan_layers(self.decoder.layers, inputs_embeds) + hidden_states = self.decoder.norm(hidden_states) + # [B, S, H] -> [B, S, V] + return self.decoder.output(hidden_states) + + # Fix the random seed to avoid flakes. + with torch.random.fork_rng(): + with torch_xla.xm.fork_rng(): + torch.random.manual_seed(42) + torch_xla.xm.set_rng_state(42) + + # Make it smaller for fast model run and comparisons. + config = DecoderOnlyConfig( + hidden_size=128, intermediate_size=8 * 128, vocab_size=256) + model = DecoderOnlyModelWithScan(config=config).to(self.device) + batch_size = 2 + sequence_length = 8 + + # Generate random input_ids within the range of the vocabulary size + input_ids = torch.randint(0, config.vocab_size, + (batch_size, sequence_length)).to(self.device) + + loop_model = deepcopy(model) + scan_model = deepcopy(model) + torch_xla.sync(wait=True) + + # Run the loop-based model. + loop_output = loop_model(input_ids.clone()) + loop_output.sum().backward() + torch_xla.sync() + + # Run again, this time using `scan` + scan_output = scan_model.forward_scan(input_ids.clone()) + scan_output.sum().backward() + + # Before materializing the tensors, check that tensor HLO has `While` in it. + self.assert_while_found_in_hlo(scan_output) + for layer_scan in scan_model.layers: + for (name, param_scan) in layer_scan.named_parameters(): + if param_scan.grad is not None: + self.assert_while_found_in_hlo(param_scan.grad) + + torch_xla.sync() + + # Compare results + super().compareResults( + scan_output, loop_output, abs_err=0.0001, rel_err=0.0001) + + # Check gradients + checks = 0 + for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers): + for (name, + param_scan), (name2, + param_loop) in zip(layer_scan.named_parameters(), + layer_loop.named_parameters()): + assert name == name2 + # Either the parameter should have gradient in both, or it should not + # have gradient in both. + assert (param_scan.grad is not None) == (param_loop.grad is not None) + # Check gradients + if param_scan.grad is not None and param_loop.grad is not None: + # Check that they are not the same tensor + assert id(param_scan.grad) != id(param_loop.grad) + assert id(param_scan.grad.untyped_storage()) != id( + param_loop.grad.untyped_storage()) + super().compareResults( + param_scan.grad, param_loop.grad, abs_err=0.0001, rel_err=0.0001) + checks = checks + 1 + assert checks > 0 + + def test_heterogenous_layers(self): + layer1 = nn.Linear(128, 128).to(torch_xla.device()) + layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device())) + with self.assertRaisesRegex(ValueError, "mismatched keys"): + scan_layers([layer1, layer2], + torch.zeros((128,), device=torch_xla.device())) + + def test_mismatched_shapes(self): + layer1 = nn.Linear(128, 128).to(torch_xla.device()) + layer2 = nn.Linear(128, 129).to(torch_xla.device()) + with self.assertRaisesRegex(ValueError, "Shape mismatch"): + scan_layers([layer1, layer2], + torch.zeros((128,), device=torch_xla.device())) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index 8d772a140f5..22e03c196b9 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -21,6 +21,7 @@ import itertools import math from numbers import Number +from functools import reduce import numpy import random import re @@ -2648,6 +2649,29 @@ def test_api(self): mapping = ctx.parameter_id_tensor_mapping() self.assertEqual(len(mapping), 2) + def test_get_parameters_scalar(self): + """Scalar tensors parameters may be shared in the HLO graph if their + numerical values are equal. `parameter_id_tensor_mapping` needs to handle + that appropriately. + """ + + device = torch_xla.device() + tensors = [] + for i in range(10): + # Add three copies of the same value. + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + result = reduce(lambda a, b: a + b, tensors) + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([result]) + mapping = ctx.parameter_id_tensor_mapping() + + import json + hlo_json = json.loads(ctx.hlo_json()) + num_parameters = len(hlo_json["hostProgramShape"]["parameters"]) + self.assertEqual(len(mapping), num_parameters) + class TestGeneric(test_utils.XlaTestCase): diff --git a/test/test_scan.py b/test/test_scan.py deleted file mode 100644 index 6926c01fb01..00000000000 --- a/test/test_scan.py +++ /dev/null @@ -1,107 +0,0 @@ -import sys -import unittest -import torch_xla -import torch -from torch_xla.experimental.scan import scan -from torch.utils._pytree import tree_map, tree_flatten, tree_iter - -from test_utils import XlaTestCase - - -def _loopy_scan(fn, init, xs): - """A simple scan implemented with for loops serving as reference - implementation.""" - carry = init - ys = [] - xs_len = len(next(iter(tree_iter(xs)))) - for i in range(xs_len): - carry, y = fn(carry, tree_map(lambda x: x[i], xs)) - ys.append(y) - ys = tree_map(lambda *x: torch.stack(x), *ys) - return carry, ys - - -class ScanTest(XlaTestCase): - - def setUp(self): - self.device = torch_xla.device() - - def compare_pytree(self, expected_pytree, actual_pytree): - flat_expected_pytree, expected_spec = tree_flatten(expected_pytree) - flat_actual_pytree, actual_spec = tree_flatten(actual_pytree) - assert expected_spec == actual_spec - super().compareResults(flat_expected_pytree, flat_actual_pytree) - - def run_test(self, step_fn, init, xs): - # Actual output - final_carry, ys = scan(step_fn, init, xs) - torch_xla.sync() - - # Expected output - expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs) - torch_xla.sync() - - # Compare - self.compare_pytree(expected_final_carry, final_carry) - self.compare_pytree(expected_ys, ys) - - return final_carry, ys - - def test_scan_forward_simple(self): - """This test uses `scan` to implement `torch.cumsum`.""" - - def step_fn(carry, x): - new_carry = carry + x - y = new_carry - return new_carry, y - - init = torch.tensor([0.0, 0.0], device=self.device) - xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) - final_carry, ys = self.run_test(step_fn, init, xs) - - # Also ensure that our loop-based scan is correct, with manual checks - # that replicate the step_fn. - expected_final_carry = torch.sum(xs, dim=0) + init - expected_ys = torch.cumsum(xs, dim=0) - self.compare_pytree(expected_final_carry, final_carry) - self.compare_pytree(expected_ys, ys) - - def test_scan_fn_not_callable(self): - init = torch.tensor([1.0, 1.0], device=self.device) - xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) - with self.assertRaises(ValueError): - scan(1000, init, xs) # type: ignore - - def test_scan_incompatible_length(self): - init = torch.tensor([1.0, 1.0], device=self.device) - xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], - device=self.device) - xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) - with self.assertRaises(ValueError): - scan(lambda a, b: (a, b), init, (xs_1, xs_2)) - - def test_scan_forward_tuples(self): - """Test scanning over the leading axis of a tuple of tensors simultaneously, - which is a simple PyTree.""" - - def step_fn(carry, x): - carry1, carry2 = carry - x1, x2 = x - new_carry1 = carry1 + x1.sum() - new_carry2 = carry2 + x2.sum() - y1 = x1 * 2 - y2 = x2 * 2 - return (new_carry1, new_carry2), (y1, y2) - - init = (torch.tensor([0.0], device=self.device), - torch.tensor([1.0, 2.0], device=self.device)) - - xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device), - torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device)) - - self.run_test(step_fn, init, xs) - - -if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index c29e6f42be5..648920b3258 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -25,8 +25,9 @@ XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrappi python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py -python3 test/test_scan.py -python3 test/test_pallas.py -v +python3 test/scan/test_scan.py +python3 test/scan/test_scan_layers.py +python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_tpu_paged_attention_kernel.py python3 test/test_input_output_aliases.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7e36a20c0e5..959bbfe2dd8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1069,7 +1069,6 @@ class PyLoweringContext { // etc.) std::unordered_map GetParameterIdTensorMapping() { // Find parameters in the lowering - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); @@ -1086,7 +1085,9 @@ class PyLoweringContext { at::ScalarType dtype = MaybeUpcastToHostTorchType(literal.shape().element_type()); at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); - results[param_ids[i]] = input; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + results[param_id.value()] = input; } return results; } @@ -1109,12 +1110,13 @@ class PyLoweringContext { torch::lazy::BackendData::Handle handle = data->GetHandle(); // Linearly search parameters and compare opaque handles - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); for (int i = 0; i < device_data.size(); ++i) { if (device_data[i]->GetHandle() == handle) { - return param_ids[i]; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + return param_id.value(); } } return -1; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c104be7c438..c2db9b36309 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -136,6 +137,16 @@ xla::XlaOp LoweringContext::GetParameter( return it->second.param; } +std::optional LoweringContext::GetParameterId( + const std::shared_ptr& data) const { + torch::lazy::BackendData::Handle handle = data->GetHandle(); + auto it = parameters_map_.find(handle); + if (it == parameters_map_.end()) { + return std::nullopt; + } + return it->second.index; +} + const std::vector& LoweringContext::GetParametersData() const { return parameters_; @@ -195,13 +206,14 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { auto post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); for (auto node : post_order) { LowerNode(node); } - // At this point the outpout better be present, otherwise there is an issue + // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); XLA_CHECK(it != emitted_outputs_.end()) @@ -216,6 +228,7 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { HloMetadataSetter meta_setter(this, node); const XlaNode* casted = dynamic_cast(node); + result_ops = casted->Lower(this); if (!casted->dynamic_dims().empty()) { xla::internal::XlaBuilderFriend builder_friend; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index e645f959af0..3a36695e1c0 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -52,6 +53,11 @@ class LoweringContext : public torch::lazy::LoweringContext { const std::shared_ptr& data, const std::unordered_set& dynamic_dims = {}); + // If a parameter associated with data has already been declared, returns its + // ID. Otherwise, returns `std::nullopt`. + std::optional GetParameterId( + const std::shared_ptr& data) const; + // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. const std::vector& GetParametersData() const; diff --git a/torch_xla/experimental/pytreeify.py b/torch_xla/experimental/pytreeify.py new file mode 100644 index 00000000000..9fb0d282526 --- /dev/null +++ b/torch_xla/experimental/pytreeify.py @@ -0,0 +1,50 @@ +import torch.utils._pytree as pytree +from torch.autograd import Function + + +# Taken from https://github.com/pytorch/pytorch/issues/96337 +# +# The main purpose is to support autograd in the `scan` operator, which takes in +# PyTrees and outputs PyTrees. Builtin PyTorch autograd ignores tensors in +# non-trivial PyTrees such as dictionaries of tensors. This decorator adds +# arbitrary PyTree support by flattening the PyTree before handing to PyTorch and +# unflattening on the way back. +def pytreeify(cls): + assert issubclass(cls, Function) + + orig_fw = cls.forward + orig_bw = cls.backward + orig_apply = cls.apply + + def new_apply(*inp): + flat_inp, struct = pytree.tree_flatten(inp) + out_struct_holder = [] + flat_out = orig_apply(struct, out_struct_holder, *flat_inp) + assert flat_out is not None + assert len(out_struct_holder) == 1 + return pytree.tree_unflatten(flat_out, out_struct_holder[0]) + + def new_forward(ctx, struct, out_struct_holder, *flat_inp): + inp = pytree.tree_unflatten(flat_inp, struct) + out = orig_fw(ctx, *inp) + flat_out, out_struct = pytree.tree_flatten(out) + ctx._inp_struct = struct + ctx._out_struct = out_struct + out_struct_holder.append(out_struct) + return tuple(flat_out) + + def new_backward(ctx, *flat_grad_outputs): + grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct) + if not isinstance(grad_outputs, tuple): + grad_outputs = (grad_outputs,) + grad_inputs = orig_bw(ctx, *grad_outputs) + flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs) + if grad_inputs_struct != ctx._inp_struct: + raise RuntimeError("The backward generated an arg structure that doesn't " + "match the forward's input.") + return (None, None) + tuple(flat_grad_inputs) + + cls.apply = new_apply + cls.forward = new_forward + cls.backward = new_backward + return cls diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 9008e03dbd9..2e0ced89927 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -2,12 +2,50 @@ Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html +# High level design + +The implementation is factored into two layers: core and autograd. The core +layer focuses on the numerical scan operation without any gradient tracking, and +the autograd layer adds forward and backward support using the scan primitive in +core. + +## Core + +The `_scan_impl_flat` function implements the core logic of scan on flattened +tensors. It uses XLA's `While` op to iterate over the leading dimension of the +input tensors. The body of the `While` loop calls `fn` and updates the carry and +output tensors. + +The `_scan_impl_pytree` function adds PyTree support on top. It flattens the +input PyTrees, calls `_scan_impl_flat` to perform the scan on the flattened +tensors, and then unflattens the results. Because gradients are sometimes +`None`, it also hides any `None`s in PyTrees from `_scan_impl_flat`, +simplifying the latter's implementation. + +## Autograd + +The `value_and_grad_partitioned` function symbolically traces the user-provided +function `fn` to obtain the forward and backward computation graphs. It then +creates two functions, `forward` and `backward`, that can be used in the +`Scan.forward` and `Scan.backward` methods. + +The `scan` operator is implemented as a PyTorch autograd Function, `Scan`. +The `Scan.forward` method scans the forward graph over the inputs. +The `Scan.backward` method scans the backward graph over the gradients and +activations. """ -from typing import Callable, TypeVar +import itertools +from typing import Callable, Dict, Sequence, TypeVar, Tuple, List, Optional, overload import torch -from torch.utils._pytree import tree_map, tree_iter +import torch.autograd +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten, tree_iter, PyTree +from functorch.compile import aot_function, make_boxed_func, default_partition # type: ignore + +import torch_xla +import torch_xla.core.xla_builder as xb +from torch_xla.experimental.pytreeify import pytreeify Carry = TypeVar('Carry') X = TypeVar('X') @@ -18,11 +56,13 @@ def scan( fn: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, + partition_fn=default_partition, + # TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration. ) -> tuple[Carry, Y]: """Apply a function over leading dimension of tensors while carrying along state. - + This is similar to the JAX `jax.lax.scan` function found in [1]. - + You may use it to loop over the leading dimension of tensors efficiently. If `xs` is a single tensor, this function is roughly equal to the following Python code: @@ -33,33 +73,65 @@ def scan(fn, init, xs): carry, y = fn(carry, xs[i]) ys.append(y) return carry, torch.stack(ys, dim=0) - + In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function will iterate through the leading dimension of every leaf element of `xs` simultaneously, and pass a slice of those elements to `fn` as another PyTree. This means you may scan over multiple tensors and produce multiple output tensors at once. - - Args: - fn: a Python callable that accepts two PyTrees of tensors: the carry object and the - slices of `xs` along its leading dimension. It should return two PyTrees: the carry - object and the slices of the output. The returned carry object will be passed to - the next invocation of `fn`. + Notes: - init: the initial carry object passed to the first invocation of `fn`. + `fn` must be AOTAutograd traceable. That requires PyTorch to understand the operations + within. For example if you invoke a custom kernel inside `fn`, you need to register the + custom kernel. See https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html. + + Args: + fn: A Python callable that accepts two PyTrees of tensors: the carry object and the + slices of `xs` along its leading dimension. It should return two PyTrees: the carry + object and the slices of the output. The returned carry object will be passed to + the next invocation of `fn`. + + init: The initial carry object passed to the first invocation of `fn`. - xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along - the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of - tensor), `fn` will get PyTrees of slices. In that case the leading dimension size - of the leaves in the PyTree must be the same. + xs: The input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along + the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of + tensor), `fn` will get PyTrees of slices. In that case the leading dimension size + of the leaves in the PyTree must be the same. + + partition_fn: Since `scan` uses AOTAutograd to trace `fn`, you may override what + computation happen in the forward and backward passes by specifying different partition + functions. `default_partition` implies no activation checkpointing. You may specify + `functorch.compile.min_cut_rematerialization_partition` to use min-cut based + activation checkpointing. You may also write your own partitioner to insert any + custom logic such as host offloading of activations. Returns: - (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns `(carry, (y1, y2))` then this function will return `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`. + + Example: + + >>> # Example of using `scan` to implement `torch.cumsum`. + >>> import torch_xla.runtime + >>> import torch + >>> from torch_xla.experimental.scan import scan + >>> + >>> def fn(carry, x): + >>> new_carry = carry + x + >>> y = new_carry + >>> return new_carry, y + >>> + >>> with torch_xla.runtime.xla_device(): + >>> init = torch.tensor([0.0, 0.0], requires_grad=True) + >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + >>> requires_grad=True) + >>> final_carry, ys = scan(fn, init, xs) + >>> torch_xla.sync() + >>> print(final_carry) # Should be [9.0, 12.0] + >>> print(ys) # Should be [[1.0, 2.0], [4.0, 6.0], [9.0, 12.0]] [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html """ @@ -82,14 +154,506 @@ def scan(fn, init, xs): if xs_length is None: raise ValueError(f"`xs` {xs} is an empty PyTree.") - carry = init - ys = [] + forward, backward = value_and_grad_partitioned( + fn, init, xs, partition_fn=partition_fn) + carry, ys = Scan.apply(forward, backward, init, xs) # type: ignore + return carry, ys + + +def value_and_grad_partitioned( + fn: Callable[[Carry, X], tuple[Carry, Y]], + init: Carry, + xs: X, + partition_fn=default_partition) -> tuple[Callable, Callable]: + """ + Given a user `fn` to be scanned over the leading dimension of the input `xs` + PyTree and an initial carry object `init`, symbolically traces `fn` and + returns two functions, `forward` and `backward`, which wrap the forward and + backward graphs of `fn` and plumbs through intermediate activations. + Specifically, given + + `fn(carry, x) -> (new_carry, y)` + + this function will build and return + + `forward(carry, x) -> (new_carry, (y, activations))` + + `backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)` + + where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient + w.r.t. `new_carry`. + + `activations` will always be a flat list of tensors. + + This is similar to the `value_and_grad` transform found in JAX, but additionally + partitions and returns separate forward/backward passes, so that we may later + use them in the `autograd.Function` implementation of `Scan`. + + Args: + fn: (Callable[[Carry, X], tuple[Carry, Y]]) A callable with signature + `fn(carry, x_t) -> (new_carry, y_t)`, representing the function to be scanned. + + init: (Carry) The initial carry object. + + xs: (X) A PyTree of inputs to be scanned over. + + partition_fn: An optional partitioning function used to partition fn into + forward and backward graphs. + + Returns: + A tuple of `(forward, backward)`, detailed in the docstring of this function. + """ + + # Make some fake tensors to trace the user function and obtain the + # forward and backward graphs. Note that the init/carry fake tensor + # always requires grad. That's because even if the user passed in some + # `init` that does not require grad, we still want gradients to flow + # through the `carry` from one iteration of the user function to the + # next. In summary, the `carry` argument used to trace a user function + # to get a correct backward pass always requires grad. + def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor: + return torch.empty_like( + v, dtype=v.dtype, device=v.device, requires_grad=requires_grad) + + fake_carry_pytree = tree_map(make_fake_tensor, init) + fake_x_pytree = tree_map( + lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs) + + with torch.enable_grad(): + fw_compiler, get_fwd = _make_get_graph_compiler() + bw_compiler, get_bwd = _make_get_graph_compiler() + fn_compiled = aot_function( + fn, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn) + _, unflatten_bwd_out = tree_flatten_none((fake_carry_pytree, fake_x_pytree)) + out = fn_compiled(fake_carry_pytree, fake_x_pytree) + # How many outputs out of the fwd_graph is actually outputs of `fn`, and not + # intermediate activations. + num_out = len(list(tree_iter(out))) + # Capture the backward. + out, unflatten_fwd_out = tree_flatten_none(out) + torch.autograd.backward(out, tree_map(lambda v: torch.ones_like(v), out)) + + fwd_graph = get_fwd() + bwd_graph = get_bwd() + + def forward(carry, x): + flat_carry, _ = tree_flatten(carry) + flat_x, _ = tree_flatten(x) + out = fwd_graph(*flat_carry, *flat_x) + actual_out, activations = split(out, num_out) + carry, y = unflatten_fwd_out(actual_out) + y = (y, activations) + return carry, y + + def backward(carry, x): + grad_new_carry, _ = tree_flatten(carry) + (grad_y, activations) = x + grad_y, _ = tree_flatten_none(grad_y) + out = bwd_graph(*activations, *grad_new_carry, *grad_y) + grad_carry, grad_x = unflatten_bwd_out(out) + return grad_carry, grad_x + + return forward, backward + + +def _make_get_graph_compiler(): + """ + Creates a compiler that records the graph, and a getter + function to retrieve them. + """ + graph: List[Optional[torch.fx.GraphModule]] = [None] + + def forward_comp(fx_module: torch.fx.GraphModule, _): + assert graph[0] is None + graph[0] = fx_module + return make_boxed_func(fx_module) + + def get_graph(): + g = graph[0] + assert g is not None + return g + + return forward_comp, get_graph + + +@pytreeify +class Scan(torch.autograd.Function): + + @staticmethod + def forward(ctx, forward, backward, init, xs): + # Forward pass, save activations for backward + ctx._backward = backward + with torch.no_grad(): + carry, ys = _scan_impl_pytree(forward, init, xs) + ys, activations = ys + ctx.save_for_backward(*activations) + return carry, ys + + @staticmethod + def backward(ctx, grad_carry, grad_ys): + activations = ctx.saved_tensors + backward = ctx._backward + with torch.no_grad(): + # Reverse loop to propagate gradients from last iteration to first. + grad_init, grad_xs = _scan_impl_pytree( + backward, grad_carry, (grad_ys, activations), reverse=True) + return None, None, grad_init, grad_xs + + +def _scan_impl_pytree(fn, init, xs, reverse: bool = False): + """Forward logic of scan without gradient tracking. `fn` operates on + PyTrees. `init` and `xs` are also PyTrees. + + See the `Scan` class which implements an autograd `Function` and builds + autograd support on top of `_scan_impl`. + """ + flat_init, unflatten_carry = tree_flatten_none(init) + flat_xs, unflatten_xs = tree_flatten_none(xs) + unflatten_y: Callable[..., PyTree] = lambda _: () # Set by `flat_fn`. + + def flat_fn( + carry: Sequence[torch.Tensor], x: Sequence[torch.Tensor] + ) -> Tuple[Sequence[torch.Tensor], Sequence[torch.Tensor]]: + nonlocal unflatten_y + carry_pytree = unflatten_carry(carry) + x_pytree = unflatten_xs(x) + carry_pytree, y_pytree = fn(carry_pytree, x_pytree) + flat_carry, _ = tree_flatten_none(carry_pytree) + flat_y, unflatten_y = tree_flatten_none(y_pytree) + return flat_carry, flat_y + + flat_carry, flat_y = _scan_impl_flat( + flat_fn, flat_init, flat_xs, reverse=reverse) + return unflatten_carry(flat_carry), unflatten_y(flat_y) + - for i in range(xs_length): - carry, y = fn(carry, tree_map(lambda x: x[i], xs)) - ys.append(y) +def tree_flatten_none(pytree: PyTree): + """ + Flattens input `pytree`, and filters out any `None` leaf PyTree nodes. + Returns the flattened list, and an unflatten function and also adds back + the removed `None`s in their correct location. + """ + flat, spec = tree_flatten(pytree) + flat, add_none = _remove_none(flat) + + def unflatten(flat): + flat = add_none(flat) + return tree_unflatten(flat, spec) + + return flat, unflatten + + +def _remove_none(s: Sequence[Optional[torch.Tensor]]): + """ + Filters out `None` values from `s`. Returns the filtered sequence, + and another function that will add back the `None` values when given a + sequence of the same structure. + """ + filtered = [v for v in s if v is not None] + none_mask = [v is None for v in s] + + def add_back_nones(s_filtered): + res = [] + idx_filtered = 0 + for is_none in none_mask: + if is_none: + res.append(None) + else: + res.append(s_filtered[idx_filtered]) + idx_filtered += 1 + return res + + return filtered, add_back_nones + + +def dynamic_update_slice(ys: xb.Op, y: xb.Op, idx: xb.Op) -> xb.Op: + # See https://openxla.org/xla/operation_semantics#dynamicupdateslice. + y = y.broadcast([1]) + indices = [idx] + for _ in range(ys.shape().rank - 1): + indices.append(idx.zeros_like()) + return ys.dynamic_update_slice(y, indices) + + +def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: + indices = [idx] + for _ in range(xs.shape().rank - 1): + indices.append(idx.zeros_like()) + slice_shape = list(xs.shape().sizes) + slice_shape[0] = 1 + sliced = xs.dynamic_slice(indices, slice_shape) + shape = list(xs.shape().sizes) + shape = shape[1:] + return sliced.reshape(shape) + + +class Builder: + + def __init__(self, name: str): + self._builder = xb.create_builder(name) + self._params = [] + self._param_tensors = [] + + def add_param(self, val: torch.Tensor): + idx = len(self._params) + param = xb.mkparam(self._builder, idx, xb.tensor_shape(val)) + self._params.append(param) + self._param_tensors.append(val) + return idx + + def params(self) -> Tuple[xb.Op, ...]: + return tuple(self._params) + + def param_tensors(self) -> Tuple[torch.Tensor, ...]: + return tuple(self._param_tensors) + + def num_params(self) -> int: + return len(self._params) + + +def _scan_impl_flat(fn, + init: Sequence[torch.Tensor], + xs: Sequence[torch.Tensor], + reverse: bool = False): + """Forward logic of scan without gradient tracking. `fn` operates on + two flat list of tensors. `init` and `xs` are also flat lists of tensors. None + of the tensors will be `None`. + + See the `Scan` class which implements an autograd `Function` and builds + autograd support on top of `_scan_impl`. + + ## Handling of random numbers + + When `fn` generates random numbers (e.g. it uses a dropout layer), we need to + ensure that each iteration of `fn` within the scan yields different random + numbers, despite running the same HLO operations. JAX requires the user to + explicitly fork the RNG state and pass it to `fn`. In PyTorch, the RNG state + is an implicit global variable. Therefore, we take a slightly different + approach: + + - Identify usage of RNG state via `_get_tensors_xla_device_data_node`. + - Create N different copies of the RNG state contained in a tensor. + - While building the `While` op body, index into the RNG state tensor at the + current iteration and provide that seed value to `fn`. + + ## Handling of HLO parameters + + Let's say the user writes a `fn` like this: + + def fn(carry, x): + foo = torch.zeros(8) + return carry, x + foo + + `fn` will lower into an HLO computation like this: + + HloModule Fn, entry_computation_layout={ + (f32[8], f32[8], f32[8]) -> (f32[8], f32[8]) + } + + The HLO computation takes three parameters while `fn` takes two arguments. + That's because IR lowering does not distinguish if a leaf data tensor comes from + a function argument or from within the function. All data tensors are lowered + into HLO parameters. We'll call them "hoisted variables" or `hoisted_vars`, since + instead of baking the value of those tensors as literals in the HLO graph, + they are turned into additional parameters of the computation. + """ + carry_len = len(init) + xs_len = len(xs) + + # Abstractly trace and lower `fn`. + # Later we will include `fn_computation` within the while loop body. + def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: + return torch.empty( + v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad) - # Combine the list of PyTrees into one PyTree, where the leaves are - # stacked into a new major axis. - ys = tree_map(lambda *x: torch.stack(x), *ys) + device = torch_xla.device() + fake_carry = tree_map(make_fake_tensor, init) + fake_x = tree_map(lambda v: make_fake_tensor(v[0]), xs) + fake_output_carry, fake_output_y = fn(fake_carry, fake_x) + + y_len = len(fake_output_y) + fn_outputs = fake_output_carry + fake_output_y + + fn_ctx = torch_xla._XLAC.lowering.LoweringContext() + fn_ctx.set_name_string("fn_ctx") + fn_ctx.build(list(fn_outputs)) + fn_hlo = fn_ctx.hlo() + fn_computation = xb.computation_from_module_proto("fn_computation", fn_hlo) + + # Figure out the shape of `ys` from the abstract tracing. + fn_carry_out, fn_y_out = split(fn_outputs, carry_len) + assert carry_len + y_len == len(fn_outputs) + fn_carry_shapes = [v.shape for v in fn_carry_out] + fn_y_shapes = [v.shape for v in fn_y_out] + for fn_carry_shape, init_leaf in zip(fn_carry_shapes, init): + assert fn_carry_shape == init_leaf.shape, f"`fn` must keep the `carry` shape unchanged. \ + Got {fn_carry_shape} but expected {init_leaf.shape}" + + builder = Builder('scan') + num_iters = next(iter(tree_iter(xs))).size(0) + ys = [ + torch.zeros((num_iters, *fn_y_shape), device=device) + for fn_y_shape in fn_y_shapes + ] + # Start the `curr_iter` loop variable at zero. + zero = torch.tensor(0, device=device) + builder.add_param(zero) + + # We are building a bigger XLA computation (the while loop) that calls + # a smaller computation (`fn_computation`). This is a mapping from + # `fn_computation` param ID to While computation param ID. + fn_param_id_to_while_param_id: Dict[int, int] = {} + + # Add carry and x. + for real, fake in ((init, fake_carry), (xs, fake_x)): + for val, fake_val in zip(real, fake): + idx = builder.add_param(val) + param_id = fn_ctx.tensor_parameter_id(fake_val) + if param_id != -1: + fn_param_id_to_while_param_id[param_id] = idx + + # Add the output as a param since our While computation consumes it, updates + # one slice, and returns the updated ys in each iteration. + for val in ys: + builder.add_param(val) + + # Detect hoisted variables. + hoisted_vars: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping() + for v in itertools.chain(fake_carry, fake_x): + param_id = fn_ctx.tensor_parameter_id(v) + if param_id != -1: + del hoisted_vars[param_id] + + # Detect RNG seed usage within the scanned function within hoisted variables. + ids, i_values = torch_xla._XLAC._get_tensors_xla_device_data_node(fn_outputs) + seed_info_id = torch_xla._XLAC._get_seed_info_id() + seed_parameter_id = None + if seed_info_id in ids: + seed_idx = ids.index(seed_info_id) + seed_parameter_id = fn_ctx.tensor_parameter_id(i_values[seed_idx]) + assert seed_parameter_id != -1, "`fn` uses random seed, but random seed is not \ + a parameter to the traced HLO graph" + + # Replace the single seed value with a tensor of seeds, one per iteration. + seed_tensor = hoisted_vars[seed_parameter_id] + assert seed_tensor.dtype == torch.int64 + hoisted_vars[seed_parameter_id] = torch.randint( + 0, 2**62, (num_iters,), dtype=torch.int64, device=torch_xla.device()) + + # Add hoisted variables as While computation params as well, + # including the potentially updated seed tensor. + for param_id, tensor in hoisted_vars.items(): + idx = builder.add_param(tensor.to(torch_xla.device())) + fn_param_id_to_while_param_id[param_id] = idx + + # Since we are threading five objects through the body_fn: + # + # - curr_iter: the current loop iteration + # - carry: the scan state + # - xs: the flattened input pytree + # - ys: the flattened output of fn + # - hoisted_vars: tensors not provided as arguments to fn but still used by fn. + # + # We need to concatenate all into one big list prior to entering `body_fn` and + # `cond_fn`, and split them back which is easier to work with after that. This + # pair of `pack`, `unpack` functions is for that purpose. + T = TypeVar('T') + + def pack(curr_iter: T, carry: Sequence[T], xs: Sequence[T], ys: Sequence[T], + hoisted_vars: Sequence[T]) -> Tuple[T, ...]: + return tuple(itertools.chain((curr_iter,), carry, xs, ys, hoisted_vars)) + + def unpack(seq: Sequence[T]) -> Tuple[T, List[T], List[T], List[T], List[T]]: + curr_iter, carry, xs, ys, hoisted_vars = split( + list(seq), 1, carry_len, xs_len, y_len) + curr_iter = curr_iter[0] + return curr_iter, carry, xs, ys, hoisted_vars + + def replace_rng_seed(curr_iter: xb.Op, *while_params: xb.Op): + """Slices the pre-generated seed tensor for the current iteration.""" + if seed_parameter_id is None: + return while_params + idx = fn_param_id_to_while_param_id[seed_parameter_id] + replaced = list(while_params) + replaced[idx] = dynamic_slice(replaced[idx], curr_iter) + return replaced + + def call_fn_computation(*while_params: xb.Op) -> xb.Op: + # We need to order the tensors in increasing parameter ID order when + # passing them to `xb.Op.call`. + fn_inputs = [ + while_params[fn_param_id_to_while_param_id[i]] + for i in range(len(fn_param_id_to_while_param_id)) + ] + return xb.Op.call(fn_computation, fn_inputs) + + def cond_fn(curr_iter: xb.Op, *rest): + return curr_iter < xb.Op.scalar( + curr_iter.builder(), num_iters, dtype=xb.Type.S64) + + def body_fn(*while_params: xb.Op): + curr_iter, carry, xs, ys, hoisted_vars = unpack(while_params) + if reverse: + max_iter = xb.Op.scalar( + curr_iter.builder(), num_iters - 1, dtype=xb.Type.S64) + idx = max_iter - curr_iter + else: + idx = curr_iter + x = [dynamic_slice(v, idx) for v in xs] + result = call_fn_computation( + *replace_rng_seed(idx, curr_iter, *carry, *x, *ys, *hoisted_vars)) + for i in range(carry_len): + carry[i] = result.get_tuple_element(i) + for i in range(y_len): + y = result.get_tuple_element(i + carry_len) + ys[i] = dynamic_update_slice(ys[i], y, idx) + one = xb.Op.scalar(curr_iter.builder(), 1, dtype=xb.Type.S64) + return pack(curr_iter + one, carry, xs, ys, hoisted_vars) + + res = xb.Op.mkwhile(builder.params(), cond_fn, body_fn) + computation = res.build('scan') + outputs = torch_xla._XLAC._xla_user_computation('xla::scan', + builder.param_tensors(), + computation) + _curr_iter, carry, xs, ys, _hoisted_vars = unpack(outputs) return carry, ys + + +U = TypeVar('U') + + +@overload +def split(seq: List[U], *part_lengths: int) -> Tuple[List[U], ...]: + ... + + +@overload +def split(seq: Tuple[U, ...], *part_lengths: int) -> Tuple[Tuple[U, ...], ...]: + ... + + +def split(seq: Sequence[U], *part_lengths: int) -> Tuple[Sequence[U], ...]: + """Splits a sequence into subsequences with given lengths. + + Args: + seq: The sequence (list or tuple) to split. + *part_lengths: The lengths of the subsequences, except the last subsequence. + + Example: + + a, b, c = split((1, 2, 3, 4, 5), 2, 2) + # a == (1, 2), b == (3, 4), c == (5, ) + + Returns: + A tuple of subsequences (lists or tuples). + """ + parts = [] + start = 0 + for length in part_lengths: + parts.append(seq[start:start + length]) + start += length + parts.append(seq[start:]) + return tuple(parts) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py new file mode 100644 index 00000000000..9d0beb525a8 --- /dev/null +++ b/torch_xla/experimental/scan_layers.py @@ -0,0 +1,142 @@ +from typing import Iterable, Mapping, Sequence, Dict, Tuple + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map +from functorch.compile import default_partition + +from torch_xla.experimental.scan import scan + + +def scan_layers(layers: Iterable[torch.nn.Module], + input_data, + partition_fn=default_partition): + """Runs each layer in `layers` sequentially, starting with `input_data`. + + `input_data` is provided as input to the first layer in `layers`. The output of one + layer is provided as input to next layer. + + All modules in `layers` must have the same structure, and they must perform the same + calculations given the same model parameters and inputs. In practice, this means you + cannot use different dropout probabilities, parameter shapes, activation functions etc., + across the `layers`. + + Under these conditions, this function is equivalent to + + sequential = torch.nn.Sequential(*layers) + sequential(input_data) + + This function can be faster to compile since it reuses the XLA computation of the + first layer to perform the computation of all other layers. + + Args: + layers: (Iterable[torch.nn.Module]) A list of layers to run. + + input_data: The input to be given to the first layer from `layers`. + + partition_fn: The graph parition function passed to AOTAutograd. Since this function + uses AOTAutograd to trace `fn`, you may override what computation happen in the + forward and backward passes by specifying different partition functions. + `default_partition` implies no activation checkpointing. You may specify + `functorch.compile.min_cut_rematerialization_partition` to use min-cut based + activation checkpointing. You may also write your own partitioner to insert any custom + logic such as host offloading of activations. + + Returns: + The output of the last layer from `layers`. + + Example: + + >>> import torch_xla.runtime + >>> import torch + >>> import torch.nn as nn + >>> from torch_xla.experimental.scan_layers import scan_layers + >>> with torch_xla.runtime.xla_device(): + >>> layers = [nn.Linear(16, 16) for i in range(10)] + >>> input = torch.randn(16) + >>> output = scan_layers(layers, input) + >>> assert output.shape == (16,) # Output is the 10-th layer output + >>> print(output) # Some random numbers + """ + # Handle empty layers case. + try: + first_layer = next(iter(layers)) + except StopIteration: + return input_data + + # Extract and stack the parameters and buffers into pytrees. + params_and_buffers = [_extract_weights_and_buffers(layer) for layer in layers] + params_list = [p for p, _ in params_and_buffers] + buffers_list = [b for _, b in params_and_buffers] + + _ensure_same_structure(params_list) + _ensure_same_structure(buffers_list) + + stacked_params = tree_map(lambda *tensors: torch.stack(tensors, dim=0), + *params_list) + stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0), + *buffers_list) + + # Use the first layer as the example/template layer. + from copy import deepcopy + example_layer = deepcopy(first_layer) + + # Define the function to apply at each step + def one_layer(carry, params_buffers): + # Apply the current layer's weights and biases to the example layer, + # then run the resulting layer. + output = torch.func.functional_call( # type: ignore + example_layer, params_buffers, carry, strict=True) + return output, None + + stacked_params_buffers = (stacked_params, stacked_buffers) + final_carry, _ = scan( + one_layer, input_data, stacked_params_buffers, partition_fn=partition_fn) + + return final_carry + + +def _extract_weights_and_buffers( + module: nn.Module +) -> Tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.Tensor]]: + """ + Extracts the parameters and buffers from a PyTorch module and + stores them in separate dictionaries. + """ + weights_dict = {name: param for name, param in module.named_parameters()} + buffers_dict = {name: buffer for name, buffer in module.named_buffers()} + return weights_dict, buffers_dict + + +def _ensure_same_structure(dicts: Sequence[Mapping[str, torch.Tensor]]): + """ + Verifies that all dictionaries in `dicts` have the same structure: + they have the same keys and all the values have the same shape. + """ + if not dicts: + return + + reference_keys = set(dicts[0].keys()) + reference_shapes = {key: dicts[0][key].shape for key in reference_keys} + + for idx, current_dict in enumerate(dicts[1:], start=1): + current_keys = set(current_dict.keys()) + + # Check if keys match + if current_keys != reference_keys: + missing_keys = reference_keys - current_keys + extra_keys = current_keys - reference_keys + error_message = f"Layer {idx} has mismatched keys." + if missing_keys: + error_message += f" Missing keys: {missing_keys}." + if extra_keys: + error_message += f" Extra keys: {extra_keys}." + raise ValueError(error_message) + + # Check if shapes match for each key + for key in reference_keys: + ref_shape = reference_shapes[key] + current_shape = current_dict[key].shape + if ref_shape != current_shape: + raise ValueError(f"Shape mismatch for '{key}' in layer {idx}: " + f"expected {ref_shape}, got {current_shape}.")