diff --git a/align_data/finetuning/finetune_model.py b/align_data/finetuning/finetune_model.py index 9e1e601c..aa761778 100644 --- a/align_data/finetuning/finetune_model.py +++ b/align_data/finetuning/finetune_model.py @@ -63,3 +63,67 @@ def train(model, dataloader, optimizer, criterion, device): total_loss += loss.item() return total_loss / len(dataloader) + + +def validate(model, dataloader, criterion, device): + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader): + target = target.float().to(device) + + output1 = model(text1_embedding) + output2 = model(text2_embedding) + + loss = criterion(output1, output2, target) + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def finetune_embeddings(): + # Hyperparameters & Configuration + EMBEDDING_DIM = 1536 + HIDDEN_DIM = 512 + EPOCHS = 10 + LEARNING_RATE = 0.001 + MARGIN = 2.0 + DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + dataset = FinetuningDataset() + dataloader = DataLoader(dataset, batch_size=32, num_workers=4) # Increase num_workers + + model = FineTuneModel(EMBEDDING_DIM, HIDDEN_DIM).to(DEVICE) + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True) + criterion = ContrastiveLoss(MARGIN) + + # Assuming you've split your data and have a separate validation set + validation_dataset = FinetuningDataset() + validation_dataloader = DataLoader(validation_dataset, batch_size=32, num_workers=4) + best_val_loss = float('inf') + + epochs_without_improvement = 0 + max_epochs_without_improvement = 5 # stop after 5 epochs without improvement + + for epoch in range(EPOCHS): + train_loss = train(model, dataloader, optimizer, criterion, DEVICE) + validate_loss = validate(model, validation_dataloader, criterion, DEVICE) + + scheduler.step(validate_loss) + if validate_loss < best_val_loss: + best_val_loss = validate_loss + torch.save(model.state_dict(), 'best_finetuned_model.pth') + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + + print(f'Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}') + + if epochs_without_improvement >= max_epochs_without_improvement: + print("Early stopping due to no improvement in validation loss.") + break + + # Save model + torch.save(model.state_dict(), 'finetuned_model.pth')