From 2e3d14f6130ebad5a149c1c52f53fd07628e1006 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Tue, 17 Sep 2024 13:45:04 -0700 Subject: [PATCH] Add deprecation warning to fsdp_config (#1530) Co-authored-by: v-chen_data --- llmfoundry/command_utils/eval.py | 35 ++++- .../inference/test_convert_composer_to_hf.py | 5 +- tests/eval/test_eval_deprecation.py | 125 ++++++++++++++++++ tests/models/hf/test_fsdp_weight_tying.py | 2 +- tests/models/hf/test_hf_peft_wrapping.py | 2 +- tests/models/test_fsdp_act_checkpoint.py | 2 +- 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 tests/eval/test_eval_deprecation.py diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index f622ca182d..e644ad1f0f 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -4,6 +4,7 @@ import logging import os import time +import warnings from typing import Any, Optional, Union import pandas as pd @@ -11,7 +12,7 @@ from composer.core import Callback from composer.loggers.logger_destination import LoggerDestination from composer.trainer import Trainer -from composer.utils import dist, get_device, reproducibility +from composer.utils import dist, get_device, parallelism, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om @@ -36,6 +37,7 @@ process_init_device, ) from llmfoundry.utils.registry_utils import import_file +from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -52,7 +54,6 @@ def evaluate_model( device_eval_batch_size: Union[int, float], eval_gauntlet_config: Optional[Union[str, dict[str, Any]]], eval_loader_config: Optional[Union[dict[str, Any], list[dict[str, Any]]]], - fsdp_config: Optional[dict[str, Any]], loggers: list[LoggerDestination], python_log_level: Optional[str], precision: str, @@ -62,9 +63,33 @@ def evaluate_model( callback_configs: Optional[dict[str, Any]], metadata: Optional[dict[str, str]], logged_config: dict[str, Any], + fsdp_config: Optional[dict[str, Any]] = None, + parallelism_config: Optional[dict[str, Any]] = None, should_log_config: bool = True, load_path: Optional[str] = None, ): + if parallelism_config: + deprecated_fsdp_args = list( + parallelism.FSDPConfig.__annotations__.keys(), + ) + for deprecated_arg in deprecated_fsdp_args: + if deprecated_arg in parallelism_config: + raise ValueError( + 'parallelism_config cannot contain deprecated fsdp_config arguments.', + ) + + if fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'The argument fsdp_config is deprecated. Please use parallelism_config instead.', + remove_version='0.13.0', + ), + ) + if fsdp_config and parallelism_config: + raise ValueError( + 'Both fsdp_config and parallelism_config cannot be provided at the same time. Please use parallelism_config.', + ) + log.info(f'Evaluating model: {model_name}') # Build tokenizer and model tokenizer_cfg = tokenizer @@ -99,6 +124,10 @@ def evaluate_model( mosaicml_logger.log_metrics(metadata) mosaicml_logger._flush_metadata(force_flush=True) + fsdp_config = parallelism_config.get( + 'fsdp_config', + None, + ) if parallelism_config else fsdp_config if fsdp_config and model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + @@ -146,7 +175,7 @@ def evaluate_model( callbacks=callbacks, loggers=loggers, precision=precision, - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, load_path=load_path, load_weights_only=True, progress_bar=False, diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 4f1bd63c62..66ec739a65 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1042,7 +1042,8 @@ def test_huggingface_conversion_callback( model=original_model, device='gpu', precision=trainer_precision, - fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, + parallelism_config={'fsdp': fsdp_config} + if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=save_interval, @@ -1469,7 +1470,7 @@ def test_mptmoe_huggingface_conversion_callback( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=save_interval, diff --git a/tests/eval/test_eval_deprecation.py b/tests/eval/test_eval_deprecation.py new file mode 100644 index 0000000000..828186245a --- /dev/null +++ b/tests/eval/test_eval_deprecation.py @@ -0,0 +1,125 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import warnings + +from llmfoundry.command_utils.eval import evaluate_model +from llmfoundry.utils.warnings import VersionedDeprecationWarning + + +class TestEvaluateModelDeprecation(unittest.TestCase): + + def setUp(self): + self.common_args = { # type: ignore + 'tokenizer': { + 'name': 'test_tokenizer', + }, + 'model': { + 'name': 'test_model', + }, + 'model_name': 'test', + 'dist_timeout': 60, + 'run_name': 'test_run', + 'seed': 42, + 'icl_tasks': [], + 'max_seq_len': 512, + 'device_eval_batch_size': 1, + 'eval_gauntlet_config': None, + 'eval_loader_config': None, + 'loggers': [], + 'python_log_level': None, + 'precision': 'fp32', + 'eval_gauntlet_df': None, + 'eval_subset_num_batches': 1, + 'icl_subset_num_batches': None, + 'callback_configs': None, + 'metadata': None, + 'logged_config': {}, + } + + def test_no_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + import composer.utils.parallelism + deprecated_fsdp_args = list( + composer.utils.parallelism.FSDPConfig.__annotations__.keys(), + ) + print(deprecated_fsdp_args) + + try: + parallelism_config = {'fsdp': {'verbose': True}} + evaluate_model( + **self.common_args, + parallelism_config=parallelism_config, + ) + except ValueError as ve: + if 'parallelism_config cannot contain deprecated fsdp_config arguments.' in str( + ve, + ): + self.fail( + 'Raised ValueError about deprecated fsdp_config arguments', + ) + elif 'Both fsdp_config and parallelism_config cannot be provided at the same time.' in str( + ve, + ): + self.fail( + 'Raised ValueError about both configs being provided', + ) + except Exception: + pass + + deprecation_warnings = [ + warning for warning in w + if isinstance(warning.message, VersionedDeprecationWarning) + ] + if deprecation_warnings: + self.fail('VersionedDeprecationWarning was raised') + + def test_deprecation_warning_with_deprecated_arg(self): + # Use assertRaises to catch the expected ValueError + with self.assertRaises(ValueError) as context: + # Directly call evaluate_model; do not use try-except here + evaluate_model( + **self.common_args, + parallelism_config={'activation_checkpointing': True}, + ) + + # Assert that the correct error message is in the exception + self.assertIn( + 'parallelism_config cannot contain deprecated fsdp_config arguments.', + str(context.exception), + ) + + def test_deprecation_warning_with_fsdp_config(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + try: + evaluate_model( + **self.common_args, + parallelism_config=None, + fsdp_config={'verbose': True}, + ) + except Exception: + pass + + self.assertTrue( + any( + issubclass(warning.category, VersionedDeprecationWarning) + for warning in w + ), + ) + + def test_error_with_both_fsdp_and_parallelism_config(self): + with self.assertRaises(ValueError) as context: + evaluate_model( + **self.common_args, + parallelism_config={'some_arg': True}, + fsdp_config={'some_arg': True}, + ) + + self.assertIn( + 'Both fsdp_config and parallelism_config cannot be provided at the same time.', + str(context.exception), + ) diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 69ced673a1..8e6c113169 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -91,7 +91,7 @@ def test_fsdp_weight_tying( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=[], device_train_microbatch_size=1, ) diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 56cb36c8c1..01acc22a60 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -125,7 +125,7 @@ def test_lora_mixed_init( trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, train_dataloader=[], device_train_microbatch_size=1, ) diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index a41574538a..366bcf7786 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -59,7 +59,7 @@ def test_fsdp_act_checkpoint( trainer = Trainer( model=model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, ) assert trainer.state.fsdp_enabled