diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 976b4241ab..c3fc9168ee 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -20,11 +20,9 @@ jobs: - name: "2.4.0_cu124" base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 dep_groups: "[all]" - te_commit: 901e5d2 - name: "2.4.0_cu124_aws" base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04-aws dep_groups: "[all]" - te_commit: 901e5d2 steps: - name: Checkout @@ -91,4 +89,3 @@ jobs: BRANCH_NAME=${{ github.head_ref || github.ref_name }} BASE_IMAGE=${{ matrix.base_image }} DEP_GROUPS=${{ matrix.dep_groups }} - TE_COMMIT=${{ matrix.te_commit }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3617732c8f..15c83035e0 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -95,7 +95,6 @@ jobs: build-args: | BASE_IMAGE=mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04-aws BRANCH_NAME=${{ env.BRANCH_NAME }} - TE_COMMIT=901e5d2 DEP_GROUPS=[all] KEEP_FOUNDRY=true @@ -111,6 +110,5 @@ jobs: build-args: | BASE_IMAGE=mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 BRANCH_NAME=${{ env.BRANCH_NAME }} - TE_COMMIT=901e5d2 DEP_GROUPS=[all] KEEP_FOUNDRY=true diff --git a/Dockerfile b/Dockerfile index a9d44bfa27..f2566cd3cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,6 @@ FROM $BASE_IMAGE ARG BRANCH_NAME ARG DEP_GROUPS -ARG TE_COMMIT ARG KEEP_FOUNDRY=false ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" @@ -16,9 +15,6 @@ ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py setup.py RUN rm setup.py -# Install TransformerEngine -RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@$TE_COMMIT - # Install and uninstall foundry to cache foundry requirements RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}" diff --git a/llmfoundry/_version.py b/llmfoundry/_version.py index 0cddcaf967..e6385a53a7 100644 --- a/llmfoundry/_version.py +++ b/llmfoundry/_version.py @@ -3,4 +3,4 @@ """The LLM Foundry Version.""" -__version__ = '0.13.0.dev0' +__version__ = '0.14.0.dev0' diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index a9edf4840b..161178a872 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -234,6 +234,7 @@ def __init__( mlflow_logging_config: Optional[dict] = None, flatten_imports: Sequence[str] = ('llmfoundry',), final_register_only: bool = False, + register_wait_seconds: int = 7200, ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -247,6 +248,7 @@ def __init__( self.using_peft = False self.final_register_only = final_register_only + self.register_wait_seconds = register_wait_seconds self.mlflow_registered_model_name = mlflow_registered_model_name if self.final_register_only and self.mlflow_registered_model_name is None: @@ -379,7 +381,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. - timeout = 3600 + timeout = self.register_wait_seconds wait_start = time.time() while not self._all_register_processes_done(state.device): wait_time = time.time() - wait_start diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py index 2667407110..3d54da6057 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py @@ -451,6 +451,7 @@ def convert_dataset_hf_from_args( ValueError: If the output directory already contains the requested splits ValueError: If `concat_tokens` is set but `tokenizer` is not """ + os.environ['WORLD_SIZE'] = '1' if tokenizer_kwargs: parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) else: diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py index c6f7d51c02..918ce7e108 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_json.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -186,6 +186,7 @@ def convert_dataset_json_from_args( ValueError: If the out_root directory exists and contains files that overlap with the requested splits ValueError: If concat_tokens is set and a tokenizer is not provided """ + os.environ['WORLD_SIZE'] = '1' if os.path.isdir(out_root) and len( set(os.listdir(out_root)).intersection(set(split)), ) > 0: diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 2321d306ff..fb1ee1d0ca 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -550,6 +550,9 @@ def validate_and_get_cluster_info( ).upper()[len('DATASECURITYMODE.'):] # NONE stands for No Isolation Shared + # This check actually checks for Unity Catalog governance compatibility and does not + # check for invalid cluster access for a particular user. Cluster access controls is + # difficult and there is no single existing API to check this. if data_security_mode == 'NONE': raise ClusterInvalidAccessMode( cluster_id=cluster_id, @@ -767,6 +770,7 @@ def convert_delta_to_json_from_args( use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless json_output_filename (str): The name of the combined final jsonl that combines all partitioned jsonl """ + os.environ['WORLD_SIZE'] = '1' _check_imports() from databricks.sdk import WorkspaceClient w = WorkspaceClient() diff --git a/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py index bb1197de57..cbd1bd275d 100644 --- a/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py +++ b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py @@ -309,6 +309,7 @@ def convert_finetuning_dataset_from_args( ValueError: If the target settings are invalid. ValueError: If the output directory already contains the requested splits. """ + os.environ['WORLD_SIZE'] = '1' if os.path.isdir(out_root) and len( set(os.listdir(out_root)).intersection(set(splits)), ) > 0: diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py index 9de13f9d5b..11eac121d0 100644 --- a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -557,6 +557,7 @@ def convert_text_to_mds_from_args( Raises: ValueError: If `use_tokenizer_eos` is True and `eos_text` is not None """ + os.environ['WORLD_SIZE'] = '1' if use_tokenizer_eos: # Ensure that eos text is not specified twice. if eos_text is not None: diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index 73127e8a07..f25f2b5cef 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -4,7 +4,6 @@ import logging import os import time -import warnings from typing import Any, Optional, Union import pandas as pd @@ -37,7 +36,6 @@ process_init_device, ) from llmfoundry.utils.registry_utils import import_file -from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -63,7 +61,6 @@ def evaluate_model( callback_configs: Optional[dict[str, Any]], metadata: Optional[dict[str, str]], logged_config: dict[str, Any], - fsdp_config: Optional[dict[str, Any]] = None, parallelism_config: Optional[dict[str, Any]] = None, should_log_config: bool = True, load_path: Optional[str] = None, @@ -78,18 +75,6 @@ def evaluate_model( 'parallelism_config cannot contain deprecated fsdp_config arguments.', ) - if fsdp_config: - warnings.warn( - VersionedDeprecationWarning( - 'The argument fsdp_config is deprecated. Please use parallelism_config instead.', - remove_version='0.14.0', - ), - ) - if fsdp_config and parallelism_config: - raise ValueError( - 'Both fsdp_config and parallelism_config cannot be provided at the same time. Please use parallelism_config.', - ) - log.info(f'Evaluating model: {model_name}') # Build tokenizer and model tokenizer_cfg = tokenizer @@ -125,9 +110,9 @@ def evaluate_model( mosaicml_logger._flush_metadata(force_flush=True) fsdp_config = parallelism_config.get( - 'fsdp_config', + 'fsdp', None, - ) if parallelism_config else fsdp_config + ) if parallelism_config else None if fsdp_config and model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + @@ -175,7 +160,7 @@ def evaluate_model( callbacks=callbacks, loggers=loggers, precision=precision, - parallelism_config={'fsdp': fsdp_config}, + parallelism_config=parallelism_config, load_path=load_path, load_weights_only=True, progress_bar=False, @@ -268,8 +253,6 @@ def evaluate(cfg: DictConfig) -> tuple[list[Trainer], pd.DataFrame]: model_configs = eval_config.models eval_gauntlet_config = eval_config.eval_gauntlet or eval_config.eval_gauntlet_str - fsdp_config = eval_config.fsdp_config - # Mandatory Evaluation Parameters icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str if icl_tasks is None: @@ -345,9 +328,9 @@ def evaluate(cfg: DictConfig) -> tuple[list[Trainer], pd.DataFrame]: device_eval_batch_size=eval_config.device_eval_batch_size, eval_gauntlet_config=eval_gauntlet_config, eval_loader_config=eval_loader_config, - fsdp_config=fsdp_config, loggers=loggers, python_log_level=eval_config.python_log_level, + parallelism_config={'fsdp': eval_config.fsdp_config}, precision=eval_config.precision, eval_gauntlet_df=eval_gauntlet_df, callback_configs=eval_config.callbacks, diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index f647565386..cb287b029c 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -311,8 +311,10 @@ def train(cfg: DictConfig) -> Trainer: eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name + env_run_name: Optional[str] = os.environ.get('RUN_NAME', None) + run_name: str = ( + train_cfg.run_name if train_cfg.run_name else env_run_name + ) or 'llm' is_state_dict_sharded: bool = ( fsdp_config.get('state_dict_type', 'full') == 'sharded' ) if fsdp_config else False @@ -320,8 +322,10 @@ def train(cfg: DictConfig) -> Trainer: save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' # Enable autoresume from model checkpoints if possible + is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None autoresume_default: bool = False - if train_cfg.save_folder is not None \ + if is_user_set_run_name and \ + train_cfg.save_folder is not None \ and not train_cfg.save_overwrite \ and not train_cfg.save_weights_only: autoresume_default = True @@ -588,6 +592,8 @@ def train(cfg: DictConfig) -> Trainer: profiler=profiler, compile_config=compile_config, spin_dataloaders=train_cfg.spin_dataloaders, + accumulate_train_batch_on_tokens=train_cfg. + accumulate_train_batch_on_tokens, ) _sort_callbacks(trainer) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 92bbac561d..7c9d149fea 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -198,7 +198,8 @@ def build_finetuning_dataloader( allowed_dataset_config_keys = set( dataset_constructor_keys, ).union(_ALLOWED_DATASET_KEYS) - _validate_config( + + extraneous_keys = _validate_config( **dataset_cfg, allowed_dataset_keys=allowed_dataset_config_keys, ) @@ -253,13 +254,13 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # Take the constructor args from above, minus args that have been created separately dataset_constructor_args = { k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and - k not in {'streams', 'packing_ratio'} + if k in set(dataset_constructor_keys).union(extraneous_keys) and + k not in {'streams', 'packing_ratio', 'replication'} } + streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, @@ -366,7 +367,7 @@ def build_finetuning_dataloader( def _validate_config( max_seq_len: int, - decoder_only_format: bool = False, + decoder_only_format: Optional[bool] = None, hf_name: Optional[str] = None, local: Optional[str] = None, remote: Optional[str] = None, @@ -378,7 +379,7 @@ def _validate_config( target_responses: Optional[str] = None, allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: dict[str, Any], -) -> None: +) -> set[str]: """Validates the dataset configuration. Makes sure that the dataset is properly configured for either @@ -389,7 +390,7 @@ def _validate_config( max_seq_len (int): The maximum length of sequences in the batch. See :class:`Seq2SeqFinetuningCollator` docstring for details. - decoder_only_format (bool): Whether to format the + decoder_only_format (bool, optional): Whether to format the examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` docstring for details. hf_name (str, optional): The name of the HuggingFace dataset @@ -434,11 +435,21 @@ def _validate_config( Raises: ValueError: If the dataset configuration does not meet the requirements. + + Returns: + set[str]: Return the extraneous keys. """ - if not set(kwargs.keys()).issubset(allowed_dataset_keys): + if decoder_only_format is None: raise ValueError( + f'decoder_only_format must be set to either True or False, but it was {decoder_only_format}.', + ) + + extraneous_keys = set() + if not set(kwargs.keys()).issubset(allowed_dataset_keys): + extraneous_keys = set(kwargs.keys()) - allowed_dataset_keys + log.warning( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_dataset_keys), + ', '.join(extraneous_keys), ) if hf_name is not None: @@ -456,7 +467,7 @@ def _validate_config( 'Those keys are used when building from a streaming dataset, but ' +\ 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.', ) - elif remote is not None: + elif remote is not None or local is not None: # Using the streaming dataset codepath illegal_keys = { 'hf_name': hf_name, @@ -533,6 +544,8 @@ def _validate_config( decoder_only_format, ) + return extraneous_keys + def _download_remote_hf_dataset(remote_path: str, split: str) -> str: """Downloads a dataset from a remote object store. diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 179f017fd9..3d9ed056ef 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -515,6 +515,22 @@ def is_valid_ift_example( return True +def _get_num_processes() -> int: + """Get the number of processes to use for dataset processing.""" + detected_cpu_count = os.cpu_count() or 1 + detected_cpus_with_margin = detected_cpu_count - 8 + num_proc = max(1, detected_cpus_with_margin) + + # Check if the user has set the MAX_NUM_PROC environment variable + # which caps the number of processes used for dataset processing. + if 'MAX_NUM_PROC' in os.environ: + max_num_proc_env = int(os.environ['MAX_NUM_PROC']) + if max_num_proc_env < num_proc: + num_proc = max_num_proc_env + + return num_proc + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. @@ -613,11 +629,6 @@ def __init__( **kwargs: Any, ): - if len(kwargs) > 0: - raise ValueError( - f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}', - ) - if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: raise ValueError( f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', @@ -658,6 +669,7 @@ def __init__( batching_method=batching_method, allow_unsafe_types=allow_unsafe_types, replication=replication, + **kwargs, ) self.tokenizer = tokenizer @@ -964,18 +976,16 @@ def dataset_mapper(example: dict): ) return mapping_fn(example, tokenizer) - detected_cpu_count = os.cpu_count() or 1 - detected_cpus_with_margin = detected_cpu_count - 8 - num_cpus_to_use = max(1, detected_cpus_with_margin) - if len(dataset) < num_cpus_to_use: - num_cpus_to_use = 1 + num_proc = _get_num_processes() + if len(dataset) < num_proc: + num_proc = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( dataset_mapper, batched=False, remove_columns=columns_to_remove, - num_proc=num_cpus_to_use, + num_proc=num_proc, desc='Tokenizing dataset', ) @@ -987,7 +997,7 @@ def dataset_mapper(example: dict): target_responses, decoder_only_format, ), - num_proc=num_cpus_to_use, + num_proc=num_proc, desc='Filtering out long prompts', ) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 3ce248e69f..37d4c32b23 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -138,11 +138,6 @@ def __init__( **kwargs: Any, ): - if len(kwargs) > 0: - raise ValueError( - f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}', - ) - if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: raise ValueError( f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', @@ -188,6 +183,7 @@ def __init__( batching_method=batching_method, allow_unsafe_types=allow_unsafe_types, replication=replication, + **kwargs, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len @@ -332,10 +328,13 @@ def build_text_dataloader( StreamingTextDataset, ).parameters + valid_base_dataset_params = inspect.signature(StreamingDataset,).parameters + dataset_config_subset_for_streaming_text_dataset = { k: v for k, v in dataset_cfg.items() - if k in valid_streaming_text_dataset_parameters + if k in valid_streaming_text_dataset_parameters or + k in valid_base_dataset_params } # build dataset potentially with streams diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py index 2f25f92940..03df90e8cd 100644 --- a/llmfoundry/models/hf/__init__.py +++ b/llmfoundry/models/hf/__init__.py @@ -9,7 +9,6 @@ prepare_hf_model_for_fsdp, ) from llmfoundry.models.hf.hf_t5 import ComposerHFT5 -from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP __all__ = [ 'BaseHuggingFaceModel', @@ -18,5 +17,4 @@ 'prepare_hf_causal_lm_model_for_fsdp', 'prepare_hf_enc_dec_model_for_fsdp', 'prepare_hf_model_for_fsdp', - 'HuggingFaceModelWithFSDP', ] diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py deleted file mode 100644 index f2b67db1ec..0000000000 --- a/llmfoundry/models/hf/model_wrapper.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" - -from __future__ import annotations - -import warnings -from collections import UserDict -from typing import TYPE_CHECKING, Mapping, Optional, Union - -import transformers -from composer.models.huggingface import HuggingFaceModel -from torchmetrics import Metric -from transformers import PreTrainedTokenizerBase -from transformers.utils.generic import ModelOutput - -from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp -from llmfoundry.utils.warnings import VersionedDeprecationWarning - -if TYPE_CHECKING: - from peft import PeftConfig, PeftModel - -__all__ = ['HuggingFaceModelWithFSDP'] - -# HuggingFace hardcodes the ignore index to -100 -_HF_IGNORE_INDEX = -100 - - -class HuggingFaceModelWithFSDP(HuggingFaceModel): - """Wrapper around HuggingFaceModel. - - Handles preparation for FSDP wrapping. - """ - - def __init__( - self, - model: Union[transformers.PreTrainedModel, 'PeftModel'], - tokenizer: Optional[PreTrainedTokenizerBase] = None, - metrics: Optional[list[Metric]] = None, - eval_metrics: Optional[list[Metric]] = None, - shift_labels: bool = False, - allow_embedding_resizing: bool = False, - init_device: Optional[str] = None, - peft_config: Optional['PeftConfig'] = None, - should_save_peft_only: bool = True, - ): - warnings.warn( - VersionedDeprecationWarning( - '`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.', - remove_version='0.14.0', - ), - ) - super().__init__( - model, - tokenizer, - use_logits=True, - metrics=metrics, - eval_metrics=eval_metrics, - shift_labels=shift_labels, - allow_embedding_resizing=allow_embedding_resizing, - peft_config=peft_config, - should_save_peft_only=should_save_peft_only, - ) - - self.prepare_inner_model(self.model, init_device) - - def forward(self, batch: Mapping): - if isinstance(batch, dict) or isinstance(batch, UserDict): - # Further input validation is left to the huggingface forward call - batch = { - k: v for k, v in batch.items() if k in self.model_forward_args - } - output = self.model(**batch) # type: ignore (thirdparty) - else: - raise ValueError( - 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model', - ) - return output - - def loss(self, outputs: ModelOutput, batch: Mapping): - if self.config.use_return_dict: - return outputs['loss'] - # loss is at index 0 in the output tuple, logits are at index 1 - return outputs[:2] - - @staticmethod - def prepare_inner_model( - model: Union[transformers.PreTrainedModel, 'PeftModel'], - init_device: Optional[str] = None, - ): - """Prepare the inner model for FSDP wrapping. - - Args: - model: The model to prepare. - init_device: The device to initialize the model on. - """ - # Note: We need to add the FSDP related attributes to the model AFTER the super init, - # so that the (possible) embedding resizing doesn't destroy them - prepare_hf_model_for_fsdp(model, init_device) - - # This provides support for meta initialization when using FSDP - model.param_init_fn = lambda module: model._init_weights(module) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index dbcabdf5f9..1adb64dc21 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -371,11 +371,8 @@ def _validate_config(self) -> None: del te # unused except: raise ImportError( - 'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' - + - 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' - + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + - 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156', + 'TransformerEngine import failed. `fc_type: te` requires TransformerEngine be installed, ', + 'e.g. pip install transformer-engine[pytorch]', ) self.ffn_config['fc_type'] = self.fc_type diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 18112c18aa..64514c528d 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -121,6 +121,7 @@ class TrainConfig: dist_timeout: Union[int, float] = 600.0 fsdp_config: Optional[dict[str, Any]] = None tp_config: Optional[dict[str, Any]] = None + accumulate_train_batch_on_tokens: bool = False # Evaluation parameters eval_interval: Union[int, str] = 1 diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 1b9feb9a10..905a376ef3 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -310,7 +310,7 @@ def __init__(self, input: dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(NetworkError): +class ClusterDoesNotExistError(UserError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -318,12 +318,12 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message, cluster_id=cluster_id) -class ClusterInvalidAccessMode(NetworkError): +class ClusterInvalidAccessMode(UserError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str, access_mode: str) -> None: - message = f'Cluster with id {cluster_id} has access mode {access_mode}. ' + \ - 'please make sure the cluster used has access mode Shared or Single User!' + message = f'The cluster you have provided: {cluster_id} does not have data governance enabled.' + \ + 'Please use a cluster with a data security mode other than NONE.' super().__init__( message, cluster_id=cluster_id, diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 6da0d5e605..1fe4d78db0 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -86,6 +86,7 @@ def experimental_class( def class_decorator(cls: Type): # noqa: UP006 original_init = cls.__init__ + cls.is_experimental = True def new_init(self: Any, *args: Any, **kwargs: Any): warnings.warn(ExperimentalWarning(feature_name)) diff --git a/mcli/mcli-1b-eval.yaml b/mcli/mcli-1b-eval.yaml index bd6a7b538a..da8884d4a3 100644 --- a/mcli/mcli-1b-eval.yaml +++ b/mcli/mcli-1b-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index 1d48cd8105..718a1b1c00 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 71566d4c46..22b22f6572 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-benchmark-mpt.yaml b/mcli/mcli-benchmark-mpt.yaml index 0c023f9a83..f588b00a94 100644 --- a/mcli/mcli-benchmark-mpt.yaml +++ b/mcli/mcli-benchmark-mpt.yaml @@ -11,7 +11,7 @@ image: mosaicml/llm-foundry:2.4.0_cu124-latest integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] diff --git a/mcli/mcli-convert-composer-to-hf.yaml b/mcli/mcli-convert-composer-to-hf.yaml index a211e3baeb..201f4d427e 100644 --- a/mcli/mcli-convert-composer-to-hf.yaml +++ b/mcli/mcli-convert-composer-to-hf.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 9bcebfbea0..bc7081fecc 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-hf-generate.yaml b/mcli/mcli-hf-generate.yaml index 85a0f6b0e4..ea120d73c9 100644 --- a/mcli/mcli-hf-generate.yaml +++ b/mcli/mcli-hf-generate.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 210e8942b5..c5a0942d97 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index 987fc829a9..2f4e3f2a78 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: .[gpu,openai] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index a3e8c40b88..354fb8a8a9 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.12.0 + git_branch: v0.13.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo diff --git a/setup.py b/setup.py index 2b310bcfd0..ae98a36f5d 100644 --- a/setup.py +++ b/setup.py @@ -65,8 +65,8 @@ 'omegaconf>=2.2.3,<3', 'slack-sdk<4', 'mosaicml-cli>=0.6.10,<1', - 'onnx==1.16.2', - 'onnxruntime==1.19.0', + 'onnx==1.17.0', + 'onnxruntime==1.19.2', 'boto3>=1.21.45,<2', 'huggingface-hub>=0.19.0,<0.25', 'beautifulsoup4>=4.12.2,<5', # required for model download utils @@ -119,18 +119,22 @@ ] extra_deps['megablocks'] = [ - 'megablocks==0.6.1', + 'megablocks<1.0', 'grouped-gemm==0.1.6', ] +extra_deps['te'] = [ + 'transformer-engine[pytorch]>=1.11.0,<1.12', +] + extra_deps['databricks-serverless'] = { dep for key, deps in extra_deps.items() for dep in deps - if 'gpu' not in key and 'megablocks' not in key and + if 'gpu' not in key and 'megablocks' not in key and 'te' not in key and 'databricks-connect' not in dep } extra_deps['all-cpu'] = { dep for key, deps in extra_deps.items() for dep in deps - if 'gpu' not in key and 'megablocks' not in key + if 'gpu' not in key and 'megablocks' not in key and 'te' not in key } extra_deps['all'] = { dep for key, deps in extra_deps.items() for dep in deps diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py new file mode 100644 index 0000000000..de91429c85 --- /dev/null +++ b/tests/callbacks/test_callbacks.py @@ -0,0 +1,127 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import typing + +import pytest +from composer.core import Callback + +from llmfoundry.callbacks.async_eval_callback import AsyncEval +from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning +from llmfoundry.interfaces.callback_with_config import CallbackWithConfig +from llmfoundry.registry import callbacks, callbacks_with_config +from llmfoundry.utils.builders import build_callback + +primitive_types = {int, float, str, bool, dict, list} + +# Callbacks that we skip during testing because they require more complex inputs. +# They should be tested separately. +skip_callbacks = [ + AsyncEval, + CurriculumLearning, +] + + +def get_default_value( + param: str, + tpe: type, + inspected_param: typing.Optional[inspect.Parameter], +): + if typing.get_origin(tpe) is typing.Union: + args = typing.get_args(tpe) + return get_default_value(param, args[0], None) + elif typing.get_origin(tpe) is list or typing.get_origin(tpe) is list: + return [] + elif typing.get_origin(tpe) is dict or typing.get_origin(tpe) is dict: + return {} + elif tpe is int: + return 0 + elif tpe is float: + return 0.0 + elif tpe is str: + return '' + elif tpe is bool: + return False + elif tpe is dict: + return {} + elif tpe is list: + return [] + elif inspected_param is not None and tpe is typing.Any and inspected_param.kind is inspect.Parameter.VAR_KEYWORD: + return None + elif inspected_param is not None and tpe is typing.Any and inspected_param.kind is inspect.Parameter.VAR_POSITIONAL: + return None + else: + raise ValueError(f'Unsupported type: {tpe} for parameter {param}') + + +def get_default_kwargs(callback_class: type): + type_hints = typing.get_type_hints(callback_class.__init__) + inspected_params = inspect.signature(callback_class.__init__).parameters + + default_kwargs = {} + + for param, tpe in type_hints.items(): + if param == 'self' or param == 'return' or param == 'train_config': + continue + if inspected_params[param].default == inspect.Parameter.empty: + default_value = get_default_value( + param, + tpe, + inspected_params[param], + ) + if default_value is not None: + default_kwargs[param] = default_value + return default_kwargs + + +def maybe_skip_callback_test(callback_class: type): + if hasattr( + callback_class, + 'is_experimental', + ) and callback_class.is_experimental: # type: ignore + pytest.skip( + f'Skipping test for {callback_class.__name__} because it is experimental.', + ) + if callback_class in skip_callbacks: + pytest.skip( + f'Skipping test for {callback_class.__name__}. It should be tested elsewhere.', + ) + + +@pytest.mark.parametrize( + 'callback_name,callback_class', + callbacks.get_all().items(), +) +def test_build_callback(callback_name: str, callback_class: type): + maybe_skip_callback_test(callback_class) + get_default_kwargs(callback_class) + + callback = build_callback( + callback_name, + kwargs=get_default_kwargs(callback_class), + ) + + assert isinstance(callback, callback_class) + assert isinstance(callback, Callback) + + +@pytest.mark.parametrize( + 'callback_name,callback_class', + callbacks_with_config.get_all().items(), +) +def test_build_callback_with_config(callback_name: str, callback_class: type): + maybe_skip_callback_test(callback_class) + get_default_kwargs(callback_class) + + callback = build_callback( + callback_name, + kwargs=get_default_kwargs(callback_class), + train_config={ + 'save_folder': 'test', + 'save_interval': '1ba', + }, + ) + + assert isinstance(callback, callback_class) + assert isinstance(callback, CallbackWithConfig) diff --git a/tests/callbacks/test_system_metrics_monitor.py b/tests/callbacks/test_system_metrics_monitor.py deleted file mode 100644 index 47095604eb..0000000000 --- a/tests/callbacks/test_system_metrics_monitor.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -from composer.callbacks import SystemMetricsMonitor - -from llmfoundry.utils.builders import build_callback - - -def test_system_metrics_monitor_callback_builds(): - callback = build_callback( - 'system_metrics_monitor', - kwargs={}, - train_config={'train_loader': {}}, - ) - assert isinstance(callback, SystemMetricsMonitor) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index d7f979713a..5f16c86eb9 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -8,7 +8,7 @@ from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import Any, Callable, ContextManager, Literal, Optional, Union -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch import catalogue import numpy as np @@ -1423,3 +1423,160 @@ def test_sharegpt_format( device_batch_size=device_batch_size, **cfg, ).dataloader + +def test_ft_dataloader_with_extra_keys(): + max_seq_len = 2 + cfg = { + 'dataset': { + 'remote': '/remote', + 'local': '/local', + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'shuffle': True, + 'num_canonical_nodes': 472, + 'target_responses': 'last', + 'target_prompts': 'none', + 'extra_key_1': 'extra_key_1', + 'extra_key_2': 'extra_key_2', + 'extra_key_3': 'extra_key_3', + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + device_batch_size = 2 + + mock_stat = MagicMock() + mock_stat.st_size = 1024 # Mock st_size with a desired value + mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems + + #with patch('streaming.base.stream.get_shards', return_value=None): + with patch('os.makedirs'), \ + patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \ + patch('json.load') as mock_json_load, \ + patch('os.stat', return_value=mock_stat), \ + patch('torch.distributed.is_available', return_value=True), \ + patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.broadcast_object_list'), \ + patch('torch.distributed.init_process_group'), \ + patch('torch.distributed.destroy_process_group'), \ + patch('torch.distributed.barrier'), \ + patch('streaming.base.dataset.StreamingDataset.get_item'): + + mock_json_load.return_value = { + 'version': + 2, + 'shards': [{ + 'column_names': ['column1', 'column2'], + 'column_encodings': ['int', 'float'], + 'column_sizes': [4, 8], + 'compression': None, + 'format': 'mds', + 'hashes': [], + 'raw_data': { + 'basename': 'shard.00000.mds', + 'bytes': 1024, + 'hashes': {}, + }, + 'samples': 1000, + 'size_limit': 67108864, + 'version': 2, + 'zip_data': None, + }], + } + + with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'): + _ = build_finetuning_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader + +@pytest.mark.xfail +def test_text_dataloader_with_extra_keys(): + max_seq_len = 1024 + cfg = { + 'dataset': { + 'remote': '/remote', + 'local': '/local', + 'split': 'train', + 'max_seq_len': max_seq_len, + 'shuffle': True, + 'num_canonical_nodes': 472, + 'extra_key_1': 'extra_key_1', + 'extra_key_2': 'extra_key_2', + 'extra_key_3': 'extra_key_3', + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + device_batch_size = 2 + + mock_stat = MagicMock() + mock_stat.st_size = 1024 # Mock st_size with a desired value + mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems + + #with patch('streaming.base.stream.get_shards', return_value=None): + with patch('os.makedirs'), \ + patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \ + patch('json.load') as mock_json_load, \ + patch('os.stat', return_value=mock_stat), \ + patch('torch.distributed.is_available', return_value=True), \ + patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.broadcast_object_list'), \ + patch('torch.distributed.init_process_group'), \ + patch('torch.distributed.destroy_process_group'), \ + patch('torch.distributed.barrier'), \ + patch('streaming.base.dataset.StreamingDataset.get_item'): + + mock_json_load.return_value = { + 'version': + 2, + 'shards': [{ + 'column_names': ['column1', 'column2'], + 'column_encodings': ['int', 'float'], + 'column_sizes': [4, 8], + 'compression': None, + 'format': 'mds', + 'hashes': [], + 'raw_data': { + 'basename': 'shard.00000.mds', + 'bytes': 1024, + 'hashes': {}, + }, + 'samples': 1000, + 'size_limit': 67108864, + 'version': 2, + 'zip_data': None, + }], + } + with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'): + _ = build_text_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 071c189b68..b89fcc4b37 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,15 +1,33 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os from contextlib import nullcontext from typing import Optional from unittest import mock import pytest -from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.data.finetuning.tasks import ( + _get_num_processes, + dataset_constructor, +) from llmfoundry.utils.exceptions import DatasetTooSmallError +def test_get_num_processes(): + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '4'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 4 + + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '32'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + with mock.patch.dict(os.environ, {}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + @pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2]) def test_finetuning_streaming_dataset_too_small( num_canonical_nodes: Optional[int], diff --git a/tests/eval/test_eval_deprecation.py b/tests/eval/test_eval_deprecation.py index 828186245a..e6b64cab05 100644 --- a/tests/eval/test_eval_deprecation.py +++ b/tests/eval/test_eval_deprecation.py @@ -90,36 +90,3 @@ def test_deprecation_warning_with_deprecated_arg(self): 'parallelism_config cannot contain deprecated fsdp_config arguments.', str(context.exception), ) - - def test_deprecation_warning_with_fsdp_config(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - - try: - evaluate_model( - **self.common_args, - parallelism_config=None, - fsdp_config={'verbose': True}, - ) - except Exception: - pass - - self.assertTrue( - any( - issubclass(warning.category, VersionedDeprecationWarning) - for warning in w - ), - ) - - def test_error_with_both_fsdp_and_parallelism_config(self): - with self.assertRaises(ValueError) as context: - evaluate_model( - **self.common_args, - parallelism_config={'some_arg': True}, - fsdp_config={'some_arg': True}, - ) - - self.assertIn( - 'Both fsdp_config and parallelism_config cannot be provided at the same time.', - str(context.exception), - ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index eeb6bf0d90..43067f5e47 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -15,7 +15,10 @@ from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context from composer.distributed.dist_strategy import prepare_fsdp_module -from composer.models.huggingface import maybe_get_underlying_model +from composer.models.huggingface import ( + HuggingFaceModel, + maybe_get_underlying_model, +) from composer.optim import DecoupledAdamW from composer.utils import ( FSDPConfig, @@ -39,7 +42,6 @@ from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms -from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers import build_alibi_bias from llmfoundry.models.layers.attention import ( check_alibi_support, @@ -2560,7 +2562,7 @@ def test_hf_init( False, ) - model = HuggingFaceModelWithFSDP(model, tokenizer) + model = HuggingFaceModel(model, tokenizer) batch = gen_random_batch(batch_size, test_cfg) @@ -2609,7 +2611,7 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): mpt = MPTForCausalLM(hf_config) - model = HuggingFaceModelWithFSDP(mpt, tokenizer, shift_labels=True) + model = HuggingFaceModel(mpt, tokenizer, shift_labels=True) model = model.to(test_cfg.device) batch = gen_random_batch(batch_size, test_cfg)