Skip to content

Commit

Permalink
Update doc from commit 370089a
Browse files Browse the repository at this point in the history
  • Loading branch information
torchxlabot2 committed Mar 14, 2024
1 parent bc4c6f7 commit 75a3e41
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 17 deletions.
2 changes: 1 addition & 1 deletion master/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/core/functions.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/core/xla_model.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/distributed/parallel_loader.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/utils/serialization.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/_modules/torch_xla/utils/utils.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
45 changes: 40 additions & 5 deletions master/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down Expand Up @@ -426,9 +426,10 @@
</li>
<li><a class="reference internal" href="#here-mesh-is-a-2x2-mesh-with-axes-x-and-y">Here, mesh is a 2x2 mesh with axes ‘x’ and ‘y’</a></li>
<li><a class="reference internal" href="#a-tensor-s-sharding-can-be-visualized-using-the-visualize-tensor-sharding-method">A tensor’s sharding can be visualized using the <code class="docutils literal notranslate"><span class="pre">visualize_tensor_sharding</span></code> method</a></li>
<li><a class="reference internal" href="#currently-model-should-be-loaded-to-xla-device-via-distribute-module">Currently, model should be loaded to xla device via distribute_module.</a></li>
<li><a class="reference internal" href="#fully-sharded-data-parallel-via-spmd">Fully Sharded Data Parallel via SPMD</a><ul>
<li><a class="reference internal" href="#sharding-output">Sharding output</a></li>
<li><a class="reference internal" href="#id41">Gradient checkpointing</a></li>
<li><a class="reference internal" href="#id50">Gradient checkpointing</a></li>
<li><a class="reference internal" href="#huggingface-llama-2-example">HuggingFace Llama 2 Example</a></li>
</ul>
</li>
Expand Down Expand Up @@ -3710,6 +3711,39 @@ <h1>A tensor’s sharding can be visualized using the <code class="docutils lite
</div>
<a class="reference external image-reference" href="assets/spmd_debug_2.png"><img alt="alt_text" src="assets/spmd_debug_2.png" /></a>
<p>You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style <code class="docutils literal notranslate"><span class="pre">tiled</span></code>, <code class="docutils literal notranslate"><span class="pre">partial_replication</span></code> and <code class="docutils literal notranslate"><span class="pre">replicated</span></code>.</p>
<p>We are introducing a new PyTorch/XLA SPMD feature, called <code class="docutils literal notranslate"><span class="pre">auto-sharding</span></code>, <a class="reference external" href="https://github.com/pytorch/xla/issues/6322">RFC</a>. This is an experimental feature in <code class="docutils literal notranslate"><span class="pre">r2.3</span></code> and <code class="docutils literal notranslate"><span class="pre">nightly</span></code>, that supports <code class="docutils literal notranslate"><span class="pre">XLA:TPU</span></code> and a single TPUVM host.</p>
<p>PyTorch/XLA auto-sharding can be enabled by one of the following:</p>
<ul>
<li><p>Setting envvar <code class="docutils literal notranslate"><span class="pre">XLA_SPMD_AUTO=1</span></code></p></li>
<li><p>Calling the SPMD API in the beginning of your code:
.. code-block:: python</p>
<blockquote>
<div><p>import torch_xla.runtime as xr
xr.use_spmd(auto=True)</p>
</div></blockquote>
</li>
<li><p>Calling <code class="docutils literal notranslate"><span class="pre">pytorch.distributed._tensor.distribute_module</span></code> with <code class="docutils literal notranslate"><span class="pre">auto-policy</span></code> and <code class="docutils literal notranslate"><span class="pre">xla</span></code>:
<a href="#id40"><span class="problematic" id="id41">``</span></a><a href="#id42"><span class="problematic" id="id43">`</span></a>python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy</p></li>
</ul>
<p>device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh(“xla”, list(range(device_count)))</p>
</div>
<div class="section" id="currently-model-should-be-loaded-to-xla-device-via-distribute-module">
<h1>Currently, model should be loaded to xla device via distribute_module.<a class="headerlink" href="#currently-model-should-be-loaded-to-xla-device-via-distribute-module" title="Permalink to this headline"></a></h1>
<p>model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
<a href="#id44"><span class="problematic" id="id45">``</span></a><a href="#id46"><span class="problematic" id="id47">`</span></a></p>
<p>Optionally, one can set the following options/env-vars to control the behvaior of
the XLA-based auto-sharding pass:</p>
<ul class="simple">
<li><p><code class="docutils literal notranslate"><span class="pre">XLA_AUTO_USE_GROUP_SHARDING</span></code>: group resharding of the parameters. Set by default.</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">XLA_AUTO_SPMD_MESH</span></code>: logical mesh shape to be used for auto-sharding. For example,
<code class="docutils literal notranslate"><span class="pre">XLA_AUTO_SPMD_MESH=2,2</span></code> corresponds to a 2-by-2 mesh with 4 global devices. If unset,
a default device mesh shape of <code class="docutils literal notranslate"><span class="pre">num_devices,1</span></code> will be used.</p></li>
</ul>
</div>
<div class="section" id="fully-sharded-data-parallel-via-spmd">
<h1>Fully Sharded Data Parallel via SPMD<a class="headerlink" href="#fully-sharded-data-parallel-via-spmd" title="Permalink to this headline"></a></h1>
Expand Down Expand Up @@ -3759,8 +3793,8 @@ <h2>Sharding output<a class="headerlink" href="#sharding-output" title="Permalin
</pre></div>
</div>
</div>
<div class="section" id="id41">
<h2>Gradient checkpointing<a class="headerlink" href="#id41" title="Permalink to this headline"></a></h2>
<div class="section" id="id50">
<h2>Gradient checkpointing<a class="headerlink" href="#id50" title="Permalink to this headline"></a></h2>
<p>Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases.</p>
<p>Example usage:</p>
<div class="highlight-python3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torch_xla.distributed.fsdp</span> <span class="kn">import</span> <span class="n">checkpoint_module</span>
Expand Down Expand Up @@ -3980,9 +4014,10 @@ <h2>HuggingFace Llama 2 Example<a class="headerlink" href="#huggingface-llama-2-
</li>
<li><a class="reference internal" href="#here-mesh-is-a-2x2-mesh-with-axes-x-and-y">Here, mesh is a 2x2 mesh with axes ‘x’ and ‘y’</a></li>
<li><a class="reference internal" href="#a-tensor-s-sharding-can-be-visualized-using-the-visualize-tensor-sharding-method">A tensor’s sharding can be visualized using the <code class="docutils literal notranslate"><span class="pre">visualize_tensor_sharding</span></code> method</a></li>
<li><a class="reference internal" href="#currently-model-should-be-loaded-to-xla-device-via-distribute-module">Currently, model should be loaded to xla device via distribute_module.</a></li>
<li><a class="reference internal" href="#fully-sharded-data-parallel-via-spmd">Fully Sharded Data Parallel via SPMD</a><ul>
<li><a class="reference internal" href="#sharding-output">Sharding output</a></li>
<li><a class="reference internal" href="#id41">Gradient checkpointing</a></li>
<li><a class="reference internal" href="#id50">Gradient checkpointing</a></li>
<li><a class="reference internal" href="#huggingface-llama-2-example">HuggingFace Llama 2 Example</a></li>
</ul>
</li>
Expand Down
2 changes: 1 addition & 1 deletion master/notes/source_of_recompilation.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
Binary file modified master/objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion master/py-modindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/search.html
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@


<div class="version">
master (2.3.0+gitfe3f23c )
master (2.3.0+git370089a )
</div>


Expand Down
2 changes: 1 addition & 1 deletion master/searchindex.js

Large diffs are not rendered by default.

0 comments on commit 75a3e41

Please sign in to comment.