Skip to content

Commit

Permalink
Update the guide for multi process persistent cache (#7917)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Aug 27, 2024
1 parent c58680c commit 17a4ef5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
14 changes: 14 additions & 0 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,20 @@ This will initialize a persistent compilation cache at the specified path. The
write to the cache, which can be useful when a shared cache mount is used for
an SPMD workload.

If you want to use persistent compilation cache in the multi process training(with `torch_xla.launch` or `xmp.spawn`), you should use the different path for different process.

```python
def _mp_fn(index):
# cache init needs to happens inside the mp_fn.
xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
....

if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
```
If you don't have the access to the `index`, you can use `xr.global_ordinal()`. Check out the runnable example in [here](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_xla_ddp.py).


## Further Reading

Additional documentation is available at the
Expand Down
3 changes: 3 additions & 0 deletions examples/data_parallel/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


class TrainResNetXLADDP(TrainResNetBase):
Expand All @@ -16,6 +17,8 @@ def run_optimizer(self):


def _mp_fn(index):
# cache init needs to happens inside the mp_fn.
xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()

Expand Down

0 comments on commit 17a4ef5

Please sign in to comment.