Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2.5 #8220

Closed
wants to merge 11 commits into from
Closed

R2.5 #8220

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- id: commit
name: Get latest torch commit
run: |
echo "torch_commit=$(git ls-remote https://github.com/pytorch/pytorch.git HEAD | awk '{print $1}')" >> "$GITHUB_OUTPUT"
echo "torch_commit=$(git ls-remote https://github.com/pytorch/pytorch.git refs/heads/release/2.5 | awk '{print $1}')" >> "$GITHUB_OUTPUT"

build-torch-xla:
name: "Build PyTorch/XLA"
Expand Down
1 change: 1 addition & 0 deletions .torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
release/2.5
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'be7eef5742089e328152908b8662e83e34bf73c1'
xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4'

http_archive(
name = "xla",
Expand Down Expand Up @@ -139,4 +139,4 @@ xla_workspace0()
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")
load("@tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure")
nccl_configure(name = "local_config_nccl")
nccl_configure(name = "local_config_nccl")
26 changes: 22 additions & 4 deletions docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall
```python
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```

### Virtual Device Optimization
Expand Down
4 changes: 2 additions & 2 deletions infra/ansible/config/cuda_deps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
cuda_deps:
# List all libcudnn8 versions with `apt list -a libcudnn8`
libcudnn:
"12.4": libcudnn-cuda-12=9.1.1.17-1
"12.4": libcudnn9-cuda-12=9.1.1.17-1
"12.3": libcudnn9-cuda-12=9.0.0.312-1
"12.1": libcudnn8=8.9.2.26-1+cuda12.1
"12.0": libcudnn8=8.8.0.121-1+cuda12.0
"11.8": libcudnn8=8.7.0.84-1+cuda11.8
"11.7": libcudnn8=8.5.0.96-1+cuda11.7
"11.2": libcudnn8=8.1.1.33-1+cuda11.2
libcudnn-dev:
"12.4": libcudnn-dev-cuda-12=9.1.1.17-1
"12.4": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.3": libcudnn9-dev-cuda-12=9.0.0.312-1
"12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1
"12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240801'
_date = '20240916'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl'
_jax_version = f'0.4.32.dev{_date}'
_jax_version = f'0.4.33'


def _get_build_mode():
Expand Down
39 changes: 38 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
"Dynamo not supported on TPU v2/v3")
class TestDistCollectiveOpsTpu(parameterized.TestCase):
"""Test for collective ops from torch.distributed"""
Expand Down Expand Up @@ -246,6 +246,32 @@ def callable(output, input):
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@staticmethod
def _all_to_all_single(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(output, input):
dist.all_to_all_single(output, input)
return output

# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
# for input and output tensor example
tensor_in = torch.tensor(
[xr.local_ordinal()] * tpu.num_expected_global_devices(),
dtype=torch.float,
device=device)
tensor_out = torch.zeros_like(tensor_in)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_out, tensor_in)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllToAll' in met.counter_names()
else:
assert 'xla::all_to_all_single' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -287,6 +313,17 @@ def test_reduce_scatter(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected[index])

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_to_all_single(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_to_all_single, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float)
# Note: AllToAll xla op does not honor the order of the all_to_all, which means
# the rank may not follow the order.
for _, val in results.items():
self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
run_test "$CDIR/quantized_ops/test_dot_general.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
151 changes: 151 additions & 0 deletions test/spmd/test_mp_input_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import sys
import numpy as np
import unittest

import torch
import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl

xr.use_spmd()


class MpInputShardingTest(unittest.TestCase):

class fake_dataloader:

def __init__(self, batch, size=1):
self.batch = batch
self.batch_size = size
self.counter = 0

def __iter__(self):
return self

def __next__(self):
if self.counter < self.batch_size:
self.counter += 1
return self.batch
raise StopIteration

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_multiple_inputs(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={
'x': xs.ShardingSpec(mesh, ('x', None)),
'y': xs.ShardingSpec(mesh, ('x', None, None))
})
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_single_tensor(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_single_tensor_with_input_sharding_dict(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(ValueError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_none(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()

train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None)
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{replicated}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_missing_keys(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(KeyError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_not_dict(self):
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
29 changes: 29 additions & 0 deletions test/test_bf16_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

device = xm.xla_device()


class TestAutocastXla(unittest.TestCase):

def test_cross_entropy_loss(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(16, 10).to(torch.bfloat16).to(device)
with torch.autocast("xla"):
loss = torch.nn.CrossEntropyLoss()(data, target)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
self.assertTrue(
re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None)

self.assertTrue(
re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None)

self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None)


if __name__ == "__main__":
unittest.main()
Loading
Loading