-
Notifications
You must be signed in to change notification settings - Fork 15
/
embed.py
105 lines (87 loc) · 3.19 KB
/
embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from datasets import load_dataset
import numpy as np
import argparse
from FlagEmbedding import FlagModel
import torch
from transformers import AutoTokenizer, AutoModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset_path",
type=str
)
parser.add_argument(
"--embedding_output_path",
type=str,
default="embeddings.npy",
help="Path to save the embeddings, set to None to not save the embeddings"
)
parser.add_argument(
"--dataset_output_path",
type=str,
default="embedded_dataset",
help="Path to save the embedded dataset, set to None to not save the embedded dataset"
)
parser.add_argument(
"--field_to_embed",
type=str,
default="text"
)
parser.add_argument(
"--model_name",
type=str,
default="BAAI/bge-large-en",
help="Currently only supports BAAI/bge-large-en, BAAI/bge-large-en-v1.5"
)
parser.add_argument(
"--normalize",
type=bool,
default=False,
help="Whether to normalize the embeddings"
)
parser.add_argument(
"--max_length",
type=int,
default=512
)
parser.add_argument(
"--batch_size",
type=int,
default=256
)
return parser.parse_args()
class BGE:
def __init__(self, model_name, normalize_embeddings, max_length=512):
if max_length > 512:
print("Specified max length is greater than 512, setting to 512")
max_length = 512
self.model = FlagModel(model_name, normalize_embeddings=normalize_embeddings, max_length=max_length)
def encode(self, texts, batch_size=256):
return self.model.encode_queries(texts, batch_size=batch_size)
class BGE_Tokenizer:
"""For classification"""
def __init__(self, model_name, normalize_embeddings, max_length=512):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
self.normalize_embeddings = normalize_embeddings
def encode(self, texts):
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512)
with torch.no_grad():
model_output = self.model(**encoded_input)
embeddings = model_output[0][:, 0] # CLS token pooling
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
def embed_dataset(args):
dataset = load_dataset(args.dataset_path)["train"]
model = BGE(args.model_name, args.normalize, args.max_length)
embeddings = model.encode(list(dataset["text"]), batch_size=args.batch_size)
if args.embedding_output_path is not None:
np.save(args.embedding_output_path, embeddings)
if args.dataset_output_path is not None:
if "embedding" in dataset.features:
dataset = dataset.remove_columns("embedding")
dataset = dataset.map(lambda example, idx: {"embedding": embeddings[idx]}, with_indices=True)
dataset.save_to_disk(args.dataset_output_path)
return dataset, embeddings