Skip to content

Commit

Permalink
Add expandeable segments flag (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 1, 2024
1 parent 7a8a156 commit 349c2ff
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
11 changes: 9 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,18 @@ def main(cfg: DictConfig) -> Trainer:
# Create copy of config for logging
logged_cfg: DictConfig = copy.deepcopy(cfg)

cuda_alloc_conf = []
# Get max split size mb
max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None)
if max_split_size_mb is not None:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'
cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}')

# Expandeable segments
if cfg.pop('expandeable_segments', False):
cuda_alloc_conf.append('expandeable_segments:True')

if len(cuda_alloc_conf) > 0:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf)

# Set CUDA lazy loading
# This can save a bit of memory if not all modules are needed
Expand Down
1 change: 1 addition & 0 deletions scripts/train/yamls/finetune/dbrx-full-ft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600
expandeable_segments: true

# FSDP
fsdp_config:
Expand Down
1 change: 1 addition & 0 deletions scripts/train/yamls/finetune/dbrx-lora-ft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600
expandeable_segments: true

# FSDP
fsdp_config:
Expand Down

0 comments on commit 349c2ff

Please sign in to comment.