Skip to content

Commit

Permalink
[Dist] Add ZeRO-1 Optimizer (#4648)
Browse files Browse the repository at this point in the history
* init

* test

* lint

* address comments
  • Loading branch information
hgt312 authored Feb 21, 2023
1 parent 023d763 commit 1d313bb
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ function run_op_tests {
run_test python3 "$CDIR/test_profiler.py"
run_test python3 "$CDIR/test_ops.py"
run_test python3 "$CDIR/test_metrics.py"
run_test python3 "$CDIR/test_zero1.py"
run_test python3 "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test python3 "$CDIR/dynamo/test_dynamo.py"
run_test python3 "$CDIR/dynamo/test_bridge.py"
Expand Down
198 changes: 198 additions & 0 deletions test/test_train_mp_mnist_zero1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import args_parse

FLAGS = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18,
)

import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils


class MNIST(nn.Module):

def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
self.bn2 = nn.BatchNorm2d(32)
self.fc1 = nn.Linear(512, 80)
self.fc2 = nn.Linear(80, 16)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def _train_update(device, step, loss, tracker, epoch, writer):
test_utils.print_training_update(
device,
step,
loss.item(),
tracker.rate(),
tracker.global_rate(),
epoch,
summary_writer=writer)


def train_mnist(flags, **kwargs):
torch.manual_seed(1)

if flags.fake_data:
train_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=60000 // flags.batch_size // xm.xrt_world_size())
test_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=10000 // flags.batch_size // xm.xrt_world_size())
else:
train_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = None
if xm.xrt_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags.batch_size,
sampler=train_sampler,
drop_last=flags.drop_last,
shuffle=False if train_sampler else True,
num_workers=flags.num_workers)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=flags.batch_size,
drop_last=flags.drop_last,
shuffle=False,
num_workers=flags.num_workers)

# Scale learning rate to num cores
lr = flags.lr * xm.xrt_world_size()

device = xm.xla_device()
model = MNIST().to(device)

writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optimizer = ZeroRedundancyOptimizer(
model.parameters(),
optim.SGD,
lr=lr,
momentum=flags.momentum,
pin_layout=False)
loss_fn = nn.NLLLoss()

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
_train_update,
args=(device, step, loss, tracker, epoch, writer),
run_async=flags.async_closures)

def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]

accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy

train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())

test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
return max_accuracy


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
if accuracy < flags.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
flags.target_accuracy))
sys.exit(21)


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
54 changes: 54 additions & 0 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

import unittest


class XlaZeRO1Test(unittest.TestCase):

def test_zero1(self):
device = xm.xla_device()

model = nn.Linear(8, 8)
x = torch.ones((8, 8))
model = model.to(device)
x = x.to(device)
y = model(x).sum()
y.backward()

opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
opt2 = ZeroRedundancyOptimizer(
model.parameters(),
torch.optim.SGD,
lr=0.01,
momentum=0.9,
grad_clipping=False)

opt1.step()
opt2.step()
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])

s1 = opt1.state_dict()
s2 = opt2.state_dict()
opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])

# step still runnable
opt1.step()
opt2.step()
opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])

# step still runnable
opt1.step()
opt2.step()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading

0 comments on commit 1d313bb

Please sign in to comment.