Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 11, 2024
1 parent 433cdc1 commit 748df10
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 63 deletions.
96 changes: 49 additions & 47 deletions examples/train_resnet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,60 @@

time.ctime()


def _train_update(step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')


class TrainResNetBase():
def __init__(self):
img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs=1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size //
xm.xrt_world_size())

self.device = xm.xla_device()
self.train_device_loader = pl.MpDeviceLoader(
train_loader,
self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(
self.model.parameters(),
weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
self.optimizer.step()
def __init__(self):
img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs = 1
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
torch.zeros(self.batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // self.batch_size //
xm.xrt_world_size())

self.device = xm.xla_device()
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()

def run_optimizer(self):
self.optimizer.step()

def start_training(self):

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
if self.num_steps == step:
break

def start_training(self):
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
self.model.train()
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
_train_update, args=(step, loss, tracker, epoch))
if self.num_steps == step:
break
for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()

for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
train_loop_fn(self.train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()

if __name__ == '__main__':
base = TrainResNetBase()
base.start_training()
base = TrainResNetBase()
base.start_training()
24 changes: 13 additions & 11 deletions examples/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@


class TrainResNetDDP(TrainResNetBase):
def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
self.model = DDP(self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(
self.model.parameters(),
weight_decay=1e-4)


def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)


def _mp_fn(index):
ddp = TrainResNetDDP()
ddp.start_training()
ddp = TrainResNetDDP()
ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
xmp.spawn(_mp_fn, args=())
14 changes: 9 additions & 5 deletions examples/train_resnet_xla_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainResNetXLADDP(TrainResNetBase):
def run_optimizer(self):
xm.optimizer_step(self.optimizer)

def run_optimizer(self):
xm.optimizer_step(self.optimizer)


def _mp_fn(index):
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()
xla_ddp = TrainResNetXLADDP()
xla_ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
xmp.spawn(_mp_fn, args=())

0 comments on commit 748df10

Please sign in to comment.