diff --git a/scripts/train/train.py b/scripts/train/train.py index c9e2d67bf4..fa98674cda 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -105,7 +105,7 @@ def validate_config(train_config: TrainConfig): train_config.model.get('fc_type', 'torch') == 'te' or 'te' in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') ): - fsdp_config = train_config.fsdp_config + fsdp_config = train_config.parallelism_config['fsdp_config'] act_ckpt = fsdp_config.get( 'activation_checkpointing', False, @@ -265,7 +265,8 @@ def main(cfg: DictConfig) -> Trainer: train_loader_config = train_cfg.train_loader # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + fsdp_config: Optional[Dict[str, Any] + ] = train_cfg.parallelism_config['fsdp_config'] eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str @@ -504,7 +505,7 @@ def main(cfg: DictConfig) -> Trainer: precision=train_cfg.precision, algorithms=algorithms, device_train_microbatch_size=train_cfg.device_train_microbatch_size, - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, save_folder=train_cfg.save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 78f7ee0d5c..5f77b72d00 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -13,9 +13,9 @@ import torch.nn as nn from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context +from composer.distributed.dist_strategy import prepare_fsdp_module from composer.models.huggingface import maybe_get_underlying_model from composer.optim import DecoupledAdamW -from composer.distributed.dist_strategy import prepare_fsdp_module from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om