Skip to content

Commit

Permalink
reorganize the example dir (#7097)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 22, 2024
1 parent 0ce06ec commit c294625
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 2 deletions.
2 changes: 2 additions & 0 deletions examples/data_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## Recommendation
Please consider using `train_resnet_spmd_data_parallel.py` since it uses SPMD internally and are very likely yield better perfomrance.
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
Expand All @@ -21,4 +26,5 @@ def _mp_fn(index):


if __name__ == '__main__':
print('consider using train_resnet_spmd_data_parallel.py instead to get better performance')
xmp.spawn(_mp_fn, args=())
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainResNetXLADDP(TrainResNetBase):

def run_optimizer(self):
# optimizer_step will call `optimizer.step()` and all_reduce the gradident
xm.optimizer_step(self.optimizer)


Expand All @@ -15,4 +21,5 @@ def _mp_fn(index):


if __name__ == '__main__':
print('consider using train_resnet_spmd_data_parallel.py instead to get better performance')
xmp.spawn(_mp_fn, args=())
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase

import itertools
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase

import torch_xla.debug.profiler as xp

# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables
Expand Down
2 changes: 2 additions & 0 deletions examples/fsdp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## Recommendation
Please consider using `train_decoder_only_fsdp_v2.py` since it uses SPMD internally and are very likely yield better perfomrance.
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
import decoder_only_model
from train_decoder_only_base import TrainDecoderOnlyBase

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase
from functools import partial

Expand All @@ -10,7 +14,6 @@
from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
transformer_auto_wrap_policy)


class TrainResNetXLAFSDP(TrainResNetBase):

def __init__(self):
Expand Down Expand Up @@ -49,4 +52,5 @@ def _mp_fn(index):


if __name__ == '__main__':
print('consider using train_decoder_only_fsdp_v2.py instead to get better performance')
xmp.spawn(_mp_fn, args=())

0 comments on commit c294625

Please sign in to comment.