diff --git a/API_GUIDE.md b/API_GUIDE.md index edea6857d7c..3738a3a9f71 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -316,6 +316,20 @@ This will initialize a persistent compilation cache at the specified path. The write to the cache, which can be useful when a shared cache mount is used for an SPMD workload. +If you want to use persistent compilation cache in the multi process training(with `torch_xla.launch` or `xmp.spawn`), you should use the different path for different process. + +```python +def _mp_fn(index): + # cache init needs to happens inside the mp_fn. + xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False) + .... + +if __name__ == '__main__': + torch_xla.launch(_mp_fn, args=()) +``` +If you don't have the access to the `index`, you can use `xr.global_ordinal()`. Check out the runnable example in [here](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_xla_ddp.py). + + ## Further Reading Additional documentation is available at the diff --git a/examples/data_parallel/train_resnet_xla_ddp.py b/examples/data_parallel/train_resnet_xla_ddp.py index 4e98ba1f442..c52d62115f2 100644 --- a/examples/data_parallel/train_resnet_xla_ddp.py +++ b/examples/data_parallel/train_resnet_xla_ddp.py @@ -6,6 +6,7 @@ import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr class TrainResNetXLADDP(TrainResNetBase): @@ -16,6 +17,8 @@ def run_optimizer(self): def _mp_fn(index): + # cache init needs to happens inside the mp_fn. + xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False) xla_ddp = TrainResNetXLADDP() xla_ddp.start_training()