From b9d567ad58919affc19e6d02ecbf5bd4eecdfb17 Mon Sep 17 00:00:00 2001 From: wbmc Date: Tue, 12 Sep 2023 01:09:14 +0800 Subject: [PATCH] fix: missing import numpy (#5533) * missing * fix typo --------- Co-authored-by: mochen.bmc --- docs/spmd.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/spmd.md b/docs/spmd.md index 69dd1f14853..8337bbd5af9 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -31,7 +31,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou ## PyTorch/XLA SPMD Design Overview -### Simple Eexample & Sharding Aannotation API +### Simple Example & Sharding Aannotation API Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output. @@ -42,6 +42,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partitio Invoking `mark_sharding` API takes a user defined logical [mesh](#mesh) and [partition\_spec](#partition-spec) and generates a sharding annotation for the XLA compiler. The sharding spec is attached to the XLATensor. Here is a simple usage example from the [[RFC](https://github.com/pytorch/xla/issues/3871), to illustrate how the sharding annotation API works: ```python +import numpy as np import torch import torch_xla.core.xla_model as xm import torch_xla.runtime as xr