From 79889e59b4cbb4f7c4c36c6a0579e9f193a0af20 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Thu, 6 Jun 2024 13:46:10 -0700 Subject: [PATCH] fsdp debug --- scripts/train/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index fa98674cda..3cf3d9551d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -105,7 +105,7 @@ def validate_config(train_config: TrainConfig): train_config.model.get('fc_type', 'torch') == 'te' or 'te' in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') ): - fsdp_config = train_config.parallelism_config['fsdp_config'] + fsdp_config = train_config.fsdp_config act_ckpt = fsdp_config.get( 'activation_checkpointing', False, @@ -265,8 +265,7 @@ def main(cfg: DictConfig) -> Trainer: train_loader_config = train_cfg.train_loader # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any] - ] = train_cfg.parallelism_config['fsdp_config'] + fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str