From 7c261aed8dedf93cc3d16150ec2f98af2688fe4c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 19 Sep 2023 10:43:02 -0700 Subject: [PATCH] Cherry-pick 2.1 release branch into `xrt` through 9/18 (#5607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle dynamo function without input (#5565) (#5577) * Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576) * [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580) * Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron * fixing linter issues * fixing spacing * apply comments and fix compilation errors * add test for new apis * fix linter * update test * update test * modify test * reverse back to GetIrValue() * update test inputs with random numbers * skip unittest because it only fails in CI --------- Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> Co-authored-by: Ubuntu Co-authored-by: seanlatias * fixing num_local_processes typo (#5573) (#5579) Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> * Move where clear pending IR is called to avoid crash (#5552) (#5582) * Move where clear pending IR is called to avoid crash * fix CI * fix CI and add some debugging messages * Fix release branch and tag patterns for GitHub Actions (#5587) (#5590) * Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589) * Allow downcasting RngUniform genenration for Bernoulli Co-authored-by: Yeounoh Chung * Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591) * Enable autocast for XLA:GPU * linter fix * XLA autocast test for GPU and TPU * linter fix * Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available. * linter fix * Add tests * Support bf16 * linter fix * exclude unsupported test cases * increase GPU test timeout to 300 Co-authored-by: Yeounoh Chung * Cherry-pick: Don't trigger CI build on release tag push (#5595) Copy of #5594 on release branch * formatting --------- Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Co-authored-by: Wonjoo Lee Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> Co-authored-by: Ubuntu Co-authored-by: seanlatias Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com> Co-authored-by: Yeounoh Chung --- .github/workflows/build_and_test.yml | 3 +- test/dynamo/test_dynamo.py | 95 +++++++---- .../test_xla_spmd_python_api_interaction.py | 20 +++ test/test_autocast.py | 58 ++++++- test/test_operations.py | 22 +++ torch_xla/_internal/pjrt.py | 2 +- torch_xla/amp/autocast_mode.py | 61 +++++-- torch_xla/core/dynamo_bridge.py | 46 +++--- torch_xla/csrc/init_python_bindings.cpp | 151 ++++++++++++++++++ torch_xla/csrc/random.cpp | 18 ++- torch_xla/csrc/random.h | 8 +- torch_xla/csrc/xla_graph_executor.cpp | 3 + torch_xla/csrc/xla_lower_util.cpp | 3 +- torch_xla/runtime.py | 10 ++ 14 files changed, 423 insertions(+), 77 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index df51a839d58..fdbfeff7825 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -6,7 +6,6 @@ on: push: branches: - master - tags: - r[0-9]+.[0-9]+ workflow_dispatch: @@ -44,7 +43,7 @@ jobs: with: docker-image: ${{ needs.build.outputs.docker-image }} runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 240 + timeout-minutes: 300 disable-xrt: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 0d33b7002c9..7cf83c84eff 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -12,6 +12,9 @@ import torch._dynamo as dynamo import torchvision import unittest +import warnings + +torch_xla._XLAC._init_computation_client() # Setup import folders. xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) @@ -58,36 +61,6 @@ def test_random_op_different_result_each_run(self): self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3)) -class DynamErrorMessageTest(unittest.TestCase): - - def test_cpu_tensor(self): - device = xm.xla_device() - input = torch.randn(4, 3, 224, 224) - input_xla = input.clone().to(device) - resnet18 = torchvision.models.resnet18() - resnet18.eval() - xla_resnet18 = torchvision.models.resnet18() - xla_resnet18.to(device) - xla_resnet18.eval() - dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla') - dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') - # input on cpu and model weight on xla - with self.assertRaises(Exception) as context: - res = dynamo_resnet18(input) - self.assertTrue( - 'found two different devices' in context.exception.__str__()) - # input on xla and model weight on cpu - with self.assertRaises(Exception) as context: - res = dynamo_resnet18_cpu(input_xla) - self.assertTrue( - 'found two different devices' in context.exception.__str__()) - # input and model weight on cpu - with self.assertRaises(Exception) as context: - res = dynamo_resnet18_cpu(input) - self.assertTrue( - 'please move all tensors to XLA device' in context.exception.__str__()) - - class DynamoInferenceBasicTest(unittest.TestCase): @classmethod @@ -123,6 +96,20 @@ def test_simple_model(self): res_cpu_3 = self.fn_simple(x + y, y * 3) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu())) + def test_fn_without_input(self): + + def fn_without_input(device): + constant = 0.835 + expanded = torch.full((4, 4), constant, device=device) + arange = torch.arange(16, device=device).reshape(4, 4) + return expanded + arange + + device = xm.xla_device() + compiled_fn = torch.compile(fn_without_input, backend='openxla') + res_cpu = fn_without_input('cpu') + res_xla_dynamo = compiled_fn(device) + self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) + def test_simple_model_with_in_place_ops(self): class TestModel(nn.Module): @@ -286,21 +273,21 @@ def fn_fallback(t): xla_dynamo_res = dynamo_fn(t_xla) self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu())) self.assertEqual(met.metric_data('CompileTime')[0], 3) - self.assertEqual(met.metric_data('ExecuteTime')[0], 10) + self.assertEqual(met.metric_data('ExecuteTime')[0], 11) # Second tracing met.clear_counters() xla_dynamo_res_2 = dynamo_fn(t_xla) self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu())) self.assertEqual(met.metric_data('CompileTime')[0], 3) - self.assertEqual(met.metric_data('ExecuteTime')[0], 12) + self.assertEqual(met.metric_data('ExecuteTime')[0], 13) # Verify that dynamo can handle different inputs xla_dynamo_res_3 = dynamo_fn(t_xla * 3) cpu_res_3 = fn_fallback(t * 3) self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu())) self.assertEqual(met.metric_data('CompileTime')[0], 4) - self.assertEqual(met.metric_data('ExecuteTime')[0], 15) + self.assertEqual(met.metric_data('ExecuteTime')[0], 16) class DynamoTrainingBasicTest(unittest.TestCase): @@ -516,6 +503,48 @@ def test_resnet18(self): met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3) +class DynamErrorMessageTest(unittest.TestCase): + + def test_mixed_cpu_tensor(self): + device = xm.xla_device() + input = torch.randn(4, 3, 224, 224) + input_xla = input.clone().to(device) + resnet18 = torchvision.models.resnet18() + resnet18.eval() + xla_resnet18 = torchvision.models.resnet18() + xla_resnet18.to(device) + xla_resnet18.eval() + dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla') + dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') + # input on cpu and model weight on xla + with self.assertRaises(Exception) as context: + res = dynamo_resnet18(input) + self.assertTrue( + 'found two different devices' in context.exception.__str__()) + # input on xla and model weight on cpu + with self.assertRaises(Exception) as context: + res = dynamo_resnet18_cpu(input_xla) + self.assertTrue( + 'found two different devices' in context.exception.__str__()) + + def test_all_cpu_tensor(self): + met.clear_all() + input = torch.randn(4, 3, 224, 224) + resnet18 = torchvision.models.resnet18() + resnet18.eval() + dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') + # input and model weight on cpu + with warnings.catch_warnings(record=True) as w: + res = dynamo_resnet18_cpu(input) + # there should be 18 paramters + 1 input + self.assertGreater(len(w), 15) + self.assertIn('Found tensor with shape torch.Size', str(w[0].message)) + # no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU + # ops to CPU. + self.assertEqual(len(met.counter_names()), 1) + self.assertIn('MarkStep', met.counter_names()) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 1d061f0d400..8ea4db3e051 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -6,6 +6,7 @@ import torch_xla import torch_xla.core.xla_model as xm from torch_xla import runtime as xr +from torch_xla.amp import autocast import test_xla_sharding_base @@ -112,6 +113,25 @@ def test_runtime_spmd_api(self): os.environ["XLA_USE_SPMD"] = "1" +class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'], + f"TPU/GPU autocast test.") + def test_xla_autocast_api(self): + device = xm.xla_device() + t1 = torch.ones([2, 3], device=device, dtype=torch.float32) + t2 = torch.ones([3, 2], device=device, dtype=torch.float32) + with autocast(device, dtype=torch.bfloat16): + t3 = torch.matmul(t1, t2) + expected_dtype = torch.bfloat16 if xr.is_bf16_supported() else torch.float16 + self.assertTrue(t3.dtype == expected_dtype) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_autocast.py b/test/test_autocast.py index 3c801068df8..837b6c623c8 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -7,7 +7,6 @@ sys.argv = [sys.argv[0]] + leftovers import torch -import torch_xla import torch_xla.core.xla_model as xm import collections import unittest @@ -152,6 +151,48 @@ def __init__(self, dev): self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] +class AutocastCudaTestExtraLists(object): + + def __init__(self, dev): + super().__init__() + n = 8 + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + + # This is currently not part of AutocastTestLists and excludes `relu`, `addbmm` + self.torch_bf16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), + torch.randn((5, 3, 5), device=dev, dtype=torch.float32), + torch.randn(5, device=dev, dtype=torch.float32), 0)), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("prelu", pointwise0_fp32 + element0_fp32), + ] + + class AutocastCudaTestUnsupportedLists(object): def __init__(self): @@ -301,8 +342,10 @@ class TestAutocastCuda(TestAutocastBase): def setUp(self): super(TestAutocastCuda, self).setUp() - self.is_autocast_enabled = torch.is_autocast_enabled + self.is_autocast_enabled = torch.is_autocast_xla_enabled self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + self.autocast_lists_extra = AutocastCudaTestExtraLists( + torch.device(xm.xla_device())) self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() def test_autocast_nn_fp16(self): @@ -334,6 +377,17 @@ def test_autocast_torch_fp32(self): self._run_autocast_outofplace( op, args, torch.float32, add_kwargs=maybe_kwargs) + def test_autocast_torch_bf16(self): + bf16_test_list = [ + tp for tp in getattr(self.autocast_lists_extra, 'torch_bf16') + ] + for op_with_args in bf16_test_list: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + # Expects float16, following the torch GPU autocast policy: + # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.cpp + self._run_autocast_outofplace( + op, args, torch.float16, add_kwargs=maybe_kwargs) + def test_autocast_torch_need_autocast_promote(self): for op, args in self.get_autocast_list('torch_need_autocast_promote'): self._run_autocast_outofplace(op, args, torch.float32) diff --git a/test/test_operations.py b/test/test_operations.py index cd8eb8afd5d..d47f8f36138 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2074,6 +2074,28 @@ def test_multi_init_xla_backend(self): self.assertEqual(met.counter_value("RegisterXLAFunctions"), 1) +# Only fails in CI https://github.com/pytorch/xla/pull/5431 +@unittest.skip +class TestLoweringContext(test_utils.XlaTestCase): + + def test_api(self): + device = xm.xla_device() + a = torch.rand(10, device=device) + b = torch.rand(10, device=device) + xm.mark_step() + + result = a + b + + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([result]) + hlo = ctx.hlo() + hlo_text = ctx.hlo_text() + self.assertTrue('opcode: "parameter"' in hlo_text) + self.assertTrue('opcode: "add"' in hlo_text) + mapping = ctx.parameter_id_tensor_mapping() + self.assertEqual(len(mapping), 2) + + class TestGeneric(test_utils.XlaTestCase): def test_zeros_like_patch(self): diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 8c0a5c8ee4b..9e7533955e4 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -142,7 +142,7 @@ def run_multiprocess(fn: Callable[..., R], num_processes = gpu.num_local_processes() gpu.initialize_distributed_runtime(num_processes) elif runtime.device_type() == 'NEURON': - num_processes = neuron.num_local_devices() + num_processes = neuron.num_local_processes() else: num_processes = 1 diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index eacc11008a8..222b5dee348 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,41 +1,84 @@ import torch import torch_xla.core.xla_model as xm +from torch_xla import runtime as xr from typing import Any class autocast(torch.amp.autocast_mode.autocast): r""" - See :class:`torch.autocast`. - ``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to - ``torch.autocast("xla", **kwargs)`` for TPUs - ``torch.autocast("cuda", **kwargs)`` for GPUs. - """ + `torch.autocast` for XLA backend devices. See :class:`torch.autocast`. + ``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to + ``torch.autocast("xla", **kwargs)`` for XLA:GPU and XLA:TPU for dtype torch.bfloat16, + ``torch.autocast("cuda", **kwargs)`` for XLA:GPU and other dtypes. + """ def __init__(self, device, enabled: bool = True, dtype: torch.dtype = None, cache_enabled: bool = True): - if xm.xla_device_hw(device) == 'GPU': + # `torch_xla.amp.autocast` is intended for XLA backend, with AutocastXLA dispatch key. + assert 'xla' in device.__str__( + ), "torch_xla.autocast is available for XLA:TPU, XLA:GPU" + + self._enabled = enabled + self._xla_device = xm.xla_device_hw(device) + if self._xla_device == 'GPU': + backend = 'cuda' if dtype is None: dtype = torch.float16 + elif dtype == torch.bfloat16: + if xr.is_bf16_supported() and not torch.cuda.is_available(): + # XLA:GPU with bfloat16 should run on `xla` backend + # unless torch.autocast is compiled with cuda. + backend = 'xla' + else: + # This has been the default behavior for unsupported bfloat16 dtype + dtype = torch.float16 + error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n" + error_message += ("Using the default cuda autocast dtype float16.") + self._dtype = dtype super().__init__( - "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) - elif xm.xla_device_hw(device) == 'TPU': + backend, + enabled=enabled, + dtype=self._dtype, + cache_enabled=cache_enabled) + elif self._xla_device == 'TPU': if dtype is None: dtype = torch.bfloat16 + if dtype != torch.bfloat16: + error_message = "In XLA:TPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "TPU Autocast only supports dtype of torch.bfloat16 currently.") + warnings.warn(error_message) + enabled = False + self._dtype = dtype super().__init__( - "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + "xla", + enabled=enabled, + dtype=self._dtype, + cache_enabled=cache_enabled) else: print( 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' ) def __enter__(self): + # This ensures that xla autocast is enabled even for XLA:GPU, which calls + # `torch.amp.autocast_mode.autocast` with `cuda` backend. + if self._xla_device == 'GPU': + self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] + self.prev_dtype = torch.get_autocast_xla_dtype( + ) # type: ignore[attr-defined] + torch.set_autocast_xla_enabled(self._enabled) + torch.set_autocast_xla_dtype(self._dtype) return super().__enter__() def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if self._xla_device == 'GPU': + torch.set_autocast_xla_enabled(self.prev) + torch.set_autocast_xla_dtype(self.prev_dtype) return super().__exit__(exc_type, exc_val, exc_tb) def __call__(self, func): diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index b2fcceb8155..b1d796ca04d 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -1,6 +1,7 @@ import copy import dataclasses import operator +import warnings import functools import itertools @@ -17,7 +18,7 @@ import torch_xla.runtime as xr import torch_xla.utils.utils as xu -debug = os.environ.get("TORCH_XLA_DEBUG") == "1" +debug = os.environ.get("XLA_DYNAMO_DEBUG") == "1" @dataclasses.dataclass @@ -321,6 +322,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): def extract_internal(xla_model: torch.fx.GraphModule): + if debug: + for xla_arg in xla_model.xla_args: + print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) + xm.mark_step() (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) @@ -368,7 +373,6 @@ def optimized_mod(*args): if len(args_and_out) == 0: return () - assert len(args) > 0 # can not handle no args case for now graph_input = graph_input_matcher(args) start_ts = time.time() res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input) @@ -457,8 +461,10 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): for xla_arg in xla_args: if xla_arg.device.type != 'xla': - raise RuntimeError( - 'For openxla dynamo backend, please move all tensors to XLA device') + warnings.warn( + "Found tensor with shape " + str(xla_arg.size()) + " on " + + str(xla_arg.device) + + ". Please move all tensors to xla device to execute on XLA device.") cloned_args = [ torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg @@ -469,6 +475,23 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): collector = FallBackNodeCollector(xla_model) collector.run(*xla_args) fallback_ops = collector.get_fallback_ops() + if debug and len(fallback_ops) > 0: + print('fallback ops are' + str(fallback_ops)) + + # This logic, needed for supporting in-place operations, is a duplicate of + # the one in the main `extract_internal` function above. We need to do this + # check for fetching fallback ops as well. + # TODO (@wonjoo): Make this duplicate code a bit cleaner. + args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( + all_xla_args) + + # Again, same logic in the `extract_internal` above to support in-place operations. + # TODO (@wonjoo): Make this duplicate code a bit cleaner. + for i, need_update in enumerate(args_need_update_bool): + if need_update and isinstance(all_xla_args[i], torch.Tensor): + all_xla_args[i].copy_(cloned_args[i]) + + torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): @@ -491,21 +514,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: partitioned_graph = partitioner.fuse_partitions(partitions) InputCollector(partitioned_graph).run(*xla_args) - # This logic, needed for supporting in-place operations, is a duplicate of - # the one in the main `extract_internal` function above. We need to do this - # check for fetching fallback ops as well. - # TODO (@wonjoo): Make this duplicate code a bit cleaner. - args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - all_xla_args) - - # Again, same logic in the `extract_internal` above to support in-place operations. - # TODO (@wonjoo): Make this duplicate code a bit cleaner. - for i, need_update in enumerate(args_need_update_bool): - if need_update and isinstance(all_xla_args[i], torch.Tensor): - all_xla_args[i].copy_(cloned_args[i]) - - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) - # compile each submodule and replace it with a call for node in partitioned_graph.graph.nodes: if node.op == "call_module" and "fused_" in node.name: diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d206d52eb13..fa9459f336b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,7 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/ops/device_data.h" +#include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/mesh_service.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" @@ -762,6 +764,154 @@ void BuildProfilerSubmodule(py::module* m) { }); } +class PyLoweringContext { + public: + PyLoweringContext() : PyLoweringContext(GetCurrentDevice()) {} + + PyLoweringContext(torch::lazy::BackendDevice device) + : lowering_ctx("PyLoweringContext", device) {} + + // Builds a HLO graph given a set of output tensors. + void Build(std::vector tensors) { + // Get the backing XLA tensors from the output torch tensor handles + std::vector xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + + // Get the lazy IR value from the output XLA tensors + std::vector ir_values; + for (auto& xtensor : xtensors) { + torch::lazy::Value value = xtensor->GetIrValue(); + ir_values.push_back(value); + } + + // Lower the graph using the output IR values + for (auto& ir_value : ir_values) { + xla::XlaOp root = lowering_ctx.GetOutputOp( + torch::lazy::Output(ir_value.node.get(), ir_value.index)); + lowering_ctx.AddResult(root); + } + computation = ConsumeValue(lowering_ctx.BuildXla()); + } + + // Get a mapping from the HLO input parameters to the backing Tensor values. + // This allows the caller to get all parameter information regardless of + // how the parameter was allocated (inline tensor, nn.Parameter, constant, + // etc.) + std::unordered_map GetParameterIdTensorMapping() { + // Find parameters in the lowering + const std::vector& param_ids = lowering_ctx.GetParameterSequence(); + const std::vector& device_data = + lowering_ctx.GetParametersData(); + + // Fetch this parameter data + std::vector literals = + runtime::GetComputationClient()->TransferFromServer( + UnwrapXlaData(device_data)); + + // Create a mapping from paramater id to the tensor data + std::unordered_map results; + for (int i = 0; i < device_data.size(); ++i) { + xla::Literal& literal = literals[i]; + xla::XlaOp op = lowering_ctx.GetParameter(device_data[i]); + at::ScalarType dtype = + TensorTypeFromXlaType(literal.shape().element_type()); + at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); + results[param_ids[i]] = input; + } + return results; + } + + // Get the parameter identifier of a given tensor. If the tensor is not a + // parameter this will always return -1. This is useful in conjunction with + // GetParameterIdTensorMapping to identify which values can be baked into + // the graph and which values must remain parameters. + int64_t GetTensorParameterId(at::Tensor tensor) { + // Convert tensor into the backing lazy node + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + torch::lazy::Value value = xtensor->GetIrValue(); + const torch::lazy::Node* node = value.node.get(); + if (node->op() != xla_device_data) { + return -1; + } + + // Convert lazy node data into opaque handle id + torch::lazy::BackendDataPtr data = DeviceData::Cast(node)->data(); + torch::lazy::BackendData::Handle handle = data->GetHandle(); + + // Linearly search parameters and compare opaque handles + const std::vector& param_ids = lowering_ctx.GetParameterSequence(); + const std::vector& device_data = + lowering_ctx.GetParametersData(); + for (int i = 0; i < device_data.size(); ++i) { + if (device_data[i]->GetHandle() == handle) { + return param_ids[i]; + } + } + return -1; + } + + // Create a serialized HloModule protobuf from a lowered graph + py::bytes GetHlo() { + const xla::HloModuleProto& proto = computation.proto(); + std::string result; + proto.SerializeToString(&result); + return result; + } + + // Create human-readable HloModule protobuf text from a lowered graph + std::string GetHloText() { + const xla::HloModuleProto& proto = computation.proto(); + std::string result; + google::protobuf::TextFormat::PrintToString(proto, &result); + return result; + } + + private: + LoweringContext lowering_ctx; + xla::XlaComputation computation; +}; + +// Add a submodule which exposes the LoweringContext to python. +void BuildLoweringContextSubmodule(py::module* m) { + /** + * Example Python Usage: + * + * import torch + * import torch_xla + * import torch_xla.core.xla_model as xm + * + * device = xm.xla_device() + * example = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) + * + * def network(x): + * return x + 2.0 + * + * result = network(example) + * + * ctx = torch_xla._XLAC.lowering.LoweringContext() + * ctx.build([result]) + * hlo = ctx.hlo() + * hlo_text = ctx.hlo_text() + * mapping = ctx.parameter_id_tensor_mapping() + * input_parameter_id = ctx.tensor_parameter_id(example) + * + **/ + + py::module lowering = + m->def_submodule("lowering", "Lowering context and utilities"); + + py::class_> + lowering_context_class(lowering, "LoweringContext", py::module_local()); + + lowering_context_class.def(py::init<>()) + .def("build", &PyLoweringContext::Build) + .def("hlo", &PyLoweringContext::GetHlo) + .def("hlo_text", &PyLoweringContext::GetHloText) + .def("parameter_id_tensor_mapping", + &PyLoweringContext::GetParameterIdTensorMapping) + .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); +} + void InitXlaModuleBindings(py::module m) { m.def("_prepare_to_exit", []() { PrepareToExit(); }); m.def("_get_git_revs", []() { return GetRevisions(); }); @@ -1716,6 +1866,7 @@ void InitXlaModuleBindings(py::module m) { }); BuildProfilerSubmodule(&m); + BuildLoweringContextSubmodule(&m); m.def("_get_tensors_handle", [](const std::vector& tensors) -> std::vector { diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 357edc18e7e..b7ac89fa765 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -56,23 +56,25 @@ xla::XlaOp MakeSeed(xla::XlaOp seed) { return xla::ConvertElementType(seed, xla::PrimitiveType::U64); } -xla::XlaOp MakeUniformBoundaryValue(xla::XlaOp val) { +xla::XlaOp MakeUniformBoundaryValue(xla::XlaOp val, bool downcast = false) { xla::PrimitiveType element_type = XlaHelpers::TypeOfXlaOp(val); if (element_type == xla::PrimitiveType::BF16 || element_type == xla::PrimitiveType::F16) { - return xla::ConvertElementType(val, xla::PrimitiveType::F32); + auto dtype = downcast ? xla::PrimitiveType::F16 : xla::PrimitiveType::F32; + return xla::ConvertElementType(val, dtype); } else if (xla::primitive_util::IsComplexType(element_type)) { return xla::Real(val); } return val; } -xla::Shape MakeRngShape(const xla::Shape& shape) { +xla::Shape MakeRngShape(const xla::Shape& shape, bool downcast = false) { xla::PrimitiveType element_type = shape.element_type(); xla::Shape rng_shape(shape); if (element_type == xla::PrimitiveType::BF16 || element_type == xla::PrimitiveType::F16) { - rng_shape.set_element_type(xla::PrimitiveType::F32); + auto dtype = downcast ? xla::PrimitiveType::F16 : xla::PrimitiveType::F32; + rng_shape.set_element_type(dtype); } else if (xla::primitive_util::IsComplexType(element_type)) { rng_shape.set_element_type( xla::primitive_util::ComplexComponentType(element_type)); @@ -106,11 +108,11 @@ xla::XlaOp RngDiscreteUniform(xla::XlaOp seed, const xla::Shape& shape, } xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape, - xla::XlaOp minval, xla::XlaOp maxval) { + xla::XlaOp minval, xla::XlaOp maxval, bool downcast) { xla::XlaOp rng_seed = MakeSeed(seed); - xla::Shape rng_shape = MakeRngShape(shape); - xla::XlaOp rng_minval = MakeUniformBoundaryValue(minval); - xla::XlaOp rng_maxval = MakeUniformBoundaryValue(maxval); + xla::Shape rng_shape = MakeRngShape(shape, downcast); + xla::XlaOp rng_minval = MakeUniformBoundaryValue(minval, downcast); + xla::XlaOp rng_maxval = MakeUniformBoundaryValue(maxval, downcast); xla::XlaOp initial_state = xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64); switch (shape.element_type()) { diff --git a/torch_xla/csrc/random.h b/torch_xla/csrc/random.h index 252e89ddaee..86a36936f25 100644 --- a/torch_xla/csrc/random.h +++ b/torch_xla/csrc/random.h @@ -5,8 +5,12 @@ namespace torch_xla { +// Set downcast to true if the caller knows the |maxval - minval| is appropriate +// for f16 dtype. We avoid computing the range on-the-fly since it incurs an XLA +// computation. xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape, - xla::XlaOp minval, xla::XlaOp maxval); + xla::XlaOp minval, xla::XlaOp maxval, + bool downcast = false); xla::XlaOp RngDiscreteUniform(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp minval, xla::XlaOp maxval); @@ -16,4 +20,4 @@ xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean, } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_RANDOM_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_RANDOM_H_ diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 269cc023eba..899539389b8 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -496,6 +496,9 @@ void XLAGraphExecutor::ClearPendingIrs( runtime::GetComputationClient()->CreateDataPlaceholder( device.toString(), std::move(shape))); tensors[i]->data()->handle = handle; + TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString() + << " of Tensor with ID " << tensors[i]->GetUniqueId() + << " with placeholder"; } tensors[i]->AssignIrValue(torch::lazy::Value()); tensors[i]->data()->view = nullptr; diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 22c551282ac..ba853f9f5ed 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -476,7 +476,8 @@ xla::XlaOp BuildBernoulli(xla::XlaOp probability, xla::XlaOp seed, xla::Zero(probability.builder(), probability_shape.element_type()); xla::XlaOp one = xla::One(probability.builder(), probability_shape.element_type()); - xla::XlaOp noise = RngUniform(seed, probability_shape, zero, one); + xla::XlaOp noise = + RngUniform(seed, probability_shape, zero, one, /*downcast=*/true); return xla::ConvertElementType(xla::Lt(noise, probability), type); } diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 7805a5580a9..9a1bd58e34f 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -88,6 +88,16 @@ def wrapper(*args, **kwargs): return wrapper +def is_bf16_supported(): + """Returns whether torch.bfloat16 is supported on this environment. + """ + try: + torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device()) + return True + except Exception as e: + return False + + @requires_pjrt def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: