Skip to content

Commit

Permalink
Address comments -- fix typos and variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Nov 7, 2023
1 parent a98bfb2 commit 1808663
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,10 +1698,10 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> time_assignment_list =
c10::List<at::IntArrayRef> tile_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
time_assignment_list.push_back(
tile_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

Expand All @@ -1720,7 +1720,7 @@ void InitXlaModuleBindings(py::module m) {
}

xla_mark_sharding_dynamo_custom_op(
input, time_assignment_list, group_assignment_list,
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
Expand Down
24 changes: 13 additions & 11 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_axis_name_idx(self, name: str) -> int:
@functools.lru_cache(maxsize=None)
def get_op_sharding(self,
partition_spec: Tuple,
flatten=False) -> torch_xla._XLAC.OpSharding:
flatten_opsharding = False) -> torch_xla._XLAC.OpSharding:
"""
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
Expand All @@ -107,7 +107,7 @@ def get_op_sharding(self,
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

# If flatten = True, return the flattened version of OpSharding
if flatten:
if flatten_opsharding:
return (tile_assignment.tolist(), group_assignment, replication_groups,
int(sharding_type))
else:
Expand Down Expand Up @@ -459,7 +459,7 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
mesh: Mesh,
partition_spec: Tuple[Union[Tuple, int, str, None]],
dynamo_custom_op: bool = False) -> XLAShardedTensor:
use_dynamo_custom_op: bool = False) -> XLAShardedTensor:
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Expand Down Expand Up @@ -508,7 +508,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

if dynamo_custom_op:
if use_dynamo_custom_op:
tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(
partition_spec, flatten=True)

Expand All @@ -517,19 +517,21 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
t.global_tensor, tile_assignment, group_assignment,
replication_groups, sharding_type)
return t
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment,
group_assignment,
replication_groups,
sharding_type)
return XLAShardedTensor(t)
else:
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment,
group_assignment,
replication_groups,
sharding_type)
return XLAShardedTensor(t)
else:
op_sharding = mesh.get_op_sharding(partition_spec)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
return t
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)
else:
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)


def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
Expand Down

0 comments on commit 1808663

Please sign in to comment.