From 804b544d9b4d229d4e6791da2ba2649de0e27083 Mon Sep 17 00:00:00 2001 From: ahoho Date: Wed, 27 Mar 2024 21:13:07 -0400 Subject: [PATCH] cast cos similarities to cpu to speedup sort --- script/refinement.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/script/refinement.py b/script/refinement.py index 9568ac8..1661a11 100644 --- a/script/refinement.py +++ b/script/refinement.py @@ -1,4 +1,5 @@ import pandas as pd +import torch from utils import * import regex import traceback @@ -18,18 +19,17 @@ def topic_pairs(topic_sent, all_pairs, threshold=0.5, num_pair=2): - num_pair: Number of pairs to return """ # Calculate cosine similarity between all pairs of sentences - sbert = SentenceTransformer("all-MiniLM-L6-v2") + device = "cuda" if torch.cuda.is_available() else "cpu" + sbert = SentenceTransformer("all-MiniLM-L6-v2", device=device) embeddings = sbert.encode(topic_sent, convert_to_tensor=True) - cosine_scores = util.cos_sim(embeddings, embeddings) - over, pairs, dups = [], [], [] + cosine_scores = util.cos_sim(embeddings, embeddings).cpu() + over, pairs = [], [] for i in range(len(cosine_scores)): - for j in range(len(cosine_scores)): - if i != j and sorted([i, j]) not in dups: - pairs.append({"index": [i, j], "score": cosine_scores[i][j]}) - dups.append(sorted([i, j])) + for j in range(i+1, len(cosine_scores)): + pairs.append({"index": [i, j], "score": cosine_scores[i][j].item()}) # Sort and choose num_pair pairs with scores higher than a certain threshold - pairs = sorted(pairs, key=lambda x: x["score"], reverse=True) + pairs = sorted(pairs, key=lambda x: x["score"], reverse=True) # TODO: do w/ numpy argsort count, idx = 0, 0 while count < num_pair and idx < len(pairs): i, j = pairs[idx]["index"]