From f6ca83a627783716d9e10830268feb5c3564d57b Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 6 Jun 2024 13:07:29 -0700 Subject: [PATCH] fsdp patch --- scripts/train/train.py | 7 ++++--- tests/models/test_model.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index c9e2d67bf4..fa98674cda 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -105,7 +105,7 @@ def validate_config(train_config: TrainConfig): train_config.model.get('fc_type', 'torch') == 'te' or 'te' in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') ): - fsdp_config = train_config.fsdp_config + fsdp_config = train_config.parallelism_config['fsdp_config'] act_ckpt = fsdp_config.get( 'activation_checkpointing', False, @@ -265,7 +265,8 @@ def main(cfg: DictConfig) -> Trainer: train_loader_config = train_cfg.train_loader # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + fsdp_config: Optional[Dict[str, Any] + ] = train_cfg.parallelism_config['fsdp_config'] eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str @@ -504,7 +505,7 @@ def main(cfg: DictConfig) -> Trainer: precision=train_cfg.precision, algorithms=algorithms, device_train_microbatch_size=train_cfg.device_train_microbatch_size, - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, save_folder=train_cfg.save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 78f7ee0d5c..5f77b72d00 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -13,9 +13,9 @@ import torch.nn as nn from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context +from composer.distributed.dist_strategy import prepare_fsdp_module from composer.models.huggingface import maybe_get_underlying_model from composer.optim import DecoupledAdamW -from composer.distributed.dist_strategy import prepare_fsdp_module from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om