Skip to content

Commit

Permalink
Add deprecation warning to fsdp_config (#1530)
Browse files Browse the repository at this point in the history
Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Sep 17, 2024
1 parent 7a23f60 commit 2e3d14f
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 8 deletions.
35 changes: 32 additions & 3 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import logging
import os
import time
import warnings
from typing import Any, Optional, Union

import pandas as pd
import torch
from composer.core import Callback
from composer.loggers.logger_destination import LoggerDestination
from composer.trainer import Trainer
from composer.utils import dist, get_device, reproducibility
from composer.utils import dist, get_device, parallelism, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

Expand All @@ -36,6 +37,7 @@
process_init_device,
)
from llmfoundry.utils.registry_utils import import_file
from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

Expand All @@ -52,7 +54,6 @@ def evaluate_model(
device_eval_batch_size: Union[int, float],
eval_gauntlet_config: Optional[Union[str, dict[str, Any]]],
eval_loader_config: Optional[Union[dict[str, Any], list[dict[str, Any]]]],
fsdp_config: Optional[dict[str, Any]],
loggers: list[LoggerDestination],
python_log_level: Optional[str],
precision: str,
Expand All @@ -62,9 +63,33 @@ def evaluate_model(
callback_configs: Optional[dict[str, Any]],
metadata: Optional[dict[str, str]],
logged_config: dict[str, Any],
fsdp_config: Optional[dict[str, Any]] = None,
parallelism_config: Optional[dict[str, Any]] = None,
should_log_config: bool = True,
load_path: Optional[str] = None,
):
if parallelism_config:
deprecated_fsdp_args = list(
parallelism.FSDPConfig.__annotations__.keys(),
)
for deprecated_arg in deprecated_fsdp_args:
if deprecated_arg in parallelism_config:
raise ValueError(
'parallelism_config cannot contain deprecated fsdp_config arguments.',
)

if fsdp_config:
warnings.warn(
VersionedDeprecationWarning(
'The argument fsdp_config is deprecated. Please use parallelism_config instead.',
remove_version='0.13.0',
),
)
if fsdp_config and parallelism_config:
raise ValueError(
'Both fsdp_config and parallelism_config cannot be provided at the same time. Please use parallelism_config.',
)

log.info(f'Evaluating model: {model_name}')
# Build tokenizer and model
tokenizer_cfg = tokenizer
Expand Down Expand Up @@ -99,6 +124,10 @@ def evaluate_model(
mosaicml_logger.log_metrics(metadata)
mosaicml_logger._flush_metadata(force_flush=True)

fsdp_config = parallelism_config.get(
'fsdp_config',
None,
) if parallelism_config else fsdp_config
if fsdp_config and model.get('load_in_8bit', False):
raise ValueError(
'The FSDP config block is not supported when loading ' +
Expand Down Expand Up @@ -146,7 +175,7 @@ def evaluate_model(
callbacks=callbacks,
loggers=loggers,
precision=precision,
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
load_path=load_path,
load_weights_only=True,
progress_bar=False,
Expand Down
5 changes: 3 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,8 @@ def test_huggingface_conversion_callback(
model=original_model,
device='gpu',
precision=trainer_precision,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
parallelism_config={'fsdp': fsdp_config}
if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=save_interval,
Expand Down Expand Up @@ -1469,7 +1470,7 @@ def test_mptmoe_huggingface_conversion_callback(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=save_interval,
Expand Down
125 changes: 125 additions & 0 deletions tests/eval/test_eval_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import unittest
import warnings

from llmfoundry.command_utils.eval import evaluate_model
from llmfoundry.utils.warnings import VersionedDeprecationWarning


class TestEvaluateModelDeprecation(unittest.TestCase):

def setUp(self):
self.common_args = { # type: ignore
'tokenizer': {
'name': 'test_tokenizer',
},
'model': {
'name': 'test_model',
},
'model_name': 'test',
'dist_timeout': 60,
'run_name': 'test_run',
'seed': 42,
'icl_tasks': [],
'max_seq_len': 512,
'device_eval_batch_size': 1,
'eval_gauntlet_config': None,
'eval_loader_config': None,
'loggers': [],
'python_log_level': None,
'precision': 'fp32',
'eval_gauntlet_df': None,
'eval_subset_num_batches': 1,
'icl_subset_num_batches': None,
'callback_configs': None,
'metadata': None,
'logged_config': {},
}

def test_no_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
import composer.utils.parallelism
deprecated_fsdp_args = list(
composer.utils.parallelism.FSDPConfig.__annotations__.keys(),
)
print(deprecated_fsdp_args)

try:
parallelism_config = {'fsdp': {'verbose': True}}
evaluate_model(
**self.common_args,
parallelism_config=parallelism_config,
)
except ValueError as ve:
if 'parallelism_config cannot contain deprecated fsdp_config arguments.' in str(
ve,
):
self.fail(
'Raised ValueError about deprecated fsdp_config arguments',
)
elif 'Both fsdp_config and parallelism_config cannot be provided at the same time.' in str(
ve,
):
self.fail(
'Raised ValueError about both configs being provided',
)
except Exception:
pass

deprecation_warnings = [
warning for warning in w
if isinstance(warning.message, VersionedDeprecationWarning)
]
if deprecation_warnings:
self.fail('VersionedDeprecationWarning was raised')

def test_deprecation_warning_with_deprecated_arg(self):
# Use assertRaises to catch the expected ValueError
with self.assertRaises(ValueError) as context:
# Directly call evaluate_model; do not use try-except here
evaluate_model(
**self.common_args,
parallelism_config={'activation_checkpointing': True},
)

# Assert that the correct error message is in the exception
self.assertIn(
'parallelism_config cannot contain deprecated fsdp_config arguments.',
str(context.exception),
)

def test_deprecation_warning_with_fsdp_config(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')

try:
evaluate_model(
**self.common_args,
parallelism_config=None,
fsdp_config={'verbose': True},
)
except Exception:
pass

self.assertTrue(
any(
issubclass(warning.category, VersionedDeprecationWarning)
for warning in w
),
)

def test_error_with_both_fsdp_and_parallelism_config(self):
with self.assertRaises(ValueError) as context:
evaluate_model(
**self.common_args,
parallelism_config={'some_arg': True},
fsdp_config={'some_arg': True},
)

self.assertIn(
'Both fsdp_config and parallelism_config cannot be provided at the same time.',
str(context.exception),
)
2 changes: 1 addition & 1 deletion tests/models/hf/test_fsdp_weight_tying.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_fsdp_weight_tying(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=[],
device_train_microbatch_size=1,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hf/test_hf_peft_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_lora_mixed_init(
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
train_dataloader=[],
device_train_microbatch_size=1,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_fsdp_act_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_fsdp_act_checkpoint(
trainer = Trainer(
model=model,
device='gpu',
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
)

assert trainer.state.fsdp_enabled
Expand Down

0 comments on commit 2e3d14f

Please sign in to comment.