From 349c2ff0d8a5ebf86f071edbde53017ab86e58a3 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 1 Apr 2024 10:45:10 -0700 Subject: [PATCH] Add expandeable segments flag (#1075) --- scripts/train/train.py | 11 +++++++++-- scripts/train/yamls/finetune/dbrx-full-ft.yaml | 1 + scripts/train/yamls/finetune/dbrx-lora-ft.yaml | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 44cfc053f4..c12ca5c54b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -130,11 +130,18 @@ def main(cfg: DictConfig) -> Trainer: # Create copy of config for logging logged_cfg: DictConfig = copy.deepcopy(cfg) + cuda_alloc_conf = [] # Get max split size mb max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None) if max_split_size_mb is not None: - os.environ[ - 'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}' + cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') + + # Expandeable segments + if cfg.pop('expandeable_segments', False): + cuda_alloc_conf.append('expandeable_segments:True') + + if len(cuda_alloc_conf) > 0: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) # Set CUDA lazy loading # This can save a bit of memory if not all modules are needed diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml index 859f59138c..8933a2d3c8 100644 --- a/scripts/train/yamls/finetune/dbrx-full-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml @@ -88,6 +88,7 @@ device_eval_batch_size: 1 precision: amp_bf16 autoresume: true dist_timeout: 3600 +expandeable_segments: true # FSDP fsdp_config: diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml index 3fb4bb99e7..58ca8fae61 100644 --- a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml @@ -96,6 +96,7 @@ device_eval_batch_size: 1 precision: amp_bf16 autoresume: true dist_timeout: 3600 +expandeable_segments: true # FSDP fsdp_config: