Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Jan 16, 2024
1 parent b98ef93 commit 656f944
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
9 changes: 6 additions & 3 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ void TestSingleReplication(
instances.emplace_back(CreateCrsComputation(shape), device_str,
all_device_strings, &shape);
}
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr> compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(std::move(instances));
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(
std::move(instances));

std::vector<at::Tensor> tensors;
for (size_t i = 0; i < device_strings.size(); ++i) {
tensors.push_back(at::ones({8, 8}, at::TensorOptions(at::kFloat)));
}
std::vector<torch::lazy::BackendDataPtr> tensors_data = CreateTensorsData(tensors, device_strings);
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, device_strings);

std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
results(device_strings.size());
Expand Down
2 changes: 0 additions & 2 deletions test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@ def test_num_local_devices(self):
xr.addressable_device_count())
if xr.device_type() == 'CUDA':
self.assertEqual(self.num_cuda_devices, xr.addressable_device_count())


def test_num_global_devices(self):
self.assertLen(torch_xla._XLAC._xla_get_all_devices(),
xr.global_device_count())
if xr.device_type() == 'CUDA':
print('test_num_global_devices is run for cuda. self.num_cuda_devices=', self.num_cuda_devices)
self.assertEqual(self.num_cuda_devices, xr.global_device_count())

def test_world_size(self):
Expand Down
4 changes: 4 additions & 0 deletions test/pjrt/test_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def setUp(self):

def tearDown(self) -> None:
dist.destroy_process_group()

def test_addressable_device_count(self):
devices_per_process = xr.addressable_device_count()
self.assertEqual(devices_per_process, 1)

def test_all_gather(self):
dist_world_size = xu.getenv_as('WORLD_SIZE', int)
Expand Down
14 changes: 10 additions & 4 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,11 @@ def forward(self, x):
x = self.fc2(x)
return F.log_softmax(x, dim=1)

@unittest.skipIf(xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.')

@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorMNIST(test_utils.XlaTestCase):

def test(self):
Expand Down Expand Up @@ -255,8 +258,11 @@ def loop_fn(model, loader, device, context):
model_parallel = dp.DataParallel(XlaMNIST, device_ids=devices)
model_parallel(loop_fn, train_loader)

@unittest.skipIf(xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.')

@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorResnet18(test_utils.XlaTestCase):

def test(self):
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ InitializePjRt(const std::string& device_type) {
bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false);
if (!spmd) {
allowed_devices = std::set{local_process_rank};
}
else if (global_world_size > 1) {
} else if (global_world_size > 1) {
allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
std::make_optional<std::set<int>>(std::set{local_process_rank});
// Use the XlaCoordinator as the distributed key-value store.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
Expand Down

0 comments on commit 656f944

Please sign in to comment.