diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 6d91237bc8..4712de5d5e 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -22,6 +22,8 @@ from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import ( MegaBlocksMoE_TokPerExpert, ) +from llmfoundry.callbacks.loss_perp_v_len_callback import \ + LossPerpVsContextLengthLogger from llmfoundry.callbacks.monolithic_ckpt_callback import ( MonolithicCheckpointSaver, ) @@ -52,6 +54,8 @@ callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) callbacks.register('run_timeout', func=RunTimeoutCallback) +callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger) + callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -66,4 +70,5 @@ 'MegaBlocksMoE_TokPerExpert', 'AsyncEval', 'CurriculumLearning', + 'LossPerpVsContextLengthLogger', ] diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py new file mode 100644 index 0000000000..1a3ac05651 --- /dev/null +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -0,0 +1,351 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Mapping, Optional, Tuple + +import torch +from composer.core import Callback, State +from composer.loggers import Logger, MLFlowLogger +from composer.utils import dist +from torchmetrics import Metric + +from llmfoundry.models.mpt import ComposerMPTCausalLM +from llmfoundry.utils.warnings import experimental_class + +__all__ = [ + 'LossPerpVsContextLengthLogger', +] + + +@experimental_class('LossPerpVsContextLengthLogger') +class LossPerpVsContextLengthLogger(Callback): + """Logs the average loss and perplexity for every context length. + + Note: Currently only works with MLFlow logger. + + Args: + log_batch_interval (int): The interval for logging. Currently logging takes longer because MLFlow downloads the table, appends rows to it, and then re-uploads it. Once this is fixed, log_batch_interval will be removed and this will always log as soon as the metric is computed. + compute_batch_interval (int): The interval for computing the metric. + ignore_index (int): Specifies a target value that is ignored for computing loss. + """ + + def __init__( + self, + log_batch_interval: int, + compute_batch_interval: int, + ignore_index: int = -100, + ): + if compute_batch_interval > log_batch_interval: + raise ValueError( + 'log_batch_interval is shorter than the compute_batch_interval for LossPerpVsContextLengthLogger.', + ) + self.log_batch_interval = log_batch_interval + self.compute_batch_interval = compute_batch_interval + self.ignore_index = ignore_index + self.metric_dict = {} + self.loss_perp_v_len = LossPerpVLen(ignore_index) + + def init(self, state: State, logger: Logger) -> None: + if not isinstance(state.model, ComposerMPTCausalLM): + raise ValueError( + 'LossPerpVsContextLengthLogger only supported for ComposerMPTCausalLM models.', + ) + if state.model.shift_labels is None: + raise ValueError( + 'state.model.shift_labels should be set for LossPerpVsContextLengthLogger.', + ) + if all( + not isinstance(destination, MLFlowLogger) + for destination in logger.destinations + ): + raise NotImplementedError( + 'Did not find MLflow in the list of loggers. LossPerpVsContextLengthLogger is only implemented for the MLflow logger.', + ) + + def after_backward(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.compute_batch_interval == 0: + sequence_id = state.batch['sequence_id' + ] if 'sequence_id' in state.batch else None + labels = state.batch['labels'] + if state.model.shift_labels: + labels[:, :-1] = labels[:, 1:].detach().clone() + labels[:, -1] = -100 + seq_parallel_world_size = getattr( + state.model.model.transformer, + 'seq_parallel_world_size', + 1, + ) + seq_parallel_rank = state.model.model.transformer.seq_parallel_rank if seq_parallel_world_size > 1 else 0 + + if isinstance(state.outputs, Mapping): + logits = state.outputs['logits'] # type: ignore + elif isinstance(state.outputs, torch.Tensor): + logits = state.outputs + else: + raise Exception( + f'Type {type(state.outputs)} for the output is unsupported.', + ) + + if labels.shape[1] != logits.shape[1]: + raise ValueError( + f'The length of labels, {labels.shape[1]=} does not match the length of logits {logits.shape[1]=}.', + ) + + labels, logits = self.preprocess_metric_inputs( + sequence_id, + labels, + logits, + seq_parallel_world_size, + seq_parallel_rank, + ) + + self.loss_perp_v_len.update( + labels, + logits, + sequence_id, + state.model.loss_fn, + ) + + def batch_end(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.compute_batch_interval == 0: + current_metric_dict = self.loss_perp_v_len.compute() + if dist.get_global_rank() == 0: + for k, v in current_metric_dict.items(): + v = v.tolist() + v.append( + state.timestamp.batch.value, + ) # Add the current batch index as the last column + if k not in self.metric_dict: + self.metric_dict[k] = [] + self.metric_dict[k].append(v) + if state.timestamp.batch.value % self.log_batch_interval == 0 and dist.get_global_rank( + ) == 0: + for k, v in self.metric_dict.items(): + columns = [] + columns = [ + f'context_length_{i}' for i in range(len(v[0]) - 1) + ] # len(v[0]) - 1 because the last column is the batch index + columns.append( + 'batch_index', + ) # Add batch as the last column name + for destination in logger.destinations: + if isinstance(destination, MLFlowLogger): + destination.log_table( + columns=columns, + rows=v, + name=f'metrics/train/LossPerpVLenTable/{k}', + step=state.timestamp.batch.value, + ) + self.metric_dict = {} + + def preprocess_metric_inputs( + self, + sequence_id: Optional[torch.Tensor], + labels: torch.Tensor, + logits: torch.Tensor, + seq_parallel_world_size: int, + seq_parallel_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + del sequence_id, seq_parallel_rank + if seq_parallel_world_size > 1: + raise ValueError( + 'LossPerpVsContextLengthLogger does not support sequence parallelism', + ) + + return labels, logits + + +class LossPerpVLen(Metric): + + full_state_update = False + + def __init__( + self, + ignore_index: int, + dist_sync_on_step: bool = False, + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.ignore_index = ignore_index + self.add_state('sum_loss', default=torch.Tensor(), dist_reduce_fx='sum') + self.add_state( + 'sum_perplexity', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_length', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + + self.add_state( + 'sum_loss_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_perplexity_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_length_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + + def update( + self, + labels: torch.Tensor, + logits: torch.Tensor, + sequence_id: Optional[torch.Tensor], + loss_fn: Any, + ) -> None: + """Updates the internal state with results from a new batch. + + Args: + labels (torch.Tensor): A Tensor of ground-truth values to compare against. + logits (torch.Tensor): A Tensor of labels. + sequence_id (torch.Tensor | None): The sequence ids for tokens. + loss_fn (Any): The cross entropy loss to use. + """ + valid_labels_mask = torch.where( + labels != self.ignore_index, + torch.ones_like(labels), + torch.zeros_like(labels), + ) + bsz, seq_len = labels.shape + loss = loss_fn(logits.view(bsz * seq_len, -1), labels.view(-1)) + loss = loss.view(bsz, seq_len) + perplexity = torch.exp(loss) + + if self.sum_loss.numel() == 0: + self.sum_loss = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_perplexity = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_length = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=torch.long, + ) + self.sum_loss_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_perplexity_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_length_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=torch.long, + ) + + self.sum_loss += torch.sum(loss, dim=(0)) + self.sum_perplexity += torch.sum(perplexity, dim=(0)) + self.sum_length += valid_labels_mask.sum(dim=0) + + if sequence_id is not None: + seq_id_expanded = torch.nn.functional.one_hot( + sequence_id, + ).transpose(-1, -2) + seq_lens = seq_id_expanded.sum(dim=-1) + max_num_seq = seq_lens.shape[1] + seq_tok_ids = torch.arange(seq_len, device=sequence_id.device)[ + None, None, :].expand(bsz, max_num_seq, -1) + mask = seq_tok_ids < seq_lens[:, :, None] + seq_len_offsets = torch.nn.functional.pad( + seq_lens.cumsum(dim=1)[:, :-1], + (1, 0), + value=0, + ) + seq_tok_ids = seq_tok_ids + seq_len_offsets[:, :, None] + seq_tok_ids = torch.where( + mask, + seq_tok_ids, + torch.zeros_like(seq_tok_ids), + ) + + loss = loss[:, None, :].expand(-1, max_num_seq, -1) + perplexity = perplexity[:, None, :].expand(-1, max_num_seq, -1) + valid_labels_mask = valid_labels_mask[:, None, :].expand( + -1, + max_num_seq, + -1, + ) + loss = torch.where( + mask, + torch.gather(input=loss, dim=2, index=seq_tok_ids), + torch.zeros_like(loss), + ) + perplexity = torch.where( + mask, + torch.gather(input=perplexity, dim=2, index=seq_tok_ids), + torch.zeros_like(perplexity), + ) + mask = torch.where( + mask, + torch.gather(input=valid_labels_mask, dim=2, index=seq_tok_ids), + torch.zeros_like(valid_labels_mask), + ) + + self.sum_loss_seq_id += torch.sum(loss, dim=(0, 1)) + self.sum_perplexity_seq_id += torch.sum(perplexity, dim=(0, 1)) + self.sum_length_seq_id += torch.sum(mask, dim=(0, 1)) + + def compute(self) -> Dict[str, torch.Tensor]: + """Aggregate the state over all processes to compute the metric. + + Returns: + loss: The loss averaged across all batches as a :class:`~torch.Tensor`. + """ + # Return average loss over entire dataset + sum_perplexity = torch.where( + self.sum_length != 0, + self.sum_perplexity, + -1, + ) + sum_loss = torch.where(self.sum_length != 0, self.sum_loss, -1) + sum_length = torch.where(self.sum_length != 0, self.sum_length, 1) + + sum_perplexity_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_perplexity_seq_id, + -1, + ) + sum_loss_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_loss_seq_id, + -1, + ) + sum_length_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_length_seq_id, + 1, + ) + + return { + 'mean_loss_v_len': + sum_loss / sum_length, + 'mean_perplexity_v_len': + sum_perplexity / sum_length, + 'sum_length': + self.sum_length, + 'mean_loss_seq_id_v_len': + sum_loss_seq_id / sum_length_seq_id, + 'mean_perplexity_seq_id_v_len': + sum_perplexity_seq_id / sum_length_seq_id, + 'sum_length_seq_id': + self.sum_length_seq_id, + } diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 48444f4e3b..2b62c77eb6 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -310,7 +310,7 @@ def build_torch_dmoe( ) -def _mb_setup_args( +def mb_setup_args( d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int], @@ -319,6 +319,21 @@ def _mb_setup_args( bias: bool, kwargs: dict[str, Any], ) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: + """Setup the MegaBlocks args. + + Args: + d_model (int): The dimension of the input and output of the FFN. + expansion_ratio (Union[int, float]): The expansion ratio of the FFN. + ffn_hidden_size (Optional[int]): The hidden size of the FFN. + ffn_act_fn (Optional[dict]): The activation function of the FFN. + device (Optional[str]): The device to run the FFN on. + bias (bool): Whether to include bias in the FFN. + kwargs (dict[str, Any]): Additional kwargs. + + Returns: + tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: + The MegaBlocks args, the MoE world size, and the expert parallel group. + """ if megablocks is None: raise RuntimeError( 'Requirements for megablocks not installed; see install instructions in `README.md`.', @@ -350,18 +365,39 @@ def _mb_setup_args( return args, moe_world_size, expert_parallel_group -def _patch_ffn_mb( +def attach_ffn_mb_args( ffn: nn.Module, - moe_world_size: int, expert_parallel_group: ProcessGroup, - device_mesh: DeviceMesh, args: 'megablocks.layers.arguments.Arguments', ): - # Attach args to MLP directly for use in param_init_fn + """Attach arguments used in parameter initialization to the FFN. + + Args: + ffn (nn.Module): The FFN module. + expert_parallel_group (ProcessGroup): The expert parallel process group. + args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks. + """ ffn.experts.mlp.hidden_size = args.ffn_hidden_size ffn.experts.mlp.expert_parallel_group = expert_parallel_group ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group + +def set_ffn_device_mesh( + ffn: nn.Module, + moe_world_size: int, + device_mesh: DeviceMesh, +): + """Sets the device mesh in FSDP kwargs. + + Args: + ffn (nn.Module): The FFN module. + moe_world_size (int): The MoE world size. + device_mesh (DeviceMesh): The full device mesh. + + Raises: + RuntimeError: If the device mesh is 3D. + ValueError: If the device mesh is not 2D or 3D. + """ if moe_world_size > 1: expert_mesh = device_mesh['expert_parallel'] expert_placements: List[Placement] = [Shard(0)] @@ -389,6 +425,15 @@ def _patch_ffn_mb( } +def moe_fused_init_setup(ffn: nn.Module,): + """Attach the _stack_dim attribute to the FFN. + + Args: + ffn (nn.Module): The FFN module. + """ + ffn.experts.mlp._stack_dim = 0 + + def build_mb_moe( d_model: int, expansion_ratio: Union[int, float], @@ -403,7 +448,7 @@ def build_mb_moe( 'Requirements for megablocks not installed; see install instructions in `README.md`.', ) - args, moe_world_size, expert_parallel_group = _mb_setup_args( + args, moe_world_size, expert_parallel_group = mb_setup_args( d_model=d_model, expansion_ratio=expansion_ratio, ffn_hidden_size=ffn_hidden_size, @@ -415,21 +460,42 @@ def build_mb_moe( ffn = megablocks.layers.moe.MoE(args) - # Fused initialization setup - # For param_init_fn, enables shape based init of stacked layers - ffn.experts.mlp._stack_dim = 0 - - _patch_ffn_mb( + moe_fused_init_setup(ffn=ffn,) + attach_ffn_mb_args( ffn=ffn, - moe_world_size=moe_world_size, expert_parallel_group=expert_parallel_group, - device_mesh=kwargs['device_mesh'], args=args, ) + set_ffn_device_mesh( + ffn=ffn, + moe_world_size=moe_world_size, + device_mesh=kwargs['device_mesh'], + ) return ffn +def dmoe_fused_init_setup( + ffn: nn.Module, + args: 'megablocks.layers.arguments.Arguments', + moe_world_size: int, +): + """Attach the _fused attribute to the dMoE model. + + This is used for parameter initialization. + + Args: + ffn (nn.Module): The FFN module. + args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks. + moe_world_size (int): The MoE world size. + """ + n_exp = min(1, args.moe_num_experts // moe_world_size) + ffn.experts.mlp._fused = ( + 0, + [(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)], + ) + + def build_mb_dmoe( d_model: int, expansion_ratio: Union[int, float], @@ -444,7 +510,7 @@ def build_mb_dmoe( 'Requirements for megablocks not installed; see install instructions in `README.md`.', ) - args, moe_world_size, expert_parallel_group = _mb_setup_args( + args, moe_world_size, expert_parallel_group = mb_setup_args( d_model=d_model, expansion_ratio=expansion_ratio, ffn_hidden_size=ffn_hidden_size, @@ -456,21 +522,21 @@ def build_mb_dmoe( ffn = megablocks.layers.dmoe.dMoE(args) - # Fused initialization setup - # For param_init_fn, enables shape based init of fused layers - n_exp = min(1, args.moe_num_experts // moe_world_size) - ffn.experts.mlp._fused = ( - 0, - [(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)], - ) - - _patch_ffn_mb( + dmoe_fused_init_setup( ffn=ffn, + args=args, moe_world_size=moe_world_size, + ) + attach_ffn_mb_args( + ffn=ffn, expert_parallel_group=expert_parallel_group, - device_mesh=kwargs['device_mesh'], args=args, ) + set_ffn_device_mesh( + ffn=ffn, + moe_world_size=moe_world_size, + device_mesh=kwargs['device_mesh'], + ) return ffn diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cb62b462c2..d51558f04d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -19,6 +19,7 @@ MutableMapping, Optional, Tuple, + Type, Union, ) @@ -368,14 +369,7 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True - block_args = self.extract_block_args(config.to_dict()) - - self.blocks = nn.ModuleList([ - MPTBlock( - device=config.init_device, - **block_args, - ) for _ in range(config.n_layers) - ]) + self.blocks = self.construct_blocks(config=config,) # Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx for i, block in enumerate(self.blocks): @@ -437,6 +431,24 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') + def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: + """Construct the nn.ModuleList with the Transformer blocks. + + Args: + config (MPTConfig): The configuration object. + + Returns: + nn.ModuleList: The list of Transformer blocks. + """ + block_args = self.extract_block_args(config.to_dict()) + + return nn.ModuleList([ + MPTBlock( + device=config.init_device, + **block_args, + ) for _ in range(config.n_layers) + ]) + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: """Sets the block args.""" if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: @@ -787,7 +799,7 @@ def __init__(self, config: MPTConfig): super().__init__(config) log.info(f'Instantiating an MPTForCausalLM model from {__file__}') - self.transformer: MPTModel = MPTModel(config) + self.transformer: MPTModel = self.backbone_model_class(config) self.lm_head = None if not config.tie_word_embeddings: @@ -819,6 +831,10 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale + @property + def backbone_model_class(self) -> Type[MPTModel]: + return MPTModel + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.transformer.get_input_embeddings() @@ -1082,9 +1098,7 @@ def __init__( additional_train_metrics = additional_train_metrics or [] - model = MPTForCausalLM( - MPTConfig(use_train_metrics=use_train_metrics, **kwargs), - ) + model = self.model_class(self.config_class(**kwargs),) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics @@ -1134,6 +1148,14 @@ def __init__( f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].', ) + @property + def model_class(self) -> Type[MPTForCausalLM]: + return MPTForCausalLM + + @property + def config_class(self) -> Type[MPTConfig]: + return MPTConfig + def get_targets(self, batch: Mapping) -> torch.Tensor: targets = torch.roll(batch['labels'], shifts=-1) targets[:, -1] = -100 diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 29f8d1bfcc..e27514275c 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -4,11 +4,12 @@ """Helper function to configure MPT with MoEs.""" import inspect -from typing import Union +from typing import Callable, Optional, Union import torch from packaging import version from torch import distributed +from torch.distributed._tensor import DeviceMesh from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size @@ -64,11 +65,47 @@ def create_set_process_group(k: int): return create_process_group_ranks(ranks) +def get_megablocks_device_mesh( + device_mesh_cfg: Optional[tuple[int]], + moe_world_size: int, + world_size: int, +) -> DeviceMesh: + """Helper function to get the device mesh for MegaBlocks MoE. + + Args: + device_mesh_cfg (Optional[tuple[int]]): The device mesh configuration specification. + moe_world_size (int): The MoE world size. + world_size (int): The world size. + + Raises: + ValueError: If the device mesh configuration is not valid. + + Returns: + The device mesh for MegaBlocks MoE. + """ + from torch.distributed._tensor.device_mesh import init_device_mesh + + if device_mesh_cfg is None or len(device_mesh_cfg) == 1: + if device_mesh_cfg is not None: + world_size = device_mesh_cfg[0] + sharding_group_dim = world_size // moe_world_size + device_mesh = init_device_mesh( + 'cuda', + (sharding_group_dim, moe_world_size), + mesh_dim_names=('weight_parallel', 'expert_parallel'), + ) + else: + raise ValueError(f'{device_mesh_cfg=} must be length 1') + + return device_mesh + + def config_megablocks_moe_args( ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int, + get_device_mesh: Callable, ) -> dict: """Configures `ffn_config` for MegaBlocks MoE. @@ -80,6 +117,7 @@ def config_megablocks_moe_args( d_model (int): Hidden size of the network. expansion_ratio (Union[int, float]): Expansion ratio in FFN. n_layers (int): Number of blocks used in the network. + get_device_mesh (Callable): Function to get the device mesh. Takes in the device mesh config and the MoE world size. Returns: ffn_config (dict): FFN configuration with MegaBlocks MoE configured. @@ -112,26 +150,17 @@ def config_megablocks_moe_args( 'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.', ) - from torch.distributed._tensor.device_mesh import init_device_mesh - world_size = distributed.get_world_size() if world_size < moe_world_size or world_size % moe_world_size: raise ValueError( f'Invalid world size configuration: {world_size=} and {moe_world_size=}', ) - # FSDP - if device_mesh_cfg is None or len(device_mesh_cfg) == 1: - if device_mesh_cfg is not None: - world_size = device_mesh_cfg[0] - sharding_group_dim = world_size // moe_world_size - device_mesh = init_device_mesh( - 'cuda', - (sharding_group_dim, moe_world_size), - mesh_dim_names=('weight_parallel', 'expert_parallel'), - ) - else: - raise ValueError(f'{device_mesh_cfg=} must be length 1') + device_mesh = get_device_mesh( + device_mesh_cfg=device_mesh_cfg, + moe_world_size=moe_world_size, + world_size=world_size, + ) ffn_config['moe_expert_model_parallelism'] = True ffn_config['expert_parallel_group'] = device_mesh[ @@ -202,6 +231,7 @@ def config_moe_args( d_model=d_model, expansion_ratio=expansion_ratio, n_layers=n_layers, + get_device_mesh=get_megablocks_device_mesh, ) else: raise ValueError(f'Invalid ffn_type ({ffn_config["ffn_type"]}).') diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6010b19b6f..72ca19834b 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -589,7 +589,7 @@ def _process_data_source( ('uc_volume', source_dataset_path[len('dbfs:'):], true_split), ) # Check for HF path - elif 'hf_name' in dataset: + elif 'hf_name' in dataset and dataset['hf_name']: hf_path = dataset['hf_name'] backend, _, _ = parse_uri(hf_path) if backend: @@ -600,7 +600,7 @@ def _process_data_source( else: data_paths.append(('hf', hf_path, true_split)) # Check for remote path - elif 'remote' in dataset: + elif 'remote' in dataset and dataset['remote']: remote_path = dataset['remote'] backend, _, _ = parse_uri(remote_path) if backend: @@ -610,7 +610,11 @@ def _process_data_source( ) if cfg_split else remote_path data_paths.append((backend, remote_path, true_split)) else: + # No backend detected so assume local path data_paths.append(('local', remote_path, true_split)) + # Check for local path + elif 'local' in dataset and dataset['local']: + data_paths.append(('local', dataset['local'], true_split)) else: log.warning('DataSource Not Found.') diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py new file mode 100644 index 0000000000..46bde1c2f1 --- /dev/null +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -0,0 +1,174 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import MagicMock + +import pytest +import torch +import transformers +from composer.core import State +from composer.core.precision import get_precision_context +from composer.devices import DeviceGPU +from composer.loggers import Logger +from composer.utils import get_device +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry import registry +from llmfoundry.data.text_data import ( + StreamingTextDataset, + build_text_dataloader, +) +from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.registry_utils import construct_from_registry + + +@pytest.mark.gpu +@pytest.mark.parametrize('shift_labels', [True, False]) +def test_loss_perp_v_len_callback( + shift_labels: bool, + monkeypatch: pytest.MonkeyPatch, +): + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') + + composer_device = get_device(None) + + model_max_length = 12 + + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = gptt.eos_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = 'right' + + cfg = { + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'shuffle': True, + 'shuffle_seed': 0, + 'eos_token_id': gptt.eos_token_id, + }, + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0, + } + + ds_mock = MagicMock(spec=StreamingTextDataset) + ds_mock.tokenizer = gptt + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: ds_mock, + ) + dl = build_text_dataloader( + **cfg, + tokenizer=gptt, + device_batch_size=1, + ) + + batch_strings = [ + 'hello hey' + gptt.eos_token + ' the quick brown fox jumps', + ] + + batch_tokenized = [gptt(b, padding=False) for b in batch_strings] + + batch_tokenized = [b['input_ids'] for b in batch_tokenized] + + batch = dl.dataloader.collate_fn(batch_tokenized) # type: ignore + + for k, v in batch.items(): # type: ignore + if isinstance(v, torch.Tensor): + batch[k] = composer_device.tensor_to_device(v) # type: ignore + + attention_impl = 'flash' + + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + with open(conf_path) as f: + test_cfg = om.load(f) + + assert isinstance(test_cfg, DictConfig) + + attn_config = { + 'attn_type': 'grouped_query_attention', + 'attn_impl': attention_impl, + 'attn_uses_sequence_id': True, + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + } + attn_config['kv_n_heads'] = 4 + + test_cfg.model.init_device = 'cpu' + test_cfg.model.init_config = { + 'name': 'baseline_', + 'init_std': 0.02, + } + test_cfg.model.attn_config = attn_config + test_cfg.model.n_layers = 2 + test_cfg.model.n_heads = 8 + test_cfg.model.d_model = 128 + + test_cfg = dict(om.to_container(test_cfg, resolve=True)) # type: ignore + + model = build_composer_model( + name=test_cfg['model']['name'], + cfg=test_cfg['model'], + tokenizer=gptt, + ) + assert model.shift_labels == True + model.shift_labels = shift_labels + + model = composer_device.module_to_device(model) + + with get_precision_context('amp_bf16'): + output = model(batch) + loss = model.loss(output, batch) + + assert isinstance(loss, torch.Tensor) + + callback = construct_from_registry( + name='loss_perp_v_len', + registry=registry.callbacks, + kwargs={ + 'log_batch_interval': 100, + 'compute_batch_interval': 1, + }, + ) + + callback.loss_perp_v_len = callback.loss_perp_v_len.to(loss.device) + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceGPU(), + ) + logger = Logger(state) + state.outputs = output + state.batch = batch + + callback.after_backward(state, logger) + current_metric_dict = callback.loss_perp_v_len.compute() + + mean_loss_seq_id = torch.sum( + current_metric_dict['mean_loss_seq_id_v_len'] * + current_metric_dict['sum_length_seq_id'], + ) / torch.sum(current_metric_dict['sum_length_seq_id']) + mean_loss = torch.sum( + current_metric_dict['mean_loss_v_len'] * + current_metric_dict['sum_length'], + ) / torch.sum(current_metric_dict['sum_length']) + assert torch.allclose(loss, mean_loss_seq_id) + assert torch.allclose(loss, mean_loss)