diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 66dd720e5..c3790651a 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -79,6 +79,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche loss = model(**batch).loss loss = loss / gradient_accumulation_steps total_loss += loss.detach().float() + loss = torch.autograd.Variable(loss, required_grad = True) if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward()