Skip to content

Commit

Permalink
updated finetune_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 16, 2023
1 parent c3b3539 commit 7bd986a
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions align_data/finetuning/finetune_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,67 @@ def train(model, dataloader, optimizer, criterion, device):
total_loss += loss.item()

return total_loss / len(dataloader)


def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0.0

with torch.no_grad():
for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader):
target = target.float().to(device)

output1 = model(text1_embedding)
output2 = model(text2_embedding)

loss = criterion(output1, output2, target)
total_loss += loss.item()

return total_loss / len(dataloader)


def finetune_embeddings():
# Hyperparameters & Configuration
EMBEDDING_DIM = 1536
HIDDEN_DIM = 512
EPOCHS = 10
LEARNING_RATE = 0.001
MARGIN = 2.0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = FinetuningDataset()
dataloader = DataLoader(dataset, batch_size=32, num_workers=4) # Increase num_workers

model = FineTuneModel(EMBEDDING_DIM, HIDDEN_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True)
criterion = ContrastiveLoss(MARGIN)

# Assuming you've split your data and have a separate validation set
validation_dataset = FinetuningDataset()
validation_dataloader = DataLoader(validation_dataset, batch_size=32, num_workers=4)
best_val_loss = float('inf')

epochs_without_improvement = 0
max_epochs_without_improvement = 5 # stop after 5 epochs without improvement

for epoch in range(EPOCHS):
train_loss = train(model, dataloader, optimizer, criterion, DEVICE)
validate_loss = validate(model, validation_dataloader, criterion, DEVICE)

scheduler.step(validate_loss)
if validate_loss < best_val_loss:
best_val_loss = validate_loss
torch.save(model.state_dict(), 'best_finetuned_model.pth')
epochs_without_improvement = 0
else:
epochs_without_improvement += 1

print(f'Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}')

if epochs_without_improvement >= max_epochs_without_improvement:
print("Early stopping due to no improvement in validation loss.")
break

# Save model
torch.save(model.state_dict(), 'finetuned_model.pth')

0 comments on commit 7bd986a

Please sign in to comment.