Skip to content

Commit

Permalink
oom error handling context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
hjc-puro committed Sep 29, 2023
1 parent 75d4f69 commit df58653
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ toy/
llama2_qlora.ipynb
__pycache__
outputs/
*.swp
2 changes: 1 addition & 1 deletion configs/initial_run.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
toy: false
base_run_name: llama2-7b-qlora

batch_size: 12
batch_size: 16
save_data_points: 5000
gradient_accumulation_steps: 4
model_id: meta-llama/Llama-2-7b-chat-hf
Expand Down
38 changes: 26 additions & 12 deletions train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import sys
import shutil
import torch
import torch.nn as nn

from transformers import Trainer, TrainingArguments
from transformers.utils import logging
from transformers.trainer_utils import EvalLoopOutput
from peft import PeftModelForCausalLM
from train.utils import print_trainable_parameters
from typing import Any, Dict, Union
from contextlib import contextmanager

from eval_args import EvaluationArguments

Expand All @@ -20,22 +23,19 @@ def __init__(self, *args, **kwargs):
self.model_id = kwargs.pop('model_id')
self.eval_args = kwargs.pop('eval_args')
self.hf_api_token = kwargs.pop('hf_api_token')
self.ooms = 0
self.oom_count = 0
super().__init__(*args, **kwargs)

def compute_loss(self, model, inputs):
try:
with self.handle_oom():
outputs = model(**inputs)
return outputs.loss
except RuntimeError as e:
# https://github.com/facebookresearch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L188
if 'out of memory' in str(e).lower():
self.ooms += 1
print(f'| WARNING: ran out of memory, skipping batch. OOM Count: {self.ooms}')
torch.cuda.empty_cache()
return torch.tensor(0.0, requires_grad=True)
else:
raise e
return torch.tensor(0.0, requires_grad=True)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
with self.handle_oom():
return super().training_step(model, inputs)
return torch.tensor(0.0, requires_grad=True)

def evaluation_loop(self, dataloader, description, prediction_loss_only=False, **kwargs) -> EvalLoopOutput:
'''
Expand Down Expand Up @@ -75,3 +75,17 @@ def evaluation_loop(self, dataloader, description, prediction_loss_only=False, *
# max_new_tokens=self.eval_args.max_new_tokens, top_p=self.eval_args.top_p,
# batch_size=self.args.per_device_eval_batch_size,)
return EvalLoopOutput(predictions=None, label_ids=None, metrics={'fake_metric': 0.0}, num_samples=0)

@contextmanager
def handle_oom(self):
# https://github.com/facebookresearch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L188
try:
yield
except RuntimeError as e:
if "out of memory" in str(e):
print("WARNING: Catching Out of Memory Error")
torch.cuda.empty_cache()
self.oom_count += 1
else:
raise e

0 comments on commit df58653

Please sign in to comment.