diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index ff11deea14..24052fc295 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.2.0 + ref: v0.2.2 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 9ecae2eaa6..e748a3119b 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.2.0 + ref: v0.2.2 path: ./ci-testing - uses: ./ci-testing/.github/actions/coverage with: diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 0bb0b4087a..976b4241ab 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -17,12 +17,14 @@ jobs: strategy: matrix: include: - - name: "2.3.1_cu121" - base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + - name: "2.4.0_cu124" + base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 dep_groups: "[all]" - - name: "2.3.1_cu121_aws" - base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws + te_commit: 901e5d2 + - name: "2.4.0_cu124_aws" + base_image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04-aws dep_groups: "[all]" + te_commit: 901e5d2 steps: - name: Checkout @@ -89,3 +91,4 @@ jobs: BRANCH_NAME=${{ github.head_ref || github.ref_name }} BASE_IMAGE=${{ matrix.base_image }} DEP_GROUPS=${{ matrix.dep_groups }} + TE_COMMIT=${{ matrix.te_commit }} diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 45913ae0bd..056b070143 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -17,19 +17,18 @@ jobs: pytest-cpu: name: ${{ matrix.name }} runs-on: ubuntu-latest + container: ${{ matrix.container }} strategy: matrix: include: - - name: "cpu-2.3.1" + - name: "cpu-2.4.0" pip_deps: "[all-cpu]" - container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.4.0_cpu-python3.11-ubuntu20.04 markers: "not gpu" pytest_command: "coverage run -m pytest" steps: - - name: Checkout code - uses: actions/checkout@v2 - name: Run PR CPU Tests - uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.2.0 + uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.2.2 with: name: ${{ matrix.name }} container: ${{ matrix.container }} diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index e3dfd35070..5b91d54442 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -22,15 +22,15 @@ jobs: fail-fast: false matrix: include: - - name: "gpu-2.3.1-1" - container: mosaicml/llm-foundry:2.3.1_cu121-latest + - name: "gpu-2.4.0-1" + container: mosaicml/llm-foundry:2.4.0_cu124-latest markers: "gpu" pip_deps: "[all]" pytest_command: "coverage run -m pytest" - ci_repo_gpu_test_ref: v0.2.0 + ci_repo_gpu_test_ref: v0.2.2 steps: - name: Run PR GPU Tests - uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.0 + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.2 with: container: ${{ matrix.container }} git_repo: mosaicml/llm-foundry @@ -51,15 +51,15 @@ jobs: fail-fast: false matrix: include: - - name: "gpu-2.3.1-2" - container: mosaicml/llm-foundry:2.3.1_cu121-latest + - name: "gpu-2.4.0-2" + container: mosaicml/llm-foundry:2.4.0_cu124-latest markers: "gpu" pip_deps: "[all]" pytest_command: "coverage run -m pytest" - ci_repo_gpu_test_ref: v0.2.0 + ci_repo_gpu_test_ref: v0.2.2 steps: - name: Run PR GPU Tests - uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.0 + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.2 with: container: ${{ matrix.container }} git_repo: mosaicml/llm-foundry @@ -80,15 +80,15 @@ jobs: fail-fast: false matrix: include: - - name: "gpu-2.3.1-4" - container: mosaicml/llm-foundry:2.3.1_cu121-latest + - name: "gpu-2.4.0-4" + container: mosaicml/llm-foundry:2.4.0_cu124-latest markers: "gpu" pip_deps: "[all]" pytest_command: "coverage run -m pytest" - ci_repo_gpu_test_ref: v0.2.0 + ci_repo_gpu_test_ref: v0.2.2 steps: - name: Run PR GPU Tests - uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.0 + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.2.2 with: container: ${{ matrix.container }} git_repo: mosaicml/llm-foundry diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml index 2df6175743..a0a91671f0 100644 --- a/.github/workflows/smoketest.yaml +++ b/.github/workflows/smoketest.yaml @@ -32,7 +32,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.2.0 + ref: v0.2.2 path: ./ci-testing - uses: ./ci-testing/.github/actions/smoketest with: diff --git a/Dockerfile b/Dockerfile index cee7063cdd..ca52532395 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,7 @@ FROM $BASE_IMAGE ARG BRANCH_NAME ARG DEP_GROUPS +ARG TE_COMMIT ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 8.9 9.0" @@ -15,7 +16,7 @@ ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py RUN rm setup.py # Install TransformerEngine -RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@b5a7c9f +RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@$TE_COMMIT # Install and uninstall foundry to cache foundry requirements RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git diff --git a/README.md b/README.md index e8a6708c5a..0fabb98653 100644 --- a/README.md +++ b/README.md @@ -107,14 +107,14 @@ Something missing? Contribute with a PR! # Hardware and Software Requirements -This codebase has been tested with PyTorch 2.2 with NVIDIA A100s and H100s. +This codebase has been tested with PyTorch 2.4 with NVIDIA A100s and H100s. This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems. If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix! | Device | Torch Version | Cuda Version | Status | | -------------- | ------------- | ------------ | ---------------------------- | -| A100-40GB/80GB | 2.3.1 | 12.1 | :white_check_mark: Supported | -| H100-80GB | 2.3.1 | 12.1 | :white_check_mark: Supported | +| A100-40GB/80GB | 2.4.0 | 12.4 | :white_check_mark: Supported | +| H100-80GB | 2.4.0 | 12.4 | :white_check_mark: Supported | ## MosaicML Docker Images We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories. @@ -122,15 +122,15 @@ We highly recommend using our prebuilt Docker images. You can find them here: ht The `mosaicml/pytorch` images are pinned to specific PyTorch and CUDA versions, and are stable and rarely updated. The `mosaicml/llm-foundry` images are built with new tags upon every commit to the `main` branch. -You can select a specific commit hash such as `mosaicml/llm-foundry:2.3.1_cu121-36ab1ba` or take the latest one using `mosaicml/llm-foundry:2.3.1_cu121-latest`. +You can select a specific commit hash such as `mosaicml/llm-foundry:2.4.0_cu124-36ab1ba` or take the latest one using `mosaicml/llm-foundry:2.4.0_cu124-latest`. **Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source. | Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? | | ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- | -| `mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04` | 2.3.1 | 12.1 (Infiniband) | No | -| `mosaicml/llm-foundry:2.3.1_cu121-latest` | 2.3.1 | 12.1 (Infiniband) | Yes | -| `mosaicml/llm-foundry:2.3.1_cu121_aws-latest` | 2.3.1 | 12.1 (EFA) | Yes | +| `mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04` | 2.4.0 | 12.4 (Infiniband) | No | +| `mosaicml/llm-foundry:2.4.0_cu124-latest` | 2.4.0 | 12.4 (Infiniband) | Yes | +| `mosaicml/llm-foundry:2.4.0_cu124_aws-latest` | 2.4.0 | 12.4 (EFA) | Yes | # Installation diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f05e7322a8..65bdcb3b6c 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -10,6 +10,7 @@ import shutil import tempfile import time +import warnings from multiprocessing.context import SpawnProcess from pathlib import Path from typing import Any, Optional, Sequence, Union @@ -18,6 +19,7 @@ import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit +from composer.devices import Device from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -161,6 +163,10 @@ class HuggingFaceCheckpointer(Callback): keys ``input_example`` and ``signature``. flatten_imports (Sequence[str]): A sequence of import prefixes that will be flattened when editing MPT files. + final_register_only (bool): If true, only register the model in the MLFlow + registry on the last batch and do not save the HuggingFace checkpoint. If + registration fails or mlflow_registered_model_name is not set, then we will + fallback to saving the HuggingFace checkpoint. """ def __init__( @@ -173,6 +179,7 @@ def __init__( mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, flatten_imports: Sequence[str] = ('llmfoundry',), + final_register_only: bool = False, ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -185,8 +192,18 @@ def __init__( self.flatten_imports = flatten_imports self.using_peft = False - # mlflow config setup + self.final_register_only = final_register_only + self.mlflow_registered_model_name = mlflow_registered_model_name + if self.final_register_only and self.mlflow_registered_model_name is None: + self.final_register_only = False + warnings.warn( + 'final_register_only is set to True, but mlflow_registered_model_name is not set. ' + + + f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.', + ) + + # mlflow config setup if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: @@ -249,7 +266,7 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] - self.child_processes: list[SpawnProcess] = [] + self.register_processes: list[SpawnProcess] = [] # Temporary save directory used by child_processes. self.temp_save_dir = None @@ -259,7 +276,17 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: state, event, ) and self.last_checkpoint_batch != state.timestamp.batch: - self._save_checkpoint(state, logger) + is_last_batch = self._is_last_batch(state) + self._save_checkpoint( + state, + logger, + register_to_mlflow=( + self.mlflow_registered_model_name is not None and + is_last_batch + ), + upload_to_save_folder=not self.final_register_only or + not is_last_batch, + ) elif event == Event.INIT: if not isinstance(state.model, HuggingFaceModel): raise ValueError( @@ -300,7 +327,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: # Wait for all child processes spawned by the callback to finish. timeout = 3600 wait_start = time.time() - while not self._all_child_processes_done(): + while not self._all_register_processes_done(state.device): wait_time = time.time() - wait_start if wait_time > timeout: raise TimeoutError( @@ -308,6 +335,19 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ) time.sleep(2) + if self._any_register_processes_error( + state.device, + ) and self.final_register_only: + log.error( + 'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.', + ) + self._save_checkpoint( + state, + logger, + upload_to_save_folder=True, + register_to_mlflow=False, + ) + # Clean up temporary save directory; all processes are done with it. if self.temp_save_dir is not None: shutil.rmtree(self.temp_save_dir) @@ -339,12 +379,23 @@ def _is_last_batch(self, state: State): return False - def _all_child_processes_done(self) -> bool: - not_done = any(process.is_alive() for process in self.child_processes) - x = torch.tensor(1 if not_done else 0).to(device='cuda') + def _all_register_processes_done(self, device: Device) -> bool: + not_done = any( + process.is_alive() for process in self.register_processes + ) + x = device.tensor_to_device(torch.tensor(1 if not_done else 0)) dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def _any_register_processes_error(self, device: Device) -> bool: + has_errors = any( + process.exitcode is not None and process.exitcode != 0 + for process in self.register_processes + ) + x = device.tensor_to_device(torch.tensor(1 if has_errors else 0)) + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 1 + def transform_model_and_tokenizer( self, model: PreTrainedModel, @@ -412,7 +463,21 @@ def transform_model_pre_registration( """ return model - def _save_checkpoint(self, state: State, logger: Logger): + def _save_checkpoint( + self, + state: State, + logger: Logger, + upload_to_save_folder: bool, + register_to_mlflow: bool, + ): + """Save a HuggingFace formatted checkpoint. + + Args: + state (State): The training state. + logger (Logger): The logger. + upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder. + register_to_mlflow (bool): Whether to register the model to MLFlow + """ del logger # unused self.last_checkpoint_batch = state.timestamp.batch @@ -520,6 +585,7 @@ def tensor_hook( new_base_model_instance, original_model.peft_config[active_adapter], ) + del new_base_model_instance else: new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( @@ -548,50 +614,53 @@ def tensor_hook( ].base_model_name_or_path = self.pretrained_model_name log.debug('Saving Hugging Face checkpoint to disk') - # This context manager casts the TE extra state in io.BytesIO format to tensor format - # Needed for proper hf ckpt saving. - context_manager = te.onnx_export( - True, - ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( - ) - with context_manager: - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance( - original_tokenizer, - PreTrainedTokenizerBase, - ) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if new_model_instance.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, - ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name, - ) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', + if upload_to_save_folder: + # This context manager casts the TE extra state in io.BytesIO format to tensor format + # Needed for proper hf ckpt saving. + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance( + original_tokenizer, + PreTrainedTokenizerBase, ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, filename)), - overwrite=self.overwrite, + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if new_model_instance.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, ) + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name, + ) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path( + os.path.join(temp_save_dir, filename), + ), + overwrite=self.overwrite, + ) + dist.barrier() if dist.get_global_rank() == 0: - if self.mlflow_registered_model_name and self._is_last_batch(state): - + if register_to_mlflow: new_model_instance = self.transform_model_pre_registration( new_model_instance, ) @@ -680,7 +749,7 @@ def tensor_hook( # Restore the monitor process. if monitor_process is not None: mlflow_logger.monitor_process = monitor_process # type: ignore - self.child_processes.append(process) + self.register_processes.append(process) # Save the temporary directory to be cleaned up later. if use_temp_dir: diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 50d11b1222..666d0278c6 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -22,6 +22,7 @@ ClusterDoesNotExistError, FailedToConnectToDatabricksError, FailedToCreateSQLConnectionError, + InsufficientPermissionsError, ) if TYPE_CHECKING: @@ -454,6 +455,12 @@ def fetch( sparkSession, ) except Exception as e: + from pyspark.errors import AnalysisException + if isinstance(e, AnalysisException): + if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore + raise InsufficientPermissionsError( + action=f'reading from {tablename}', + ) from e raise RuntimeError( f'Error in get rows from {tablename}. Restart sparkSession and try again', ) from e diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py index 7c40a7e698..9a1f8a912d 100644 --- a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -478,7 +478,9 @@ def convert_text_to_mds( index_path = os.path.join(local_output_folder, 'index.json') with open(index_path, 'r') as index_file: if not json.load(index_file)['shards']: - raise DatasetTooSmallError() + raise DatasetTooSmallError( + reason='No shards were created when converting text to MDS.', + ) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index f622ca182d..e644ad1f0f 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -4,6 +4,7 @@ import logging import os import time +import warnings from typing import Any, Optional, Union import pandas as pd @@ -11,7 +12,7 @@ from composer.core import Callback from composer.loggers.logger_destination import LoggerDestination from composer.trainer import Trainer -from composer.utils import dist, get_device, reproducibility +from composer.utils import dist, get_device, parallelism, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om @@ -36,6 +37,7 @@ process_init_device, ) from llmfoundry.utils.registry_utils import import_file +from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -52,7 +54,6 @@ def evaluate_model( device_eval_batch_size: Union[int, float], eval_gauntlet_config: Optional[Union[str, dict[str, Any]]], eval_loader_config: Optional[Union[dict[str, Any], list[dict[str, Any]]]], - fsdp_config: Optional[dict[str, Any]], loggers: list[LoggerDestination], python_log_level: Optional[str], precision: str, @@ -62,9 +63,33 @@ def evaluate_model( callback_configs: Optional[dict[str, Any]], metadata: Optional[dict[str, str]], logged_config: dict[str, Any], + fsdp_config: Optional[dict[str, Any]] = None, + parallelism_config: Optional[dict[str, Any]] = None, should_log_config: bool = True, load_path: Optional[str] = None, ): + if parallelism_config: + deprecated_fsdp_args = list( + parallelism.FSDPConfig.__annotations__.keys(), + ) + for deprecated_arg in deprecated_fsdp_args: + if deprecated_arg in parallelism_config: + raise ValueError( + 'parallelism_config cannot contain deprecated fsdp_config arguments.', + ) + + if fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'The argument fsdp_config is deprecated. Please use parallelism_config instead.', + remove_version='0.13.0', + ), + ) + if fsdp_config and parallelism_config: + raise ValueError( + 'Both fsdp_config and parallelism_config cannot be provided at the same time. Please use parallelism_config.', + ) + log.info(f'Evaluating model: {model_name}') # Build tokenizer and model tokenizer_cfg = tokenizer @@ -99,6 +124,10 @@ def evaluate_model( mosaicml_logger.log_metrics(metadata) mosaicml_logger._flush_metadata(force_flush=True) + fsdp_config = parallelism_config.get( + 'fsdp_config', + None, + ) if parallelism_config else fsdp_config if fsdp_config and model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + @@ -146,7 +175,7 @@ def evaluate_model( callbacks=callbacks, loggers=loggers, precision=precision, - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, load_path=load_path, load_weights_only=True, progress_bar=False, diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 8e6309175a..14b7980d57 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -10,6 +10,7 @@ import torch import torch.distributed from composer import ComposerModel, Trainer +from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.core.callback import Callback from composer.profiler import ( JSONTraceHandler, @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): log.debug('Barrier test passed with device.') +def _sort_callbacks(trainer: Trainer): + """Sort callback so that checkpoint saving callbacks go first. + + Args: + trainer (Trainer): Trainer object + """ + + def _sort_key(c: Callback) -> int: + # CheckpointSaver goes before HuggingFaceCheckpointer because the blocking time is shortest while upload is async. + if isinstance(c, CheckpointSaver): + return 1 + if isinstance(c, HuggingFaceCheckpointer): + return 2 + return 0 + + trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key) + + def train(cfg: DictConfig) -> Trainer: code_paths = cfg.get('code_paths', []) # Import any user provided code @@ -548,6 +567,8 @@ def train(cfg: DictConfig) -> Trainer: spin_dataloaders=train_cfg.spin_dataloaders, ) + _sort_callbacks(trainer) + # Optionally just save an HF checkpoint if train_cfg.only_hf_checkpoint: hf_checkpointer_callbacks = [ @@ -563,7 +584,12 @@ def train(cfg: DictConfig) -> Trainer: ) hf_checkpointer_callback = hf_checkpointer_callbacks[0] - hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) + hf_checkpointer_callback._save_checkpoint( + trainer.state, + trainer.logger, + upload_to_save_folder=True, + register_to_mlflow=True, + ) return trainer if train_cfg.only_composer_checkpoint: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 801813b3ff..e8f6484ef2 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -73,6 +73,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ALLOWED_RESPONSE_KEYS, ChatTemplateError, ConsecutiveRepeatedChatRolesError, + DatasetTooSmallError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, InvalidExampleTypeError, @@ -1013,7 +1014,7 @@ def dataset_mapper(example: dict): raise MisconfiguredHfDatasetError( dataset_name=dataset_name, split=split, - ) + ) from error if error is not None: log.error('Error during data prep') raise error @@ -1033,7 +1034,24 @@ def build_from_streaming( *args: Any, **kwargs: Any, ) -> StreamingFinetuningDataset: - return self.streaming_dataset_class(*args, **kwargs) + dataset = self.streaming_dataset_class(*args, **kwargs) + num_canonical_nodes = dataset.num_canonical_nodes + num_samples = dataset.num_samples + if num_canonical_nodes is None: + num_physical_nodes = dist.get_world_size( + ) // dist.get_local_world_size() + if num_samples < num_physical_nodes: + raise DatasetTooSmallError( + f'{num_samples=} is less than {dist.get_world_size() // dist.get_local_world_size()}, the number of physical nodes. ', + ) + + if num_canonical_nodes is not None and num_samples < num_canonical_nodes: + raise DatasetTooSmallError( + f'{num_samples=} is less than {num_canonical_nodes=}. ' + + 'Please check your index.json file and ensure that your dataset has been written out correctly.' + + 'If this was intended, reduce num_canonical_nodes.', + ) + return dataset dataset_constructor = DatasetConstructor() diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index 6b693f2d21..d193e1067f 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -69,7 +69,7 @@ def __init__( config_overrides: Optional[dict[str, Any]] = None, use_logits: bool = True, shift_labels: bool = False, - peft_config: Optional['PeftConfig'] = None, + peft_config: Optional[dict[str, Any]] = None, allow_embedding_resizing: bool = False, use_train_metrics: bool = True, additional_train_metrics: Optional[list] = None, @@ -92,8 +92,6 @@ def __init__( model = self.transform_model(model) - self.prepare_inner_model(model, init_device) - metrics, eval_metrics = self.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -121,6 +119,10 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + # Prepare for FSDP needs to happen after the super init, so that any model + # architecture changes are completed + self.prepare_inner_model(self.model, init_device) + def loss(self, outputs: ModelOutput, batch: Mapping): if self.config.use_return_dict: return outputs['loss'] diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 5af5481f0a..c88cf33d1b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -164,11 +164,15 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if key_value_states is not None: + extra_kwargs['key_value_states'] = key_value_states + if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -327,12 +331,16 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if key_value_states is not None: + extra_kwargs['key_value_states'] = key_value_states + b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 38ca253f80..5b5a6b1449 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -397,7 +397,6 @@ def attach_ffn_mb_args( """ ffn.experts.mlp.hidden_size = args.ffn_hidden_size ffn.experts.mlp.expert_parallel_group = expert_parallel_group - ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group def get_fsdp_submesh_2d(device_mesh: DeviceMesh): diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2db18120ac..9212f5594d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -396,6 +396,7 @@ def __init__(self, config: MPTConfig): self.wte = SharedEmbedding( config.vocab_size, config.d_model, + padding_idx=config.pad_token_id, device=config.init_device, ) if self.learned_pos_emb: diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index d7b61354c7..bd8f279ad5 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -62,13 +62,6 @@ def megablocks_n_total_params(mpt_model) -> int: # type: ignore moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') - if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): - # If MegaBlocks shards experts, the total sharding world size - # must be increased by the degree to which MegaBlocks shards the - # experts. - mb_args = mpt_model.model.transformer.mb_args - moe_world_size *= mb_args.weight_parallel_group.size() - n_total_params = 0 for module in mpt_model.modules(): if isinstance( @@ -109,9 +102,6 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') local_experts = moe_num_experts / moe_world_size # if local_experts is < 1, then the expert is sharded - if mpt_model.config.ffn_config.get('moe_weight_parallelism', False): - mb_args = mpt_model.model.transformer.mb_args - local_experts /= mb_args.weight_parallel_group.size() moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1) n_active_params = 0 diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 9941c2d049..8ad6e77c57 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -224,6 +224,9 @@ def embedding_init( emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) + if module.padding_idx is not None: + with torch.no_grad(): + module.weight[module.padding_idx].fill_(0) return True @@ -484,19 +487,12 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( div_is_residual (float): The value by which parameter initialization is divided if init_div_is_residual flag is enabled. """ - expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - weight_parallel_group_rank = int( - module.weight_parallel_group.rank(), - ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -505,8 +501,7 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( if isinstance(w1, DTensor): w1 = w1._local_tensor w1_size = list(w1.shape) # type: ignore - w1_size[ - 0] = w1_size[0] * expert_process_group_size * weight_parallel_group_size + w1_size[0] = w1_size[0] * expert_process_group_size n_exp = w1_size[0] // hidden_size _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) @@ -514,26 +509,21 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( _w1 = w1.new_empty(w1_size) # type: ignore fused_param_init_helper(_w1, init_fn_, _fused) _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] - _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - w1.copy_(_w1_local_slice) # type: ignore + w1.copy_(_w1_local) # type: ignore # Initialize w2 w2 = module.w2 if isinstance(w2, DTensor): w2 = w2._local_tensor w2_size = list(w2.shape) # type: ignore - w2_size[ - 0] = w2_size[0] * expert_process_group_size * weight_parallel_group_size + w2_size[0] = w2_size[0] * expert_process_group_size _w2 = w2.new_empty(w2_size) # type: ignore # MegaBlocks operates on w2 as x @ w2, so needs flipped fan mode fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused) _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] - _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - w2.copy_(_w2_local_slice) # type: ignore + w2.copy_(_w2_local) # type: ignore if init_div_is_residual is not False: with torch.no_grad(): w2.div_(div_is_residual) # type: ignore @@ -567,19 +557,12 @@ def _megablocks_sparse_glu_generic_param_init_fn_( ) # Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1 - expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - weight_parallel_group_rank = int( - module.weight_parallel_group.rank(), - ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -588,8 +571,7 @@ def _megablocks_sparse_glu_generic_param_init_fn_( if isinstance(v1, DTensor): v1 = v1._local_tensor v1_size = list(v1.shape) # type: ignore - v1_size[ - 0] = v1_size[0] * expert_process_group_size * weight_parallel_group_size + v1_size[0] = v1_size[0] * expert_process_group_size n_exp = v1_size[0] // hidden_size _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)]) @@ -597,10 +579,8 @@ def _megablocks_sparse_glu_generic_param_init_fn_( _v1 = v1.new_empty(v1_size) # type: ignore fused_param_init_helper(_v1, init_fn_, _fused) _v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank] - _v1_local_slice = _v1_local.chunk(weight_parallel_group_size, - dim=0)[weight_parallel_group_rank] with torch.no_grad(): - v1.copy_(_v1_local_slice) # type: ignore + v1.copy_(_v1_local) # type: ignore def _megablocks_mlp_generic_param_init_fn_( @@ -623,41 +603,32 @@ def _megablocks_mlp_generic_param_init_fn_( div_is_residual (float): The value by which parameter initialization is divided if init_div_is_residual flag is enabled. """ - expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0 + expert_process_group_size, rank = 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( module.expert_parallel_group.size(), ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore - if module.weight_parallel_group is not None: - weight_parallel_group_size = int( - module.weight_parallel_group.size(), - ) # type: ignore - w_rank = int(module.weight_parallel_group.rank()) # type: ignore _init_fn_ = _flip_fan_mode(init_fn_) # Initialize w1 w1_size = list(module.w1.shape) # type: ignore w1_size[0] = w1_size[0] * expert_process_group_size - w1_size[1] = w1_size[1] * weight_parallel_group_size _w1 = module.w1.new_empty(w1_size) # type: ignore stacked_param_init_helper(_w1, _init_fn_, module._stack_dim) # type: ignore _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank] - _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, dim=1)[w_rank] with torch.no_grad(): - module.w1.copy_(_w1_local_slice) # type: ignore + module.w1.copy_(_w1_local) # type: ignore # Initialize w2 w2_size = list(module.w2.shape) # type: ignore w2_size[0] = w2_size[0] * expert_process_group_size - w2_size[1] = w2_size[1] * weight_parallel_group_size _w2 = module.w2.new_empty(w2_size) # type: ignore stacked_param_init_helper(_w2, _init_fn_, module._stack_dim) # type: ignore _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank] - _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, dim=1)[w_rank] with torch.no_grad(): - module.w2.copy_(_w2_local_slice) # type: ignore + module.w2.copy_(_w2_local) # type: ignore if init_div_is_residual is not False: with torch.no_grad(): module.w2.div_(div_is_residual) # type: ignore diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 345a254407..11895564f2 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -376,9 +376,9 @@ def __init__(self, dataset_name: str, split: str) -> None: class DatasetTooSmallError(UserError): """Error thrown when the dataset is too small to be processed.""" - def __init__(self) -> None: - message = f'Your dataset is too small and produced no complete samples during preprocessing. Please provide more data.' - super().__init__(message) + def __init__(self, reason: str) -> None: + message = f'Your dataset is too small and produced no complete samples or too few samples. Please provide more data. {reason}' + super().__init__(message, reason=reason) class RunTimeoutError(InternalError): @@ -427,3 +427,11 @@ def __init__( window_size=window_size, loss_window=loss_window, ) + + +class InsufficientPermissionsError(UserError): + """Error thrown when the user does not have sufficient permissions.""" + + def __init__(self, action: str) -> None: + message = f'Insufficient permissions when {action}. Please check your permissions.' + super().__init__(message, action=action) diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index 0f62917ef8..19f96d4fb1 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -158,7 +158,7 @@ def export_to_onnx( ort_session = ort.InferenceSession(str(output_file)) for key, value in sample_input.items(): - sample_input[key] = value.cpu().numpy() + sample_input[key] = value.cpu().numpy() # pyright: ignore loaded_model_out = ort_session.run(None, sample_input) diff --git a/scripts/train/yamls/pretrain/testing-moe.yaml b/scripts/train/yamls/pretrain/testing-moe.yaml index e61e3e451e..ee9483ffd0 100644 --- a/scripts/train/yamls/pretrain/testing-moe.yaml +++ b/scripts/train/yamls/pretrain/testing-moe.yaml @@ -23,7 +23,6 @@ model: moe_num_experts: 4 moe_top_k: 2 moe_world_size: 1 - moe_weight_parallelism: false uniform_expert_assignment: false n_heads: 2 n_layers: 2 diff --git a/setup.py b/setup.py index 1316626559..ebc66fdacf 100644 --- a/setup.py +++ b/setup.py @@ -53,11 +53,11 @@ install_requires = [ 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.24.1,<0.25', - 'mlflow>=2.14.1,<2.16', + 'mlflow>=2.14.1,<2.17', 'accelerate>=0.25,<0.34', # for HF inference `device_map` 'transformers>=4.43.2,<4.44', 'mosaicml-streaming>=0.8.1,<0.9', - 'torch>=2.3.0,<2.4', + 'torch>=2.4.0,<2.4.1', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.2.0', @@ -92,6 +92,7 @@ extra_deps['databricks'] = [ 'mosaicml[databricks]>=0.24.1,<0.25', + 'numpy<2', 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', @@ -118,8 +119,8 @@ ] extra_deps['megablocks'] = [ - 'megablocks==0.5.1', - 'grouped-gemm==0.1.4', + 'megablocks==0.6.1', + 'grouped-gemm==0.1.6', ] extra_deps['databricks-serverless'] = { diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 260988dc31..66ec739a65 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -9,6 +9,7 @@ import shutil from argparse import Namespace from typing import Any, Callable, Optional, cast +from unittest import mock from unittest.mock import ANY, MagicMock, patch import catalogue @@ -314,9 +315,15 @@ class MockSpawnProcess: multiprocessing, so we need to patch SpawnProcess for tests. """ - def __init__(self, target: Callable, kwargs: dict[str, Any]): + def __init__( + self, + target: Callable, + kwargs: dict[str, Any], + exitcode: int = 0, + ): self.target = target self.kwargs = kwargs + self.exitcode = exitcode def start(self): self.target(**self.kwargs) @@ -325,6 +332,133 @@ def is_alive(self) -> bool: return False +def _create_mlflow_logger_mock() -> MagicMock: + mlflow_logger_mock = MagicMock(spec=MLFlowLogger) + mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} + mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) + mlflow_logger_mock.register_model_with_run_id = MagicMock() + mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' + return mlflow_logger_mock + + +def _create_optimizer(original_model: torch.nn.Module) -> torch.optim.Optimizer: + optimizer_config = _OPTIMIZER_CFG() + optimizer_name = optimizer_config.pop('name') + return build_optimizer( + original_model, + optimizer_name, + optimizer_config, + ) + + +@pytest.mark.gpu +@pytest.mark.parametrize('mlflow_registry_error', [True, False]) +@pytest.mark.parametrize( + 'mlflow_registered_model_name', + [None, 'dummy-registered-name'], +) +@patch('os.cpu_count', MagicMock(return_value=1)) +@patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess, +) +def test_final_register_only( + mlflow_registry_error: bool, + mlflow_registered_model_name: Optional[str], + tiny_ft_dataloader: DataLoader, + tmp_path: pathlib.Path, + build_tiny_mpt: Callable, +): + if mlflow_registry_error and mlflow_registered_model_name is None: + pytest.skip( + 'Cannot test mlflow_registry_error without mlflow_registered_model_name', + ) + + delete_transformers_cache() + + dist.initialize_dist(get_device('gpu')) + + precision_str = 'bfloat16' + + checkpointer_callback = HuggingFaceCheckpointer( + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval='1dur', + precision=precision_str, + mlflow_registered_model_name=mlflow_registered_model_name, + final_register_only=True, + ) + + original_model = build_tiny_mpt() + + optimizer = _create_optimizer(original_model) + + mlflow_logger_mock = _create_mlflow_logger_mock() + + checkpointer_callback._save_checkpoint = MagicMock( + wraps=checkpointer_callback._save_checkpoint, + ) + trainer = Trainer( + model=original_model, + device='gpu', + train_dataloader=tiny_ft_dataloader, + max_duration='1ba', + callbacks=[checkpointer_callback], + loggers=[mlflow_logger_mock], + optimizers=optimizer, + save_latest_filename=None, + ) + + with mock.patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=lambda target, + kwargs: MockSpawnProcess( + target, + kwargs, + exitcode=1 if mlflow_registry_error else 0, + ), + ): + trainer.fit() + + if mlflow_registered_model_name is not None: + # We should always attempt to register the model once + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + if mlflow_registry_error: + # If the registry fails, we should still save the model + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_count == 2 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': True, + 'upload_to_save_folder': False, + } + assert checkpointer_callback._save_checkpoint.call_args_list[ + 1].kwargs == { + 'register_to_mlflow': False, + 'upload_to_save_folder': True, + } + else: + # No mlflow_registry_error, so we should only register the model + assert checkpointer_callback._save_checkpoint.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': True, + 'upload_to_save_folder': False, + } + else: + # No mlflow_registered_model_name, so we should only save the checkpoint + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + assert checkpointer_callback._save_checkpoint.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': False, + 'upload_to_save_folder': True, + } + + @pytest.mark.gpu @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( @@ -368,23 +502,9 @@ def test_huggingface_conversion_callback_interval( original_model = build_tiny_mpt() - optimizer_config = _OPTIMIZER_CFG() - optimizer_name = optimizer_config.pop('name') - optimizer = build_optimizer( - original_model, - optimizer_name, - optimizer_config, - ) + optimizer = _create_optimizer(original_model) - mlflow_logger_mock = MagicMock(spec=MLFlowLogger) - mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} - mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model_with_run_id = MagicMock() - mlflow_logger_mock.model_registry_prefix = '' - mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' - mlflow_logger_mock._run_id = 'mlflow-run-id' - mlflow_logger_mock._enabled = True - mlflow_logger_mock.run_url = 'fake-url' + mlflow_logger_mock = _create_mlflow_logger_mock() checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, ) @@ -519,7 +639,6 @@ def _get_model_and_tokenizer( 'moe_num_experts': 4, 'moe_top_k': 2, 'moe_world_size': 1, - 'moe_weight_parallelism': False, 'uniform_expert_assignment': False, }, 'max_seq_len': max_seq_len, @@ -923,7 +1042,8 @@ def test_huggingface_conversion_callback( model=original_model, device='gpu', precision=trainer_precision, - fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, + parallelism_config={'fsdp': fsdp_config} + if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=save_interval, @@ -1251,8 +1371,6 @@ def test_mptmoe_huggingface_conversion_callback( 2, 'moe_world_size': 2, - 'moe_weight_parallelism': - False, 'uniform_expert_assignment': True, 'mlp_impl': @@ -1352,7 +1470,7 @@ def test_mptmoe_huggingface_conversion_callback( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=save_interval, diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 1f724a6070..9af96f9868 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -5,14 +5,18 @@ import os import pathlib from typing import Optional +from unittest.mock import Mock import pytest +from composer.callbacks import CheckpointSaver from composer.loggers import InMemoryLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback from llmfoundry.command_utils import TrainConfig # noqa: E402 from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config +from llmfoundry.command_utils.train import _sort_callbacks from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): -1][-1] == 0 +def test_sort_callbacks(): + trainer_mock = Mock() + trainer_mock.state.callbacks = [ + CheckpointSaver(), + HuggingFaceCheckpointer('save-folder', '1ba'), + RunTimeoutCallback(), + ] + _sort_callbacks(trainer_mock) + + assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback) + assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver) + assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer) + + def test_train_multi_eval(tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py new file mode 100644 index 0000000000..071c189b68 --- /dev/null +++ b/tests/data/test_dataset.py @@ -0,0 +1,44 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext +from typing import Optional +from unittest import mock + +import pytest + +from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.utils.exceptions import DatasetTooSmallError + + +@pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2]) +def test_finetuning_streaming_dataset_too_small( + num_canonical_nodes: Optional[int], +): + num_samples = 2 + + class MockDataset: + + def __init__(self): + self.num_canonical_nodes = num_canonical_nodes + self.num_samples = num_samples + + class MockDist: + + def get_world_size(self): + return 32 + + def get_local_world_size(self): + return 8 + + result_context = nullcontext( + ) if num_canonical_nodes == 2 else pytest.raises(DatasetTooSmallError) + with result_context: + with mock.patch( + 'llmfoundry.data.finetuning.tasks.dist', + new=MockDist(), + ): + with mock.patch( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.streaming_dataset_class', + new=MockDataset, + ): + dataset_constructor.build_from_streaming() diff --git a/tests/eval/test_eval_deprecation.py b/tests/eval/test_eval_deprecation.py new file mode 100644 index 0000000000..828186245a --- /dev/null +++ b/tests/eval/test_eval_deprecation.py @@ -0,0 +1,125 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import warnings + +from llmfoundry.command_utils.eval import evaluate_model +from llmfoundry.utils.warnings import VersionedDeprecationWarning + + +class TestEvaluateModelDeprecation(unittest.TestCase): + + def setUp(self): + self.common_args = { # type: ignore + 'tokenizer': { + 'name': 'test_tokenizer', + }, + 'model': { + 'name': 'test_model', + }, + 'model_name': 'test', + 'dist_timeout': 60, + 'run_name': 'test_run', + 'seed': 42, + 'icl_tasks': [], + 'max_seq_len': 512, + 'device_eval_batch_size': 1, + 'eval_gauntlet_config': None, + 'eval_loader_config': None, + 'loggers': [], + 'python_log_level': None, + 'precision': 'fp32', + 'eval_gauntlet_df': None, + 'eval_subset_num_batches': 1, + 'icl_subset_num_batches': None, + 'callback_configs': None, + 'metadata': None, + 'logged_config': {}, + } + + def test_no_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + import composer.utils.parallelism + deprecated_fsdp_args = list( + composer.utils.parallelism.FSDPConfig.__annotations__.keys(), + ) + print(deprecated_fsdp_args) + + try: + parallelism_config = {'fsdp': {'verbose': True}} + evaluate_model( + **self.common_args, + parallelism_config=parallelism_config, + ) + except ValueError as ve: + if 'parallelism_config cannot contain deprecated fsdp_config arguments.' in str( + ve, + ): + self.fail( + 'Raised ValueError about deprecated fsdp_config arguments', + ) + elif 'Both fsdp_config and parallelism_config cannot be provided at the same time.' in str( + ve, + ): + self.fail( + 'Raised ValueError about both configs being provided', + ) + except Exception: + pass + + deprecation_warnings = [ + warning for warning in w + if isinstance(warning.message, VersionedDeprecationWarning) + ] + if deprecation_warnings: + self.fail('VersionedDeprecationWarning was raised') + + def test_deprecation_warning_with_deprecated_arg(self): + # Use assertRaises to catch the expected ValueError + with self.assertRaises(ValueError) as context: + # Directly call evaluate_model; do not use try-except here + evaluate_model( + **self.common_args, + parallelism_config={'activation_checkpointing': True}, + ) + + # Assert that the correct error message is in the exception + self.assertIn( + 'parallelism_config cannot contain deprecated fsdp_config arguments.', + str(context.exception), + ) + + def test_deprecation_warning_with_fsdp_config(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + try: + evaluate_model( + **self.common_args, + parallelism_config=None, + fsdp_config={'verbose': True}, + ) + except Exception: + pass + + self.assertTrue( + any( + issubclass(warning.category, VersionedDeprecationWarning) + for warning in w + ), + ) + + def test_error_with_both_fsdp_and_parallelism_config(self): + with self.assertRaises(ValueError) as context: + evaluate_model( + **self.common_args, + parallelism_config={'some_arg': True}, + fsdp_config={'some_arg': True}, + ) + + self.assertIn( + 'Both fsdp_config and parallelism_config cannot be provided at the same time.', + str(context.exception), + ) diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 69ced673a1..8e6c113169 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -91,7 +91,7 @@ def test_fsdp_weight_tying( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=[], device_train_microbatch_size=1, ) diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 522fc5db57..01acc22a60 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -11,6 +11,7 @@ from composer import Trainer from peft import LoraConfig, get_peft_model +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp from llmfoundry.utils.builders import build_composer_model, build_tokenizer @@ -36,6 +37,27 @@ def test_peft_wraps(): assert m._fsdp_wrap +def test_causal_lm_peft_wraps(): + model = ComposerHFCausalLM( + tokenizer=None, + pretrained_model_name_or_path='mosaicml/mpt-7b', + pretrained=False, + trust_remote_code=True, + config_overrides={'n_layers': 2}, + peft_config={ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + }, + ) + + for n, m in model.named_modules(): + if 'lora' in n and 'default' in n: + has_parameters = any(True for _ in m.parameters()) + has_buffers = any(True for _ in m.buffers()) + if has_parameters or has_buffers: + assert m._fsdp_wrap + + @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize( @@ -103,7 +125,7 @@ def test_lora_mixed_init( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=[], device_train_microbatch_size=1, ) diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index a41574538a..366bcf7786 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -59,7 +59,7 @@ def test_fsdp_act_checkpoint( trainer = Trainer( model=model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, ) assert trainer.state.fsdp_enabled diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 134ca35ec0..820da5e71f 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -190,7 +190,6 @@ def test_gen_mpt_moe( 'moe_num_experts': 4, 'moe_top_k': 2, 'moe_world_size': 1, - 'moe_weight_parallelism': False, 'uniform_expert_assignment': False, }, ) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 95732cfd8f..32d46607b5 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -85,7 +85,7 @@ def test_onnx_export(tie_word_embeddings: bool, tmp_path: pathlib.Path): ort_session = ort.InferenceSession(str(tmp_path / 'mpt.onnx')) for key, value in sample_input.items(): - sample_input[key] = value.cpu().numpy() + sample_input[key] = value.cpu().numpy() # pyright: ignore loaded_model_out = ort_session.run(None, sample_input) diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 0eaf60c869..11d9fba430 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -199,3 +199,30 @@ def test_emb_init(emb_init_cfg: Optional[tuple[str, Union[int, list[int]]]]): emb_init_uniform_lim, ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: assert (model.emb.weight == emb_init_uniform_lim[0]).all() + + +@pytest.mark.parametrize( + 'padding_idx', + [0, 2], +) +def test_emb_padding_init(padding_idx: int,): + cfg: dict[str, Union[int, list[int]]] = { + 'vocab_size': 64, + 'in_features': 16, + 'n_layers': 2, + 'padding_idx': padding_idx, + 'emb_init_std': 5, + } + dict_cfg = om.create(cfg) + + model = nn.Embedding( + dict_cfg.vocab_size, + dict_cfg.in_features, + dict_cfg.padding_idx, + ) + + model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) + assert isinstance(model, torch.nn.Embedding) + + if dict_cfg.get('emb_init_std') is not None: + assert (model.weight[padding_idx] == 0).all()