-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathsearch.py
37 lines (33 loc) · 1.16 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pandas as pd
import lotus
from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM
lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
reranker = CrossEncoderReranker(model="mixedbread-ai/mxbai-rerank-large-v1")
lotus.settings.configure(lm=lm, rm=rm, reranker=reranker)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
"Introduction to Computer Science",
"Introduction to Data Science",
"Introduction to Machine Learning",
"Introduction to Artificial Intelligence",
"Introduction to Robotics",
"Introduction to Computer Vision",
"Introduction to Natural Language Processing",
"Introduction to Reinforcement Learning",
"Introduction to Deep Learning",
"Introduction to Computer Networks",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir").sem_search(
"Course Name",
"Which course name is most related to computer security?",
K=8,
n_rerank=4,
)
print(df)