Skip to content

Commit

Permalink
add oom handling
Browse files Browse the repository at this point in the history
  • Loading branch information
hjc-puro committed Sep 29, 2023
1 parent 18e6bed commit 58be7ef
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions configs/initial_run.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
toy: false
base_run_name: llama2-7b-qlora

batch_size: 12
save_data_points: 2000
batch_size: 10
save_data_points: 5000
gradient_accumulation_steps: 4
model_id: meta-llama/Llama-2-7b-chat-hf
fp16: false
Expand Down
15 changes: 13 additions & 2 deletions train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,22 @@ def __init__(self, *args, **kwargs):
self.model_id = kwargs.pop('model_id')
self.eval_args = kwargs.pop('eval_args')
self.hf_api_token = kwargs.pop('hf_api_token')
self.ooms = 0
super().__init__(*args, **kwargs)

def compute_loss(self, model, inputs):
outputs = model(**inputs)
return outputs.loss
try:
outputs = model(**inputs)
return outputs.loss
except RuntimeError as e:
# https://github.com/facebookresearch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L188
if 'out of memory' in str(e).lower():
self.ooms += 1
print(f'| WARNING: ran out of memory, skipping batch. OOM Count: {self.ooms}')
torch.cuda.empty_cache()
return torch.tensor(0.0, requires_grad=True)
else:
raise e

def evaluation_loop(self, dataloader, description, prediction_loss_only=False, **kwargs) -> EvalLoopOutput:
'''
Expand Down

0 comments on commit 58be7ef

Please sign in to comment.