Skip to content

Commit

Permalink
Cherry-pick 2.1 release branch into xrt through 9/18 (#5607)
Browse files Browse the repository at this point in the history
* Handle dynamo function without input (#5565) (#5577)

* Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576)

* [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580)

* Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron

* fixing linter issues

* fixing spacing

* apply comments and fix compilation errors

* add test for new apis

* fix linter

* update test

* update test

* modify test

* reverse back to GetIrValue()

* update test inputs with random numbers

* skip unittest because it only fails in CI

---------

Co-authored-by: aws-kingrj <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: seanlatias <[email protected]>

* fixing num_local_processes typo (#5573) (#5579)

Co-authored-by: aws-kingrj <[email protected]>

* Move where clear pending IR is called to avoid crash (#5552) (#5582)

* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages

* Fix release branch and tag patterns for GitHub Actions (#5587) (#5590)

* Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589)

* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <[email protected]>

* Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591)

* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <[email protected]>

* Cherry-pick: Don't trigger CI build on release tag push (#5595)

Copy of #5594 on release branch

* formatting

---------

Co-authored-by: JackCaoG <[email protected]>
Co-authored-by: Wonjoo Lee <[email protected]>
Co-authored-by: aws-kingrj <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: seanlatias <[email protected]>
Co-authored-by: Manfei <[email protected]>
Co-authored-by: Yeounoh Chung <[email protected]>
  • Loading branch information
8 people authored Sep 19, 2023
1 parent 7c32c0f commit 7c261ae
Show file tree
Hide file tree
Showing 14 changed files with 423 additions and 77 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ on:
push:
branches:
- master
tags:
- r[0-9]+.[0-9]+
workflow_dispatch:

Expand Down Expand Up @@ -44,7 +43,7 @@ jobs:
with:
docker-image: ${{ needs.build.outputs.docker-image }}
runner: linux.8xlarge.nvidia.gpu
timeout-minutes: 240
timeout-minutes: 300
disable-xrt: 1
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
Expand Down
95 changes: 62 additions & 33 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch._dynamo as dynamo
import torchvision
import unittest
import warnings

torch_xla._XLAC._init_computation_client()

# Setup import folders.
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
Expand Down Expand Up @@ -58,36 +61,6 @@ def test_random_op_different_result_each_run(self):
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))


class DynamErrorMessageTest(unittest.TestCase):

def test_cpu_tensor(self):
device = xm.xla_device()
input = torch.randn(4, 3, 224, 224)
input_xla = input.clone().to(device)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18()
xla_resnet18.to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
# input on cpu and model weight on xla
with self.assertRaises(Exception) as context:
res = dynamo_resnet18(input)
self.assertTrue(
'found two different devices' in context.exception.__str__())
# input on xla and model weight on cpu
with self.assertRaises(Exception) as context:
res = dynamo_resnet18_cpu(input_xla)
self.assertTrue(
'found two different devices' in context.exception.__str__())
# input and model weight on cpu
with self.assertRaises(Exception) as context:
res = dynamo_resnet18_cpu(input)
self.assertTrue(
'please move all tensors to XLA device' in context.exception.__str__())


class DynamoInferenceBasicTest(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -123,6 +96,20 @@ def test_simple_model(self):
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))

def test_fn_without_input(self):

def fn_without_input(device):
constant = 0.835
expanded = torch.full((4, 4), constant, device=device)
arange = torch.arange(16, device=device).reshape(4, 4)
return expanded + arange

device = xm.xla_device()
compiled_fn = torch.compile(fn_without_input, backend='openxla')
res_cpu = fn_without_input('cpu')
res_xla_dynamo = compiled_fn(device)
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

def test_simple_model_with_in_place_ops(self):

class TestModel(nn.Module):
Expand Down Expand Up @@ -286,21 +273,21 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
self.assertEqual(met.metric_data('ExecuteTime')[0], 11)

# Second tracing
met.clear_counters()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
self.assertEqual(met.metric_data('ExecuteTime')[0], 13)

# Verify that dynamo can handle different inputs
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 15)
self.assertEqual(met.metric_data('ExecuteTime')[0], 16)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down Expand Up @@ -516,6 +503,48 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamErrorMessageTest(unittest.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
input = torch.randn(4, 3, 224, 224)
input_xla = input.clone().to(device)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18()
xla_resnet18.to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
# input on cpu and model weight on xla
with self.assertRaises(Exception) as context:
res = dynamo_resnet18(input)
self.assertTrue(
'found two different devices' in context.exception.__str__())
# input on xla and model weight on cpu
with self.assertRaises(Exception) as context:
res = dynamo_resnet18_cpu(input_xla)
self.assertTrue(
'found two different devices' in context.exception.__str__())

def test_all_cpu_tensor(self):
met.clear_all()
input = torch.randn(4, 3, 224, 224)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
# input and model weight on cpu
with warnings.catch_warnings(record=True) as w:
res = dynamo_resnet18_cpu(input)
# there should be 18 paramters + 1 input
self.assertGreater(len(w), 15)
self.assertIn('Found tensor with shape torch.Size', str(w[0].message))
# no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU
# ops to CPU.
self.assertEqual(len(met.counter_names()), 1)
self.assertIn('MarkStep', met.counter_names())


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
20 changes: 20 additions & 0 deletions test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from torch_xla.amp import autocast

import test_xla_sharding_base

Expand Down Expand Up @@ -112,6 +113,25 @@ def test_runtime_spmd_api(self):
os.environ["XLA_USE_SPMD"] = "1"


class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'],
f"TPU/GPU autocast test.")
def test_xla_autocast_api(self):
device = xm.xla_device()
t1 = torch.ones([2, 3], device=device, dtype=torch.float32)
t2 = torch.ones([3, 2], device=device, dtype=torch.float32)
with autocast(device, dtype=torch.bfloat16):
t3 = torch.matmul(t1, t2)
expected_dtype = torch.bfloat16 if xr.is_bf16_supported() else torch.float16
self.assertTrue(t3.dtype == expected_dtype)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
58 changes: 56 additions & 2 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
sys.argv = [sys.argv[0]] + leftovers

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import collections
import unittest
Expand Down Expand Up @@ -152,6 +151,48 @@ def __init__(self, dev):
self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)]


class AutocastCudaTestExtraLists(object):

def __init__(self, dev):
super().__init__()
n = 8
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]

mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)

pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)

element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)

# This is currently not part of AutocastTestLists and excludes `relu`, `addbmm`
self.torch_bf16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
torch.randn(5, device=dev, dtype=torch.float32), 0)),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("prelu", pointwise0_fp32 + element0_fp32),
]


class AutocastCudaTestUnsupportedLists(object):

def __init__(self):
Expand Down Expand Up @@ -301,8 +342,10 @@ class TestAutocastCuda(TestAutocastBase):

def setUp(self):
super(TestAutocastCuda, self).setUp()
self.is_autocast_enabled = torch.is_autocast_enabled
self.is_autocast_enabled = torch.is_autocast_xla_enabled
self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device()))
self.autocast_lists_extra = AutocastCudaTestExtraLists(
torch.device(xm.xla_device()))
self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists()

def test_autocast_nn_fp16(self):
Expand Down Expand Up @@ -334,6 +377,17 @@ def test_autocast_torch_fp32(self):
self._run_autocast_outofplace(
op, args, torch.float32, add_kwargs=maybe_kwargs)

def test_autocast_torch_bf16(self):
bf16_test_list = [
tp for tp in getattr(self.autocast_lists_extra, 'torch_bf16')
]
for op_with_args in bf16_test_list:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
# Expects float16, following the torch GPU autocast policy:
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.cpp
self._run_autocast_outofplace(
op, args, torch.float16, add_kwargs=maybe_kwargs)

def test_autocast_torch_need_autocast_promote(self):
for op, args in self.get_autocast_list('torch_need_autocast_promote'):
self._run_autocast_outofplace(op, args, torch.float32)
Expand Down
22 changes: 22 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,28 @@ def test_multi_init_xla_backend(self):
self.assertEqual(met.counter_value("RegisterXLAFunctions"), 1)


# Only fails in CI https://github.com/pytorch/xla/pull/5431
@unittest.skip
class TestLoweringContext(test_utils.XlaTestCase):

def test_api(self):
device = xm.xla_device()
a = torch.rand(10, device=device)
b = torch.rand(10, device=device)
xm.mark_step()

result = a + b

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([result])
hlo = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertTrue('opcode: "parameter"' in hlo_text)
self.assertTrue('opcode: "add"' in hlo_text)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)


class TestGeneric(test_utils.XlaTestCase):

def test_zeros_like_patch(self):
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_multiprocess(fn: Callable[..., R],
num_processes = gpu.num_local_processes()
gpu.initialize_distributed_runtime(num_processes)
elif runtime.device_type() == 'NEURON':
num_processes = neuron.num_local_devices()
num_processes = neuron.num_local_processes()
else:
num_processes = 1

Expand Down
Loading

0 comments on commit 7c261ae

Please sign in to comment.