diff --git a/scripts/train/train.py b/scripts/train/train.py index 01a351f1e7..96066d5a1d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -80,14 +80,14 @@ def validate_config(cfg: DictConfig): fsdp_config = cfg.get('fsdp_config', None) act_ckpt = fsdp_config.get('activation_checkpointing', False) act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', True) - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False: + 'activation_checkpointing_reentrant', False) + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = False`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' ) - cfg.fsdp_config.activation_checkpointing_reentrant = True + cfg.fsdp_config.activation_checkpointing_reentrant = False if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn(