Skip to content

Commit

Permalink
Move spmd out of experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Nov 10, 2023
1 parent 2ac9dc0 commit 558a894
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.spmd.xla_sharding as xs
from torch_xla.experimental.spmd import Mesh
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla.distributed.spmd import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()
Expand Down Expand Up @@ -104,7 +104,7 @@ We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028

```python
import torch_xla.runtime as xr
from torch_xla.experimental.spmd import Mesh
from torch_xla.distributed.spmd import Mesh

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
Expand Down
2 changes: 1 addition & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def sharding_spec(self):

@property
def sharding_type(self) -> 'ShardingType':
from torch_xla.experimental.spmd import ShardingType
from torch_xla.distributed.spmd import ShardingType
sharding_type = torch_xla._XLAC._get_xla_sharding_type(self.global_tensor)
return ShardingType(sharding_type)

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses

import torch
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd.xla_sharding as xs

from torch.distributed.checkpoint.planner import SavePlan
from typing import (
Expand All @@ -23,7 +23,7 @@
)
from torch.distributed.checkpoint.metadata import (MetadataIndex,
STATE_DICT_TYPE)
from torch_xla.experimental.spmd import XLAShardedTensor, ShardingType
from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType
from torch.utils._pytree import tree_map

PATH_ITEM = Union[str, int]
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

warnings.warn(
"Importing from `torch_xla.experimental.xla_sharded_tensor` will be deprecated "
"after 2.2 release. Please use `torch_xla.experimental.spmd` "
"after 2.2 release. Please use `torch_xla.distributed.spmd` "
"instead.", DeprecationWarning, 2)

from .spmd.xla_sharded_tensor import *
from torch_xla.distributed.spmd.xla_sharded_tensor import *
4 changes: 2 additions & 2 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

warnings.warn(
"Importing from `torch_xla.experimental.xla_sharding` will be deprecated "
"after 2.2 release. Please use `torch_xla.experimental.spmd` instead.",
"after 2.2 release. Please use `torch_xla.distributed.spmd` instead.",
DeprecationWarning, 2)

from .spmd.xla_sharding import *
from torch_xla.distributed.spmd.xla_sharding import *

0 comments on commit 558a894

Please sign in to comment.