diff --git a/test/run_tests.sh b/test/run_tests.sh index f30818f09cc..5fab1de031d 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -146,6 +146,7 @@ function run_op_tests { run_test python3 "$CDIR/test_profiler.py" run_test python3 "$CDIR/test_ops.py" run_test python3 "$CDIR/test_metrics.py" + run_test python3 "$CDIR/test_zero1.py" run_test python3 "$CDIR/dynamo/test_dynamo_integrations_util.py" run_test python3 "$CDIR/dynamo/test_dynamo.py" run_test python3 "$CDIR/dynamo/test_bridge.py" diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py new file mode 100644 index 00000000000..6f8d3964b52 --- /dev/null +++ b/test/test_train_mp_mnist_zero1.py @@ -0,0 +1,198 @@ +import args_parse + +FLAGS = args_parse.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=128, + momentum=0.5, + lr=0.01, + target_accuracy=98.0, + num_epochs=18, +) + +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils + + +class MNIST(nn.Module): + + def __init__(self): + super(MNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 16, kernel_size=5) + self.bn1 = nn.BatchNorm2d(16) + self.conv2 = nn.Conv2d(16, 32, kernel_size=5) + self.bn2 = nn.BatchNorm2d(32) + self.fc1 = nn.Linear(512, 80) + self.fc2 = nn.Linear(80, 16) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def _train_update(device, step, loss, tracker, epoch, writer): + test_utils.print_training_update( + device, + step, + loss.item(), + tracker.rate(), + tracker.global_rate(), + epoch, + summary_writer=writer) + + +def train_mnist(flags, **kwargs): + torch.manual_seed(1) + + if flags.fake_data: + train_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, + 28), torch.zeros(flags.batch_size, + dtype=torch.int64)), + sample_count=60000 // flags.batch_size // xm.xrt_world_size()) + test_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, + 28), torch.zeros(flags.batch_size, + dtype=torch.int64)), + sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + else: + train_dataset = datasets.MNIST( + os.path.join(flags.datadir, str(xm.get_ordinal())), + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])) + test_dataset = datasets.MNIST( + os.path.join(flags.datadir, str(xm.get_ordinal())), + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])) + train_sampler = None + if xm.xrt_world_size() > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=flags.batch_size, + sampler=train_sampler, + drop_last=flags.drop_last, + shuffle=False if train_sampler else True, + num_workers=flags.num_workers) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=flags.batch_size, + drop_last=flags.drop_last, + shuffle=False, + num_workers=flags.num_workers) + + # Scale learning rate to num cores + lr = flags.lr * xm.xrt_world_size() + + device = xm.xla_device() + model = MNIST().to(device) + + writer = None + if xm.is_master_ordinal(): + writer = test_utils.get_summary_writer(flags.logdir) + optimizer = ZeroRedundancyOptimizer( + model.parameters(), + optim.SGD, + lr=lr, + momentum=flags.momentum, + pin_layout=False) + loss_fn = nn.NLLLoss() + + def train_loop_fn(loader, epoch): + tracker = xm.RateTracker() + model.train() + for step, (data, target) in enumerate(loader): + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + tracker.add(flags.batch_size) + if step % flags.log_steps == 0: + xm.add_step_closure( + _train_update, + args=(device, step, loss, tracker, epoch, writer), + run_async=flags.async_closures) + + def test_loop_fn(loader): + total_samples = 0 + correct = 0 + model.eval() + for data, target in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum() + total_samples += data.size()[0] + + accuracy = 100.0 * correct.item() / total_samples + accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) + return accuracy + + train_device_loader = pl.MpDeviceLoader(train_loader, device) + test_device_loader = pl.MpDeviceLoader(test_loader, device) + accuracy, max_accuracy = 0.0, 0.0 + for epoch in range(1, flags.num_epochs + 1): + xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) + train_loop_fn(train_device_loader, epoch) + xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) + + accuracy = test_loop_fn(test_device_loader) + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( + epoch, test_utils.now(), accuracy)) + max_accuracy = max(accuracy, max_accuracy) + test_utils.write_to_summary( + writer, + epoch, + dict_to_write={'Accuracy/test': accuracy}, + write_xla_metrics=True) + if flags.metrics_debug: + xm.master_print(met.metrics_report()) + + test_utils.close_summary_writer(writer) + xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) + return max_accuracy + + +def _mp_fn(index, flags): + torch.set_default_tensor_type('torch.FloatTensor') + accuracy = train_mnist(flags) + if flags.tidy and os.path.isdir(flags.datadir): + shutil.rmtree(flags.datadir) + if accuracy < flags.target_accuracy: + print('Accuracy {} is below target {}'.format(accuracy, + flags.target_accuracy)) + sys.exit(21) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) diff --git a/test/test_zero1.py b/test/test_zero1.py new file mode 100644 index 00000000000..11135cc5dab --- /dev/null +++ b/test/test_zero1.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer + +import unittest + + +class XlaZeRO1Test(unittest.TestCase): + + def test_zero1(self): + device = xm.xla_device() + + model = nn.Linear(8, 8) + x = torch.ones((8, 8)) + model = model.to(device) + x = x.to(device) + y = model(x).sum() + y.backward() + + opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + opt2 = ZeroRedundancyOptimizer( + model.parameters(), + torch.optim.SGD, + lr=0.01, + momentum=0.9, + grad_clipping=False) + + opt1.step() + opt2.step() + assert str(opt1.state_dict()) == str(opt2.state_dict()['base']) + + s1 = opt1.state_dict() + s2 = opt2.state_dict() + opt1.load_state_dict(s1) + opt2.load_state_dict(s2) + assert str(opt1.state_dict()) == str(opt2.state_dict()['base']) + + # step still runnable + opt1.step() + opt2.step() + opt1.load_state_dict(s1) + opt2.load_state_dict(s2) + assert str(opt1.state_dict()) == str(opt2.state_dict()['base']) + + # step still runnable + opt1.step() + opt2.step() + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py new file mode 100644 index 00000000000..12327f8fdd1 --- /dev/null +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -0,0 +1,162 @@ +from copy import deepcopy +from typing import ( + Any, + Iterator, + Optional, + Type, +) + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + +import torch_xla +import torch_xla.core.xla_model as xm + + +class ZeroRedundancyOptimizer(Optimizer): + r""" + ZeRO-1 wrapper. This class can wrap an arbitrary :class:`optim.Optimizer + ` and shards its states across ranks. + + Arguments: + params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s + or :class:`dict` s giving all parameters, which will be sharded + across ranks. + optimizer_class (:class:`torch.nn.Optimizer`): the class of the local + optimizer. + optimizer_dtype (:class:`torch.dtype`, optional): the desired data type + of optimizer. Default: ``torch.float32`` + grad_clipping (bool, optional): enable (True) or disable (False) grad + clipping. Default: True + max_norm (float, optional): max norm of the gradients, effective only + when ``grad_clipping`` is True. Default: 1.0 + pin_layout (bool, Optional): if ``True``, then pin the layout in the + collective ops (all_gather and reduce_scatter). See `xm.all_reduce` + for details on pinning layout. Default: True + **defaults: any trailing arguments, which are forwarded to the local + optimizer. + + .. note:: This runs `step` on sharded parameters. This might lead to + accuracy disparities compared to using original local optimizer. As + some optimizers (e.g. LAMB) compute global norm and norm for each + parameter, using sharded parameter results in different norm values. + """ + + def __init__( + self, + params: Iterator[Tensor], + optimizer_class: Type[Optimizer], + optimizer_dtype: Optional[Any] = None, + grad_clipping: bool = True, + max_norm: Optional[float] = None, + pin_layout: bool = True, + **defaults: Any, + ): + self.params = list(params) + super().__init__(self.params, defaults) + + self.device = self.params[0].device + + self.rank = xm.get_ordinal() + self.world_size = xm.xrt_world_size() + self.cc_op_groups = [list(range(self.world_size))] + + self.optimizer_dtype = optimizer_dtype if optimizer_dtype is not None else torch.float32 + self.grad_clipping = grad_clipping + self.max_norm = max_norm if max_norm is not None else 1.0 + self.pin_layout = pin_layout + + # Shard parameters for use in optimizer + self.sharded_params = [] + self._shard_parameters() + # Optimizer initialization + self.base_optimizer = optimizer_class(iter(self.sharded_params), **defaults) + + def _shard_tensor(self, tensor: torch.Tensor): + """ + Get the shard of the input tensor. + """ + assert tensor.shape[0] % self.world_size == 0, "Not support padding now." + tensor = tensor.chunk(self.world_size)[self.rank] + return tensor + + def _shard_parameters(self): + """ + Shard all parameters. + """ + xm.unlazy(self.params) + for param in self.params: + shard_data = param.data.to(device="cpu") # move to cpu + shard_data = self._shard_tensor(shard_data) # slice it + if shard_data.dtype != self.optimizer_dtype: + shard_data = shard_data.to(dtype=self.optimizer_dtype) + shard_data = shard_data.to(device=self.device) # move to xla device + shard = nn.Parameter(shard_data, requires_grad=param.requires_grad) + self.sharded_params.append(shard) + + @torch.no_grad() + def step(self, closure=None, **kwargs): + """ + Performs a single optimizer step and syncs parameters across all ranks. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Reduce full gradients across ranks + # Assign gradient shards to the respective parameter shards + for param, shard in zip(self.params, self.sharded_params): + if param.grad is not None: + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + param.grad, + scale=1.0 / self.world_size, + scatter_dim=0, + shard_count=self.world_size, + pin_layout=self.pin_layout, + groups=self.cc_op_groups, + ) + + if grad_shard.dtype != self.optimizer_dtype: + grad_shard = grad_shard.to(dtype=self.optimizer_dtype) + shard.grad = grad_shard + + if self.grad_clipping: + # Update unscale/clip with sub partitions + torch.nn.utils.clip_grad_norm_( + self.sharded_params, max_norm=self.max_norm) + + # Step the wrapped optimizer + loss = self.base_optimizer.step(closure=closure, **kwargs) + # Remove shards' grads + self.base_optimizer.zero_grad(set_to_none=True) + + # All gather the new weights across the ranks and assign them to the full parameters + for param, shard in zip(self.params, self.sharded_params): + if param.grad is not None: + shard_data = shard.data + if param.dtype != self.optimizer_dtype: + shard_data = shard_data.to(dtype=param.dtype) + xm.all_gather( + shard_data, + dim=0, + output=param.data, + pin_layout=self.pin_layout, + groups=self.cc_op_groups, + ) + + return loss + + def state_dict(self): + state_dict = super().state_dict() + state_dict['base'] = self.base_optimizer.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + state_dict = deepcopy(state_dict) + base = state_dict.pop('base') + super().load_state_dict(state_dict) + self.base_optimizer.load_state_dict(base)