forked from YenRaven/annoy_ltm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embeddings.py
38 lines (30 loc) · 1.52 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# ./embeddings.py
from modules import shared
from extensions.annoy_ltm.helpers import get_device
import torch
def generate_embeddings(text, logger):
"""
Generates embeddings for a given text.
Parameters:
text (str): The input text to generate embeddings for.
logger (logging.Logger): A logger to log the process.
Returns:
np.ndarray: The generated embeddings.
"""
input_ids = shared.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.to(get_device()) # Move input_ids to the model's device
input_ids = input_ids.long() # ensure the values are not floats
with torch.no_grad():
if hasattr(shared.model.model, 'embed_tokens'):
input_embeds = shared.model.model.embed_tokens(input_ids)
elif hasattr(shared.model.model, 'get_input_embeddings'):
input_embeds = shared.model.model.get_input_embeddings()(input_ids)
elif hasattr(shared.model.model, 'model'): # Reported in issue #17
if hasattr(shared.model.model.model, 'embed_tokens'):
input_embeds = shared.model.model.model.embed_tokens(input_ids)
else:
raise AttributeError("The model doesn't have an 'embed_tokens' or 'get_input_embeddings' method.")
input_embeds = input_embeds.mean(dim=1).squeeze(0) # Remove the extra dimension
result = input_embeds.cpu().numpy().flatten() # Convert to NumPy array and flatten
logger(f"generating embeddings for text: {text}\n{result}", 5)
return result