diff --git a/README.md b/README.md index 59869ba4bc..f51de3fe2d 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,9 @@ Tutorial videos from the community: Something missing? Contribute with a PR! # Latest News +* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators) +* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus) +* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250) * [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k) * [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250) * [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b) @@ -186,6 +189,12 @@ Notes: 1. `attn_impl: triton` does not work. 1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue. +### Intel Gaudi +Support for LLM Foundry on Intel Gaudi devices is experimental, please use the branch `habana_alpha` and see the [README on that branch](https://github.com/mosaicml/llm-foundry/blob/habana_alpha) which has [install instructions and known issues.](https://github.com/mosaicml/llm-foundry/tree/habana_alpha?tab=readme-ov-file#intel-gaudi) + +For training and inference performance results on Intel Gaudi2 accelerators, see our blog: https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators + + # Quickstart > **Note** diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c79537c781..491d510188 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -8,7 +8,7 @@ import os import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -32,31 +32,40 @@ class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. Args: - save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that - this would be the same as your save_folder. - save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be - saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. - Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, + save_folder (str): Top level folder to save checkpoints to (can be a + URI). It is likely that this would be the same as your save_folder. + save_interval: Union[str, int, Time]: The interval describing how often + checkpoints should be saved. If an integer, it will be assumed to be + in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either + :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. - huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. - precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. + huggingface_folder_name (str): Folder to save each checkpoint under (can + be a format string). Default is ``ba{batch}``. + precision: The precision to save the model in. Default is ``float32``. + Options are ``bfloat16``, ``float16``, or ``float32``. overwrite (bool): Whether to overwrite previous checkpoints. - mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not - be registered. Default is ``None``. - mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. - Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and + mlflow_registered_model_name (Optional[str]): The name to register the + model under in the MLflow model registry. If ``None``, the model + will not be registered. Default is ``None``. + mlflow_logging_config (Optional[dict]): A dictionary of config arguments + that will get passed along to the MLflow ``save_model`` call. + Expected to contain ``metadata`` and ``task`` keys. If either is + unspecified, the defaults are ``'text-generation'`` and ``{'task': 'llm/v1/completions'}`` respectively. + flatten_imports (Sequence[str]): A sequence of import prefixes that will + be flattened when editing MPT files. """ def __init__( - self, - save_folder: str, - save_interval: Union[str, int, Time], - huggingface_folder_name: str = 'ba{batch}', - precision: str = 'float32', - overwrite: bool = True, - mlflow_registered_model_name: Optional[str] = None, - mlflow_logging_config: Optional[dict] = None, + self, + save_folder: str, + save_interval: Union[str, int, Time], + huggingface_folder_name: str = 'ba{batch}', + precision: str = 'float32', + overwrite: bool = True, + mlflow_registered_model_name: Optional[str] = None, + mlflow_logging_config: Optional[dict] = None, + flatten_imports: Sequence[str] = ('llmfoundry',), ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -66,6 +75,7 @@ def __init__( 'float16': torch.float16, 'bfloat16': torch.bfloat16, }[precision] + self.flatten_imports = flatten_imports # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -91,7 +101,7 @@ def __init__( if isinstance(save_interval, int): save_interval = Time(save_interval, TimeUnit.EPOCH) - self.save_interval = save_interval + self.save_interval: Time = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( @@ -229,7 +239,10 @@ def _save_checkpoint(self, state: State, logger: Logger): # Only need to edit files for MPT because it has custom code if original_model.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility(temp_save_dir) + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) if self.remote_ud is not None: log.info(f'Uploading HuggingFace formatted checkpoint') diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 083cd48069..1c0894a451 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -325,9 +325,10 @@ def get_tokens_per_batch_func( """ def get_num_samples_in_batch(batch: Batch) -> int: - if not isinstance(batch, Mapping) or 'attention_mask' not in batch: + if not isinstance(batch, Mapping) or ('attention_mask' not in batch and + 'input_ids' not in batch): raise ValueError( - 'get_tokens_per_batch_func() requires a batch with an attention_mask key' + 'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key' ) if not decoder_only and 'decoder_attention_mask' not in batch: @@ -336,7 +337,10 @@ def get_num_samples_in_batch(batch: Batch) -> int: ) # Count number of non padding tokens in batch - input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) + if 'attention_mask' in batch: + input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) + else: + input_ids_tokens = batch['input_ids'].numel() # For encoder decoder models only decoder_input_ids_tokens = 0 diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 560e8c31fc..fa3e109bf8 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -117,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) -class MPTGeGLU(MPTMLP): +class MPTGLU(MPTMLP): def __init__( self, @@ -138,19 +138,19 @@ def __init__( device=device, bias=bias, ) - self.gate = FC_CLASS_REGISTRY[fc_type]( + self.gate_proj = FC_CLASS_REGISTRY[fc_type]( d_model, self.up_proj.out_features, **self.fc_kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x)) * self.gate(x)) + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, - 'mptgeglu': MPTGeGLU, + 'mptglu': MPTGLU, } if te is not None: @@ -169,10 +169,10 @@ def build_ffn( **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptgeglu']: + if ffn_type in ['mptmlp', 'mptglu']: if len(kwargs) > 0: raise ValueError( - f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}' + f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' ) return FFN_CLASS_REGISTRY[ffn_type]( d_model=d_model, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 6c4c286712..ae4754108c 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -107,7 +107,7 @@ def __init__( factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: - ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp + ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. @@ -291,7 +291,13 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']: + if self.ffn_config['ffn_type'] == 'mptgeglu': + raise ValueError( + 'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. ' + + + 'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.' + ) + elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8c134e2b9f..4c80b10ed9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -63,7 +63,7 @@ # Otherwise, certain modules are missing. # isort: off from llmfoundry.models.utils.adapt_tokenizer import ( - AutoTokenizerForMOD, # type: ignore (see note), + AutoTokenizerForMOD, # type: ignore (see note) adapt_tokenizer_for_denoising, # type: ignore (see note) ) from llmfoundry.models.utils.hf_prefixlm_converter import ( diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 9d1d1dda71..f76d29b1c7 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -235,9 +235,9 @@ def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True): self._f_encode = None self._f_decode = None if self._try_quantize: - from turbo import dequantize8b, quantize8b - self._f_encode = quantize8b - self._f_decode = dequantize8b + from turbo import dequantize_signed, quantize_signed + self._f_encode = quantize_signed + self._f_decode = dequantize_signed if data is not None: self.set_data(data) @@ -277,7 +277,7 @@ def set_data(self, data: torch.Tensor) -> None: f'on device {data.device} with shape {data.shape}.') self.data = None assert self._f_encode is not None # pyright - self.quantized, self.scales = self._f_encode(data) + self.quantized, self.scales, _ = self._f_encode(data) else: self.data = data.to(dtype=torch.float32) self.quantized = None diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 35e77eab6c..dafeec94e1 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -112,7 +112,8 @@ def load_tokenizer( def _write_zero_bias(weight_name: str, weight_file_path: str, - bias_shape: Union[Tuple[int, ...], int]) -> None: + bias_shape: Union[Tuple[int, ...], + int], np_data_type: np.dtype) -> None: """Write zeros for bias when converting MPT to FasterTransformer weights. MPT model might not have bias while FT expects bias. @@ -121,6 +122,7 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, weight_name (str): Name of the weight tensor. weight_file_path (str): Output path for storing the weight (NOT zero bias). bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array. + np_data_type (np.dtype): The data type for bias. """ if 'weight' not in weight_file_path: raise RuntimeError( @@ -128,13 +130,14 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, ) log.debug(f'zero bias for weight: {weight_name}') bias_file_path = weight_file_path.replace('.weight', '.bias') - bias = np.zeros(bias_shape, dtype=np.float32) + bias = np.zeros(bias_shape, dtype=np_data_type) bias.tofile(bias_file_path) def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, tensor_name: str, config: Dict[str, Any], - data: np.ndarray) -> None: + data: np.ndarray, + np_weight_data_type: np.dtype) -> None: """Convert each MPT weight to a FasterTransformer compatible format. Args: @@ -155,7 +158,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.bin') data.tofile(save_path) if 'weight' in tensor_name and config['no_bias']: - _write_zero_bias(tensor_name, save_path, data.shape[-1]) + _write_zero_bias(tensor_name, save_path, data.shape[-1], + np_weight_data_type + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('attention.dense.weight') != -1: assert data.shape == ( @@ -170,11 +175,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, if config['no_bias']: fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin') - _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) + _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], + np_weight_data_type + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1: assert data.shape == ( - config['d_model'], config['mlp_ratio'] * + config['d_model'], config['expansion_ratio'] * config['d_model']), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T @@ -185,11 +192,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, if config['no_bias']: fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin') - _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) + _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], + np_weight_data_type + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1: assert data.shape == ( - config['mlp_ratio'] * config['d_model'], + config['expansion_ratio'] * config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T @@ -200,11 +209,12 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, split_vals[j].tofile(save_path) if config['no_bias']: _write_zero_bias(tensor_name, save_path, - split_vals[j].shape[-1]) + split_vals[j].shape[-1], np_weight_data_type + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1: assert data.shape == ( - config['mlp_ratio'] * + config['expansion_ratio'] * config['d_model'],), f'unexpected dim for {tensor_name}' split_vals = np.split(data, infer_gpu_num, axis=-1) for j in range(infer_gpu_num): @@ -238,7 +248,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, split_vals[j].tofile(save_path) if config['no_bias']: _write_zero_bias(tensor_name, save_path, - (3, split_vals[j].shape[-1])) + (3, split_vals[j].shape[-1]), + np_weight_data_type + ) # pyright: ignore [reportGeneralTypeIssues] else: raise RuntimeError(f'Tensor with name {tensor_name} is not handled') @@ -306,7 +318,12 @@ def convert_and_save_ft_weights(named_params: dict, 'model.final_layernorm.weight.bin') data.tofile(save_path) if config['no_bias']: - _write_zero_bias(name, save_path, data.shape[-1]) + _write_zero_bias( + name, + save_path, + data.shape[-1], + np_weight_data_type # pyright: ignore [reportGeneralTypeIssues] + ) elif name == 'transformer.lm_head.weight': data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin')) else: @@ -315,5 +332,11 @@ def convert_and_save_ft_weights(named_params: dict, new_name = name.replace('transformer.blocks.', 'layers.').replace( mpt_pattern, ft_pattern) - _convert_weight_to_ft_each(save_dir, infer_gpu_num, - new_name, config, data) + _convert_weight_to_ft_each( + save_dir, + infer_gpu_num, + new_name, + config, + data, + np_weight_data_type # pyright: ignore [reportGeneralTypeIssues] + ) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6680154e87..55576eaba0 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,6 +120,14 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) + # Always set `sync_module_states` to True when using hybrid sharding + if fsdp_config is not None and \ + fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \ + and not fsdp_config.get('sync_module_states', False): + warnings.warn( + ('Setting `sync_module_states = True` for FSDP. This is required ' + 'when using hybrid sharding.')) + fsdp_config['sync_module_states'] = True # no mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 47d7f79bff..a74ab1cc35 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -4,14 +4,14 @@ import ast import importlib import os -from typing import List, Optional +from typing import Optional, Sequence __all__ = ['edit_files_for_hf_compatibility'] class DeleteSpecificNodes(ast.NodeTransformer): - def __init__(self, nodes_to_remove: List[ast.AST]): + def __init__(self, nodes_to_remove: list[ast.AST]): self.nodes_to_remove = nodes_to_remove def visit(self, node: ast.AST) -> Optional[ast.AST]: @@ -39,7 +39,26 @@ def find_module_file(module_name: str) -> str: return module_file -def process_file(file_path: str, folder_path: str) -> List[str]: +def _flatten_import( + node: ast.ImportFrom, + flatten_imports_prefix: Sequence[str], +) -> bool: + """Returns True if import should be flattened. + + Checks whether the node starts the same as any of the imports in + flatten_imports_prefix. + """ + for import_prefix in flatten_imports_prefix: + if node.module is not None and node.module.startswith(import_prefix): + return True + return False + + +def process_file( + file_path: str, + folder_path: str, + flatten_imports_prefix: Sequence[str], +) -> list[str]: with open(file_path, 'r') as f: source = f.read() @@ -51,37 +70,35 @@ def process_file(file_path: str, folder_path: str) -> List[str]: new_files_to_process = [] nodes_to_remove = [] for node in ast.walk(tree): - # convert any llmfoundry imports into relative imports - if isinstance( - node, ast.ImportFrom - ) and node.module is not None and node.module.startswith('llmfoundry'): + # Convert any llmfoundry imports into relative imports + if (isinstance(node, ast.ImportFrom) and node.module is not None and + _flatten_import(node, flatten_imports_prefix)): module_path = find_module_file(node.module) node.module = convert_to_relative_import(node.module, parent_module_name) - # recursively process any llmfoundry files + # Recursively process any llmfoundry files new_files_to_process.append(module_path) - # remove any imports from composer or omegaconf + # Remove any imports from composer or omegaconf elif isinstance(node, ast.ImportFrom) and node.module is not None and ( node.module.startswith('composer') or node.module.startswith('omegaconf')): nodes_to_remove.append(node) - # remove the Composer* class - elif isinstance(node, - ast.ClassDef) and node.name.startswith('Composer'): + # Remove the Composer* class + elif (isinstance(node, ast.ClassDef) and + node.name.startswith('Composer')): nodes_to_remove.append(node) - # remove the __all__ declaration in any __init__.py files, whose enclosing module - # will be converted to a single file of the same name - elif isinstance(node, - ast.Assign) and len(node.targets) == 1 and isinstance( - node.targets[0], - ast.Name) and node.targets[0].id == '__all__': + # Remove the __all__ declaration in any __init__.py files, whose + # enclosing module will be converted to a single file of the same name + elif (isinstance(node, ast.Assign) and len(node.targets) == 1 and + isinstance(node.targets[0], ast.Name) and + node.targets[0].id == '__all__'): nodes_to_remove.append(node) transformer = DeleteSpecificNodes(nodes_to_remove) new_tree = transformer.visit(tree) new_filename = os.path.basename(file_path) - # special case for __init__.py to mimic the original submodule + # Special case for __init__.py to mimic the original submodule if new_filename == '__init__.py': new_filename = file_path.split('/')[-2] + '.py' new_file_path = os.path.join(folder_path, new_filename) @@ -92,7 +109,10 @@ def process_file(file_path: str, folder_path: str) -> List[str]: return new_files_to_process -def edit_files_for_hf_compatibility(folder: str) -> None: +def edit_files_for_hf_compatibility( + folder: str, + flatten_imports_prefix: Sequence[str] = ('llmfoundry',), +) -> None: files_to_process = [ os.path.join(folder, filename) for filename in os.listdir(folder) @@ -103,7 +123,7 @@ def edit_files_for_hf_compatibility(folder: str) -> None: while len(files_to_process) > 0: to_process = files_to_process.pop() if os.path.isfile(to_process) and to_process.endswith('.py'): - to_add = process_file(to_process, folder) + to_add = process_file(to_process, folder, flatten_imports_prefix) for file in to_add: if file not in files_processed_and_queued: files_to_process.append(file) diff --git a/setup.py b/setup.py index c030fe3268..8122bbb14f 100644 --- a/setup.py +++ b/setup.py @@ -47,15 +47,15 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.36,<4.37', - 'mosaicml-streaming>=0.7.1,<0.8', + 'mosaicml-streaming>=0.7.2,<0.8', 'torch>=2.1,<2.1.1', 'datasets==2.15.0', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', - 'einops==0.5.0', + 'einops==0.7.0', 'omegaconf>=2.2.3,<3', 'slack-sdk<4', 'mosaicml-cli>=0.5.27,<1', @@ -84,22 +84,22 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17.1,<0.18', + 'mosaicml[databricks]>=0.17.2,<0.18', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.17.1,<0.18', + 'mosaicml[tensorboard]>=0.17.2,<0.18', ] extra_deps['gpu'] = [ 'flash-attn==1.0.9', - 'mosaicml-turbo==0.0.4', + 'mosaicml-turbo==0.0.7', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['gpu-flash2'] = [ 'flash-attn==2.3.6', - 'mosaicml-turbo==0.0.4', + 'mosaicml-turbo==0.0.7', ] extra_deps['peft'] = [ diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index ec0cff491d..68b7b98924 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -691,16 +691,17 @@ def test_token_counting_func(pad_token_id: int, batch_size: int, assert actual_token_count == expected_token_count -@pytest.mark.parametrize( - 'dataloader_type', - ['finetuning-hf', 'finetuning-streaming', 'denoising', 'text']) +@pytest.mark.parametrize('dataloader_type,tensor_input', + [('finetuning-hf', False), + ('finetuning-streaming', False), ('denoising', False), + ('text', True), ('text', False)]) @pytest.mark.parametrize('pad_token_id', [100, None]) @pytest.mark.parametrize('batch_size', [1, 8]) @pytest.mark.parametrize('model_max_length', [1024]) @pytest.mark.parametrize('padding_side', ['left']) def test_token_counting_func_dataloader_setting( - dataloader_type: str, pad_token_id: Optional[int], batch_size: int, - model_max_length: int, padding_side: str, + dataloader_type: str, tensor_input: bool, pad_token_id: Optional[int], + batch_size: int, model_max_length: int, padding_side: str, monkeypatch: pytest.MonkeyPatch): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id @@ -710,9 +711,11 @@ def test_token_counting_func_dataloader_setting( batch_strings = [] expected_token_count = 0 for _ in range(batch_size): + # Get randomly different lengths if we are going to add padding sample_length = random.randint( 1, model_max_length // - 4) if pad_token_id is not None else model_max_length // 4 + 4) if (pad_token_id is not None and + not tensor_input) else model_max_length // 4 batch_strings.append(' '.join(['hello'] * sample_length)) expected_token_count += sample_length @@ -721,13 +724,18 @@ def test_token_counting_func_dataloader_setting( for b in batch_strings ] + if tensor_input: + batch_tokenized = [ + torch.tensor(b['input_ids']) for b in batch_tokenized + ] + if dataloader_type == 'denoising': expected_token_count += 2 * batch_size # for the two eos tokens expected_token_count += 5 * batch_size # for the corruption prefix tokens if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}: for b in batch_tokenized: - b['labels'] = b['input_ids'].copy() + b['labels'] = b['input_ids'].copy() # type: ignore expected_token_count *= 2 expected_token_count += 1 * batch_size # for the eos token diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2419dbfa41..7bccad089d 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -350,7 +350,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) +@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu']) @pytest.mark.parametrize('ffn_act_fn', [ None, { diff --git a/tests/utils/test_huggingface_hub_utils.py b/tests/utils/test_huggingface_hub_utils.py new file mode 100644 index 0000000000..5effb3a771 --- /dev/null +++ b/tests/utils/test_huggingface_hub_utils.py @@ -0,0 +1,16 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import ast + +from llmfoundry.utils.huggingface_hub_utils import _flatten_import + + +def test_flatten_import_true(): + node = ast.ImportFrom('y', ['x', 'y', 'z']) + assert _flatten_import(node, ('x', 'y', 'z')) + + +def test_flatten_import_false(): + node = ast.ImportFrom('y', ['x', 'y', 'z']) + assert not _flatten_import(node, ('x', 'z'))