Skip to content

Commit

Permalink
fix: missing import numpy (#5533)
Browse files Browse the repository at this point in the history
* missing

* fix typo

---------

Co-authored-by: mochen.bmc <[email protected]>
  • Loading branch information
wbmc and mochen.bmc authored Sep 11, 2023
1 parent e51d28b commit b9d567a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou
## PyTorch/XLA SPMD Design Overview


### Simple Eexample & Sharding Aannotation API
### Simple Example & Sharding Aannotation API

Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.

Expand All @@ -42,6 +42,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partitio
Invoking `mark_sharding` API takes a user defined logical [mesh](#mesh) and [partition\_spec](#partition-spec) and generates a sharding annotation for the XLA compiler. The sharding spec is attached to the XLATensor. Here is a simple usage example from the [[RFC](https://github.com/pytorch/xla/issues/3871), to illustrate how the sharding annotation API works:

```python
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
Expand Down

0 comments on commit b9d567a

Please sign in to comment.