From b5fc0fadcbb2342b4633586ca3d6cd450b9e9ec9 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 10 Apr 2024 10:04:55 -0700 Subject: [PATCH] GRT-2819 fix overwritting in script (#1107) --- scripts/train/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 01a351f1e7..96066d5a1d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -80,14 +80,14 @@ def validate_config(cfg: DictConfig): fsdp_config = cfg.get('fsdp_config', None) act_ckpt = fsdp_config.get('activation_checkpointing', False) act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', True) - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False: + 'activation_checkpointing_reentrant', False) + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = False`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' ) - cfg.fsdp_config.activation_checkpointing_reentrant = True + cfg.fsdp_config.activation_checkpointing_reentrant = False if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn(