diff --git a/docs/nanoset.md b/docs/nanoset.md index 9dce21b7..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood diff --git a/examples/doremi/README.md b/examples/doremi/README.md index 5a726bd1..dfc9ea40 100644 --- a/examples/doremi/README.md +++ b/examples/doremi/README.md @@ -87,3 +87,7 @@ For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model - 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi + +#### Thoughts + +For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant. diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 5c31d07f..8eefa9c2 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -18,6 +18,18 @@ pip install -r requirements.txt > https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 +## Bug related to nanotron +Encountered the following issue when ran train_mamba.sh: +``` +causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv +``` +Solved this by doing: +pip uninstall mamba-ssm +pip install causal_conv1d==1.1.1 +pip install mamba-ssm --no-cache-dir +https://github.com/state-spaces/mamba/issues/169 + + ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba diff --git a/examples/mup/README.md b/examples/mup/README.md index c86850ca..ed94c1fb 100644 --- a/examples/mup/README.md +++ b/examples/mup/README.md @@ -32,3 +32,8 @@ We trained a 350m model with spectral µTransfer and standard parametrization us Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments. ![LLaMA](./assets/llama.png) + + +#### Thoughts + +For Spectral MuP, the experiments we used it on MLP only [link] and 300m LLaMA [link] (there are links to the experiment config in the mup readme). However, when we tested it on 1B/8B models iirc, the loss blew up for some reasons. So, we'd recommend they try μTransfer, not spectral μTransfer. diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 3c619a9e..4f8cc1c2 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -128,7 +128,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -162,6 +162,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..7f20ad99 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -23,6 +23,7 @@ class ParallelismArgs: pp_engine: Pipeline engine to use between "1f1b" and "afab" tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism tp_linear_async_communication: Whether to use async communication in TP linear layers + recompute_layer: Whether to recompute each Transformer layer to save memory. """ dp: int @@ -31,6 +32,9 @@ class ParallelismArgs: pp_engine: Optional[PipelineEngine] = None tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + recompute_layer: bool = False + + tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2411e5fa..28a2e30f 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,10 +14,11 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn +from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist from nanotron import logging @@ -154,6 +155,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -163,8 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) @@ -301,6 +302,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. self.rotary_embedding = RotaryEmbedding( @@ -591,12 +593,14 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - - def forward( + + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -609,12 +613,31 @@ def forward( hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] hidden_states = hidden_states + residual + return hidden_states, output["sequence_mask"] + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) + else: + hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + return { "hidden_states": hidden_states, - "sequence_mask": output["sequence_mask"], + "sequence_mask": sequence_mask, } - class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -716,6 +739,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): *rest_size, device=tensor.device, dtype=tensor.dtype, - requires_grad=tensor.requires_grad, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..e2ee3a29 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -20,13 +20,12 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -89,10 +88,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] @@ -121,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -141,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -149,14 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -204,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -261,8 +271,8 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + handle1: Optional[dist.Work] = None + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -273,14 +283,10 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +295,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,41 +306,128 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.contiguous() + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + if group.size() == 1: + sub_grad_input = grad_input + else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -345,18 +435,19 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): @@ -387,8 +478,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -398,12 +488,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) # Doing gather + slicing during the NeMo forward pass can make this tensor @@ -412,31 +498,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) - - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +578,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index b9ac12ae..f694b0e6 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,11 +1,31 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton + + +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + return self.buffer[name, dtype][:required_numel].view(shape) def assert_cuda_max_connections_set_to_1(func): diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..346ad573 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -236,6 +236,7 @@ def load( load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, + parallel_context=parallel_context, root_folder=root_folder, ) return checkpoint_metadata diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 68a3b1a0..f11210da 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -def lr_scheduler_filename(): +def lr_scheduler_filename(parallel_context: ParallelContext): """The lr_scheduler is the same for all processes.""" - return f"{ObjectType.LR_SCHEDULER.value}.pt" + return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" def save_optimizer( @@ -109,9 +109,6 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(parallel_context.world_pg) > 0: - # Only WORLD-RANK 0 saves the lr scheduler state - return root_folder = root_folder / "lr_scheduler" root_folder.mkdir(exist_ok=True, parents=True) @@ -119,7 +116,7 @@ def save_lr_scheduler( # We dump the optimizer state using `torch.save` torch.save( lr_scheduler.state_dict(), - root_folder / lr_scheduler_filename(), + root_folder / lr_scheduler_filename(parallel_context), ) @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - def load_lr_scheduler( lr_scheduler, + parallel_context: ParallelContext, root_folder: Path, ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename()) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 31b63cbb..3994ddd3 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -209,6 +209,7 @@ def __init__( if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, + parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) @@ -461,6 +462,9 @@ def train( dist.barrier() # let's wait for everyone before leaving + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( @@ -887,8 +891,8 @@ def save_checkpoint(self) -> Path: ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.parallel_context.world_pg) == 0 - ), # We only save the lr_scheduler on world_rank==0 + dist.get_rank(self.parallel_context.dp_pg) == 0 + ), # We only save the lr_scheduler on DP==0 should_save_config=bool( dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the config on world_rank==0 diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 14fe1ca8..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,11 +1,10 @@ import functools import inspect -import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import Callable, ContextManager, List, Optional +from typing import ContextManager, List, Optional import torch from packaging import version @@ -15,6 +14,25 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` @@ -52,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Usefull to run only once per node first (e.g. to create local files, etc) + Useful to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -123,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() + def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 127ba2fa..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True @@ -150,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 @@ -208,14 +239,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -261,6 +297,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy()