Skip to content

Commit

Permalink
fsdp patch
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Jun 6, 2024
1 parent d549499 commit f6ca83a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def validate_config(train_config: TrainConfig):
train_config.model.get('fc_type', 'torch') == 'te' or 'te'
in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')
):
fsdp_config = train_config.fsdp_config
fsdp_config = train_config.parallelism_config['fsdp_config']
act_ckpt = fsdp_config.get(
'activation_checkpointing',
False,
Expand Down Expand Up @@ -265,7 +265,8 @@ def main(cfg: DictConfig) -> Trainer:
train_loader_config = train_cfg.train_loader

# Optional fsdp data, fine-tuning, and eval configs
fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config
fsdp_config: Optional[Dict[str, Any]
] = train_cfg.parallelism_config['fsdp_config']

eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders
icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str
Expand Down Expand Up @@ -504,7 +505,7 @@ def main(cfg: DictConfig) -> Trainer:
precision=train_cfg.precision,
algorithms=algorithms,
device_train_microbatch_size=train_cfg.device_train_microbatch_size,
fsdp_config=fsdp_config,
parallelism_config={'fsdp': fsdp_config},
save_folder=train_cfg.save_folder,
save_filename=save_filename,
save_latest_filename=save_latest_filename,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import torch.nn as nn
from accelerate import init_empty_weights
from composer.core.precision import Precision, get_precision_context
from composer.distributed.dist_strategy import prepare_fsdp_module
from composer.models.huggingface import maybe_get_underlying_model
from composer.optim import DecoupledAdamW
from composer.distributed.dist_strategy import prepare_fsdp_module
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
Expand Down

0 comments on commit f6ca83a

Please sign in to comment.