diff --git a/torch_xla/distributed/fsdp/utils.py b/torch_xla/distributed/fsdp/utils.py index 01bfa5c2837..56615f6a7cf 100644 --- a/torch_xla/distributed/fsdp/utils.py +++ b/torch_xla/distributed/fsdp/utils.py @@ -64,13 +64,13 @@ class DummyReduceScatter: """A dummy op for debugging with the same output shape as reduce_scatter""" def __init__(self, shard_count): - assert shard_count == xm.xrt_world_size() + assert shard_count == xr.world_size() self.scale = 1.0 def __call__(self, input, callback): full_size = input.size(0) - shard_size = full_size // xm.xrt_world_size() - begin = shard_size * xm.get_ordinal() + shard_size = full_size // xr.world_size() + begin = shard_size * xr.global_ordinal() end = begin + shard_size slices = [None] * input.dim() slices[0] = slice(begin, end) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index 26b55dc4cf7..abddb3511ee 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -413,8 +413,8 @@ def __init__( # FSDP data parallelism with model parallelism (e.g. Megatron) self.sharding_groups = sharding_groups if sharding_groups is None: - self.rank = xm.get_ordinal() - self.world_size = xm.xrt_world_size() + self.rank = xr.global_ordinal() + self.world_size = xr.world_size() else: if sharding_rank is None or sharding_world_size is None: raise ValueError(