Skip to content

Commit

Permalink
cast cos similarities to cpu to speedup sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoho committed Mar 28, 2024
1 parent 00a3ac6 commit 804b544
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions script/refinement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
import torch
from utils import *
import regex
import traceback
Expand All @@ -18,18 +19,17 @@ def topic_pairs(topic_sent, all_pairs, threshold=0.5, num_pair=2):
- num_pair: Number of pairs to return
"""
# Calculate cosine similarity between all pairs of sentences
sbert = SentenceTransformer("all-MiniLM-L6-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
sbert = SentenceTransformer("all-MiniLM-L6-v2", device=device)
embeddings = sbert.encode(topic_sent, convert_to_tensor=True)
cosine_scores = util.cos_sim(embeddings, embeddings)
over, pairs, dups = [], [], []
cosine_scores = util.cos_sim(embeddings, embeddings).cpu()
over, pairs = [], []
for i in range(len(cosine_scores)):
for j in range(len(cosine_scores)):
if i != j and sorted([i, j]) not in dups:
pairs.append({"index": [i, j], "score": cosine_scores[i][j]})
dups.append(sorted([i, j]))
for j in range(i+1, len(cosine_scores)):
pairs.append({"index": [i, j], "score": cosine_scores[i][j].item()})

# Sort and choose num_pair pairs with scores higher than a certain threshold
pairs = sorted(pairs, key=lambda x: x["score"], reverse=True)
pairs = sorted(pairs, key=lambda x: x["score"], reverse=True) # TODO: do w/ numpy argsort
count, idx = 0, 0
while count < num_pair and idx < len(pairs):
i, j = pairs[idx]["index"]
Expand Down

0 comments on commit 804b544

Please sign in to comment.