vecscan
: A Linear-scan-based High-speed Dense Vector Search Engine
- 2023.09.11 - Release v2.1.0
- new features:
- use
float16
as default dtype andmps
as default device for Apple's MPS environment
- use
- bug fixes:
- fix import problem in vectorizer
- new features:
The vecscan is a dense vector search engine that performs similarity search for embedding databases in linear and greedy way by using the SIMDs (such as AVX2, AVX512, or AMX), CUDA, or MPS through PyTorch. (Note that using a GPU is super fast, but not required, and modern CPUs are more cost effective.)
The vecscan employs simple linear-scan based algorithms, and it does not cause quantization errors that tend to be a problem with approximate neighborhood searches like faiss. The vecscan makes it very easy to build your vector search applications.
In vecscan, the default dtype of embedding vectors is "float16" for MPS, "bfloat16" for others (you can use torch.float32
for all the devices instead), and the file format of embedding database is safetensors. The embedding database, which holds 1 million records of 768-dimensional bfloat16 or float16 vectors, occupies 1.5GB of main memory (or GPU memory). If you're using 8x Sapphire Rapids vCPUs, VectorScanner.search()
will take only 0.1[sec] for similarity score calculation and sorting (1M-records, 768-dim, bfloat16). The benchmarks for major CPUs and GPUs can be found in Benchmarks section.
- Intel Architecture
- Best
- AMX on bfloat16
- Sapphire Rapids
- AMX on bfloat16
- Better
- AVX512_BF16 on bfloat16
- Cooper Lake, 4th gen EPYC.
- The CPUs of GCP N2 instance are Cascade Lake or later in its specification but actually Cooper Lake appeared in our benchmark experiments
- AVX512_BF16 on bfloat16
- Limited
- AVX512F/AVX2 on float32 (bfloat16 is too slow)
- Consumer CPUs or older Xeon CPUs
- AVX512F/AVX2 on float32 (bfloat16 is too slow)
- Best
- Apple Silicon
- Best
- MPS on float16
- Limited
- Without MPS, only float32 is available
- Best
- GPUs (Optional)
- Best
- Ampere or later on bfloat16
- L4(24GB), and L40(48GB) are the best choice for cost performance in 2023
- Volta or later on float16
- T4(16GB)
- Ampere or later on bfloat16
- Best
$ pip install vecscan
$ git clone https://github.com/megagonlabs/vecscan.git
$ cd vecscan
$ pip install -e .
If you're using GPUs with CUDA, install torch with CUDA version by using "--index-url" options.
Example for CUDA 11.8:
$ pip install -U torch --index-url https://download.pytorch.org/whl/cu118
The latency and the throughput of VectorScanner.search()
fully depend on the total FLOPs of the processors.
We recommend you to use the latest XEON platform (such as GCP C3 instance which supports AMX), a Apple MPS device, or a CUDA GPU device (such as NVIDIA L4) with enough memory to load entire safetensors vector file.
If you're using the OpenAI API for embedding, you need to install openai
package and set your api key to the environmental variable beforehand. See Embedding Examples section for details.
from vecscan import VectorScanner, Vectorizer
# load safetensors file
scanner = VectorScanner.load_file("path_to_safetensors")
# for Apple MPS or CUDA devices:
# scanner = VectorScanner.load_file("path_to_safetensors", device="mps")
# scanner = VectorScanner.load_file("path_to_safetensors", device="cuda")
# use OpenAI's text-embedding-ada-002 with the environmental variable "OPENAI_API_KEY"
vectorizer = Vectorizer.create(vectorizer_type="openai_api", model_path="text-embedding-ada-002")
# for float16 or float32:
# vectorizer = Vectorizer.create(vectorizer_type="openai_api", model_path="text-embedding-ada-002", vec_dtype="float16")
# vectorizer = Vectorizer.create(vectorizer_type="openai_api", model_path="text-embedding-ada-002", vec_dtype="float32")
# get query embedding
query_vec = vectorizer.vectorize(["some query text"])[0]
# execute search and get similarity scores and corresponding document ids in descendant order
sim_scores, doc_ids = scanner.search(query_vec)
Although the conditions have not been determined, when calling VectorScanner.score()
or VectorScanner.search()
with query_vec
which comes from a row near the end of the 2d matrix in mps, the process may be forcibly terminated with an error like error: the range subRange.start + subRange.length does not fit in dimension[1]
.
In such cases, you can avoid the error by cloning the Tensor as follows:
sim_scores, doc_ids = scanner.search(query_vec.clone().detach())
vecscan -/- scanner.py ----+- VectorScanner
| |
| +- Similarity (Enum)
| +- Dot # Dot product similarity (assuming all the vectors normalized to have the |v|=1)
| +- Cosine # cosine similarity (general purpose)
| +- L1 # Manhattan distance
| +- L2 # Euclidean distance
|
/- vector_loader -/- VectorLoader (`convert_to_safetensors` command)
| +- CsvVectorLoader
| +- JsonlVectorLoader
| +- BinaryVectorLoader
|
/- vectorizer ----/- Vectorizer (`vectorize` command)
+- VectorizerOpenAIAPI
+- VectorizerBertCLS
+- VectorizerSBert
An implementation for dense vector linear search engine
from vecscan import VectorScanner
VectorScanner.load_file(cls, path, device, normalize=False, break_in=True)
- Create VectorScanner instance and load 2d tensors to
self.shards
from safetensors file - Args:
- path (str): path for safetensors file to load
- device (str): a device to load vectors (typically
cpu
,cuda
, ormps
)- default: "mps" for Apple environmen, "cuda" for CUDA environment, or "cpu" for others
- normalize (bool): normalize the norm of each vector if True
- break_in (bool): execute break-in run after loading entire vectors
- Returns:
- VectorScanner: new VectorScanner instance
- Create VectorScanner instance and load 2d tensors to
VectorScanner(shards)
- Args:
- shards (List[torch.Tensor]): a List of 2d Tensor instances which stores the search target dense vectors
- shape: all elements must be the same dim (size of the element must be < 2**32 bytes)
- dtype: all elements must be the same dtype
- device: all elements must be on the same device
- shards (List[torch.Tensor]): a List of 2d Tensor instances which stores the search target dense vectors
- Args:
score(self, query_vector, similarity_func=Similarity.Dot)
- Calculate the similarity scores between vectors and
query_vector
by usingsimilarity_func
- Args:
- query_vector (torch.Tensor): a dense 1d Tensor instance which stores the embedding vector of query text
- shape: [dim]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- similarity_func (Callable[[Tensor, Tensor], Tensor]): a Callable which calculates the similarities between target dense matrix and query vector
- default:
Similarity.Dot
- Dot product similarity (assuming all the vectors normalized to have the |v|=1) - shapes: arg1=[records, dim], arg2=[dim], return=[records]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- default:
- query_vector (torch.Tensor): a dense 1d Tensor instance which stores the embedding vector of query text
- Returns:
- torch.Tensor: a dense 1d Tensor instance which stores the similarity scores
- shape: [records]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- Calculate the similarity scores between vectors and
search(self, query_vector, target_ids=None, n_best=1000, similarity_func=Similarity.Dot)
- Sort the result of
score()
and then applytarget_ids
filter - Args:
- query_vector (torch.Tensor): a dense 1d Tensor instance which stores the embedding vector of query text
- shape: [dim]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- target_ids (Optional[Union[List[int], Tensor]]): search target is limited to records included in target_ids if specified
- default: None
- top_n (Optional[int]): search result list is limited to top_n if specified
- default:
1000
- default:
- similarity_func (Callable[[Tensor, Tensor], Tensor]): a Callable which calculates the similarities between target dense matrix and query vector
- default:
Similarity.Dot
(Dot product similarity (assuming all the vectors normalized to have the |v|=1) - shapes: arg1=[records, dim], arg2=[dim], return=[records]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- default:
- query_vector (torch.Tensor): a dense 1d Tensor instance which stores the embedding vector of query text
- Returns:
- Tuple[Tensor, Tensor]: a Tuple which contains search results (sorted_scores, doc_ids)
- shapes: [top_n] or [records]
- dtypes: sorted_scores=
self.shards[0].dtype
, doc_ids=torch.int64
- device: same as the elements of
self.shards
- Sort the result of
save_file(self, path)
- Save vectors to new safetensors file
- Args:
- path (str): path for new safetensors file
__len__(self)
- Returns:
- int: number of total rows in
self.shards
- int: number of total rows in
- Returns:
__getitem__(self, index: int)
- Args:
- index (int): index for all rows in
self.shards
- index (int): index for all rows in
- Returns:
- Tensor: a row vector in
self.shards
- shapes: [dim]
- dtype: same as the elements of
self.shards
- device: same as the elements of
self.shards
- Tensor: a row vector in
- Args:
dtype(self)
- Returns:
- dtype of Tensor instances in
self.shards
- dtype of Tensor instances in
- Returns:
device(self)
- Returns:
- device of Tensor instances in
self.shards
- device of Tensor instances in
- Returns:
shape(self)
- Returns:
- shape of entire vectors in
self.shards
- shape of entire vectors in
- Returns:
to(self, dst: Any)
- apply to(dst) for all the Tensor instances in
self.shards
- Args:
- dst: dtype or device
- Returns: self
- apply to(dst) for all the Tensor instances in
Enum for similarity functions.
from vecscan import Similarity
Similarity.Dot
- Dot product similarity (assuming all the vectors normalized to have the |v|=1)
- This implementation avoids the accuracy degradation in
torch.mv()
that occurs when the number of elements in the bfloat16 or float16 matrix on CPUs is 4096 or less
Similarity.Cosine
- Cosine similarity using
torch.nn.functional.cosine_similarity()
- Cosine similarity using
Similarity.L1
- L1 norm (Manhattan distance) using
(m - v).abs().sum(dim=1)
(norm(ord=1)
has a problem of accuracy degradation for bfloat16 and float16)
- L1 norm (Manhattan distance) using
Similarity.L2
- L2 norm (Euclidean distance) using
torch.linalg.norm(m - v, dim=1, ord=2)
- L2 norm (Euclidean distance) using
An abstract class for loading embedding vectors from file
from vecscan import VectorLoader
create(cls, input_format, vec_dim, normalize, safetensors_dtype, shard_size, **kwargs)
- Create an instance of VectorLoader implementation class
- Args:
- input_format (str): a value in [
csv
,jsonl
,binary
] - vec_dim (Optional[int]): vector dimension
- default: None - determine from inputs
- normalize (bool): normalize all the vectors to have |v|=1 if True
- default: False
- safetensors_dtype (Any): dtype of output Tensor instances
- default: "bfloat16"
- shard_size (int): maximum size of each shard in safetensors file
- default:
2**32
(in byte)
- default:
- kwargs: keyword arguments for
VectorLoader
implementation class
- input_format (str): a value in [
- Returns:
- VectorLoader: new instance of `VectorLoader`` implementation class
VectorLoader(input_format, vec_dim, normalize, safetensors_dtype, shard_size, **kwargs)
- Constructor
- Args:
- vec_dim (Optional[int]): vector dimension
- default: None - determine from inputs
- normalize (bool): normalize all the vectors to have |v|=1 if True
- default: False
- safetensors_dtype (Any): dtype of output Tensor instances
- default: "bfloat16"
- shard_size (int): maximum size of each shard in safetensors file
- default:
2**32
(in byte)
- default:
- kwargs: keyword arguments for
VectorLoader
implementation class
- vec_dim (Optional[int]): vector dimension
create_vector_scanner(self, fin)
- Creates a
VectorScanner
instance from given input file. - Args:
- fin (IO): input file
- Returns:
- VectorScanner: new
VectorScanner
instance
- VectorScanner: new
- Creates a
load_shard(self, fin)
- Prototype method for loading single shard from input file
- Args:
- fin (IO): input file
- Returns:
- Optional[Tensor]: a Tensor instance if one or more records exists, None for end of file
CsvVectorLoader
- Converts CSV lines to safetensors
- Skips first line if
skip_first_line
is True
JsonlVectorLoader
- Converts JSONL lines to safetensors
- default: each line has a list element consists of float values
- specify
target_field
if each line has a dict element and the embedding field is directly under the root dict
BinaryVectorLoader
- Converts binary float values to safetensors
vec_dim
andinput_dtype
must be specified to determine the byte length of a row
You can convert existing embedding data to vecscan's safetensors file by running convert_to_safetensors
command.
$ convert_to_safetensors -h
usage: convert_to_safetensors [-h] -f {csv,jsonl,binary} -o OUTPUT_SAFETENSORS_PATH [-d VEC_DIM] [-n] [-s] [-t TARGET_FIELD] [--input_dtype INPUT_DTYPE] [--safetensors_dtype SAFETENSORS_DTYPE] [--shard_size SHARD_SIZE] [-v]
A tool converting vector data to a safetensors file.
optional arguments:
-h, --help show this help message and exit
-f {csv,jsonl,binary}, --input_format {csv,jsonl,binary}
-o OUTPUT_SAFETENSORS_PATH, --output_safetensors_path OUTPUT_SAFETENSORS_PATH
-d VEC_DIM, --vec_dim VEC_DIM
-n, --normalize
-s, --skip_first_line
-t TARGET_FIELD, --target_field TARGET_FIELD
--input_dtype INPUT_DTYPE
--safetensors_dtype SAFETENSORS_DTYPE
--shard_size SHARD_SIZE
-v, --verbose
Output bfloat16
safetensors (for cpu or cuda):
$ cat << EOS > sample.csv
1.0, 0.0, 0.0, 0.0
0.0, 1.0, 0.0, 0.0
EOS
$ convert_to_safetensors -f csv -o sample.csv.safetensors < sample.csv
2023-09-08 01:29:51,613 INFO:vecscan: convert sys.stdin to sample.csv.safetensors
2023-09-08 01:29:51,613 INFO:vecscan: input_format=csv, input_dtype=float32, vec_dim=None, target_field=None, normalize=False, safetensors_dtype=bfloat16, shard_size=4294967296
2023-09-08 01:29:51,613 INFO:vecscan: 4 records converted
Output safetensors for float16
(for mps or cuda):
$ convert_to_safetensors -f csv -o sample.csv.safetensors --safetensors_dtype float16 < sample.csv
If CSV has a title row:
$ convert_to_safetensors -f csv -o sample.csv.safetensors --skip_first_line < sample.csv
$ cat << EOS > sample_list.jsonl
[1.0, 0.0, 0.0, 0.0]
[0.0, 1.0, 0.0, 0.0]
EOS
$ convert_to_safetensors -f jsonl -o sample_list.jsonl.safetensors < sample_list.jsonl
2023-09-08 01:38:47,905 INFO:vecscan: convert sys.stdin to sample_list.jsonl.safetensors
2023-09-08 01:38:47,905 INFO:vecscan: input_format=jsonl, input_dtype=float32, vec_dim=None, target_field=None, normalize=False, safetensors_dtype=bfloat16, shard_size=4294967296
2023-09-08 01:38:47,906 INFO:vecscan: 2 records converted
$ cat << EOS > sample_dict.jsonl
{"vec": [1.0, 0.0, 0.0, 0.0]}
{"vec": [0.0, 1.0, 0.0, 0.0]}
EOS
$ convert_to_safetensors -f jsonl -t vec -o sample_dict.jsonl.safetensors < sample_dict.jsonl
2023-09-08 01:41:21,840 INFO:vecscan: convert sys.stdin to sample_dict.jsonl.safetensors
2023-09-08 01:41:21,840 INFO:vecscan: input_format=jsonl, input_dtype=float32, vec_dim=None, target_field=vec, normalize=False, safetensors_dtype=bfloat16, shard_size=4294967296
2023-09-08 01:41:21,840 INFO:vecscan: 2 records converted
$ python -c 'import sys; import numpy; m = numpy.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], numpy.float32); m.tofile(sys.stdout.buffer)' > sample.vec
$ convert_to_safetensors -f binary -d 4 --input_dtype float32 -o sample.vec.safetensors < sample.vec
2023-09-08 01:50:24,489 INFO:vecscan: convert sys.stdin to sample.vec.safetensors
2023-09-08 01:50:24,489 INFO:vecscan: input_format=binary, input_dtype=float32, vec_dim=4, target_field=None, normalize=False, safetensors_dtype=bfloat16, shard_size=4294967296
2023-09-08 01:50:24,489 INFO:vecscan: 2 records converted
An abstract class for extracting embedding vectors from transformer models
from vecscan import Vectorizer
Vectorizer.create(cls, vectorizer_type, model_path, vec_dtype, **kwargs)
- Creates an instance of Vectorizer implementation class
- Args:
- vectorizer_type (str): a value in [
openai_api
,bert_cls
,sbert
] - model_path (str): path for the model directory or model name in API
- vec_dtype (str): dtype string of output Tensor
- default:
bfloat16
- default:
- kwargs: keyword arguments for each Vectorizer implementation class
- vectorizer_type (str): a value in [
- Returns:
- Vectorizer: new instance of Vectorizer implementation class
Vectorizer(vectorizer_type, model_path, batch_size, vec_dim, vec_dtype, device, **kwargs)
- Constructor
- Args:
- vectorizer_type (str): a value in [
openai_api
,bert_cls
,sbert
] - model_path (str): path for the model directory or model name in API
- batch_size (int): batch size
- vec_dim (int): vector dimension
- vec_dtype (str): dtype string of output Tensor
- default:
bfloat16
- default:
- kwargs: not used
- vectorizer_type (str): a value in [
vectorize(self, batch)
- Prototype method for extracting embedding vectors
- Args:
batch
: text list to embed- type:
List[str]
- type:
- Returns:
- type:
torch.Tensor
- shape: [len(batch), dim]
- dtype: vec_dtype
- device: depends on Vectorizer implementation class
- type:
vectorize_file(self, fin, fout)
- Extracts embedding vectors for file
- Args:
- fin (IO): text input
- fout (IO): binary output
- Returns: int: number of embedded lines
VectorizerOpenAIAPI
- Obtains text embedding using OpenAI API
VectorizerBertCLS
- Extracts [CLS] embeddings for the specified
hidden_layer
from BERT-style transformer.- default:
hidden_layer=3
- default:
- In Self-Guided Contrastive Learning for BERT Sentence Representations, they reported that the performance of [CLS] token embedding varies greatly depending on the layer to use.
- Extracts [CLS] embeddings for the specified
VectorizerSBERT
- Extracts embeddings from
sentence-transformers
- Extracts embeddings from
You can embed entire text lines in a file by running vectorize
command.
$ vectorize -h
usage: vectorize [-h] -o OUTPUT_BASE_PATH -t {openai_api,bert_cls,sbert} -m MODEL_PATH [--vec_dtype VEC_DTYPE] [--safetensors_dtype SAFETENSORS_DTYPE] [-r] [--batch_size BATCH_SIZE] [--vec_dim VEC_DIM] [--device DEVICE] [--max_retry MAX_RETRY] [--max_length MAX_LENGTH] [--hidden_layer HIDDEN_LAYER] [--normalize] [-v]
Vectorize text lines from stdin.
optional arguments:
-h, --help show this help message and exit
-o OUTPUT_BASE_PATH, --output_base_path OUTPUT_BASE_PATH
-t {openai_api,bert_cls,sbert}, --vectorizer_type {openai_api,bert_cls,sbert}
-m MODEL_PATH, --model_path MODEL_PATH
--vec_dtype VEC_DTYPE
--safetensors_dtype SAFETENSORS_DTYPE
-r, --remove_vec_file
--batch_size BATCH_SIZE
--vec_dim VEC_DIM
--device DEVICE
--max_retry MAX_RETRY
--max_length MAX_LENGTH
--hidden_layer HIDDEN_LAYER
--normalize
-v, --verbose
Install openai
package`:
$ pip install openai
Then set API key and run `vectorize`` command:
$ export HISTCONTROL=ignorespace # do not save blankspace-started commands to history
$ export OPENAI_API_KEY=xxxx # get secret key from https://platform.openai.com/account/api-keys
$ vectorize -t openai_api -m text-embedding-ada-002 -o ada-002 < input.txt
2023-08-29 07:28:44,514 INFO:__main__: Will create following files:
2023-08-29 07:28:44,514 INFO:__main__: ada-002.vec
2023-08-29 07:28:44,514 INFO:__main__: ada-002.vec.info
2023-08-29 07:28:44,514 INFO:__main__: ada-002.safetensors
2023-08-29 07:28:44,514 INFO:__main__: embedding started
1000it [00:03, 293.99it/s]
2023-08-29 07:28:48,702 INFO:__main__: {
"vec_dim": 1536,
"vec_count": 1000,
"vec_dtype": "float32",
"vectorizer_type": "openai_api",
"model_path": "text-embedding-ada-002"
}
2023-08-29 07:28:48,702 INFO:__main__: embedding finished
2023-08-29 07:28:48,702 INFO:__main__: convert to safetensors
2023-08-29 07:28:48,718 INFO:__main__: convert finished
2023-08-29 07:28:48,719 INFO:__main__: ada-002.vec removed
For other dtypes, add a option --safetensors_dtype
to vectorize
.
$ vectorize -t openai_api -m text-embedding-ada-002 -o ada-002 --safetensors_dtype float16 < input.txt
You need to use GPUs to embed text by BERT-like transformer models.
Install transformers
and the tokenizer packages required in cl-tohoku/bert-japanese-base-v3
:
$ pip install transformers fugashi unidic-lite
Then run `vectorize`` command:
$ vectorize -t bert_cls -m cl-tohoku/bert-base-japanese-v3 -o bert-base-japanese-v3 < input.txt
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2023-08-29 07:26:04,673 INFO:__main__: Will create following files:
2023-08-29 07:26:04,673 INFO:__main__: bert-base-japanese-v3.vec
2023-08-29 07:26:04,673 INFO:__main__: bert-base-japanese-v3.vec.info
2023-08-29 07:26:04,673 INFO:__main__: bert-base-japanese-v3.safetensors
2023-08-29 07:26:04,673 INFO:__main__: embedding started
1000it [00:04, 240.00027it/s]
2023-08-29 07:26:11,736 INFO:__main__: {
"vec_dim": 768,
"vec_count": 1000,
"vec_dtype": "float32",
"vectorizer_type": "bert_cls",
"model_path": "cl-tohoku/bert-base-japanese-v3"
}
2023-08-29 07:26:11,736 INFO:__main__: embedding finished
2023-08-29 07:26:11,739 INFO:__main__: convert to safetensors
2023-08-29 07:26:11,750 INFO:__main__: convert finished
2023-08-29 07:26:11,751 INFO:__main__: bert-base-japanese-v3.vec removed
Install sentence-transformers
package:
$ pip install transformers sentence-transformers
Then run `vectorize`` command:
$ vectorize -t sbert -m path_to_sbert_model -o sbert < input.txt
2023-08-29 07:26:53,544 INFO:__main__: Will create following files:
2023-08-29 07:26:53,544 INFO:__main__: sbert.vec
2023-08-29 07:26:53,544 INFO:__main__: sbert.vec.info
2023-08-29 07:26:53,544 INFO:__main__: sbert.safetensors
2023-08-29 07:26:53,544 INFO:__main__: embedding started
1000it [00:02, 342.23it/s]
2023-08-29 07:26:56,757 INFO:__main__: {
"vec_dim": 768,
"vec_count": 1000,
"vec_dtype": "float32",
"vectorizer_type": "sbert",
"model_path": "hysb_poor_mans_finetuned_posi/"
}
2023-08-29 07:26:56,757 INFO:__main__: embedding finished
2023-08-29 07:26:56,757 INFO:__main__: convert to safetensors
2023-08-29 07:26:56,768 INFO:__main__: convert finished
2023-08-29 07:26:56,769 INFO:__main__: sbert.vec removed
- GCP us-central1-b
- balanced persistent disk 100GB
- ubuntu 22.04
- cuda 11.8 (for GPUs)
- python 3.10.12
- torch 2.0.1
- vectors
- 768 dimension x 13,046,560 records = 10,019,758,080 elements
- bfloat16 or float16 - 20.04[GB]
- float32 - 40.08[GB]
GCP Instance | RAM | Cost / Month | bfloat16/float16 in [sec] | float32 in [sec] | ||||||
---|---|---|---|---|---|---|---|---|---|---|
score() | search() | search() targets | score() | search() | search() targets | |||||
L4 GPU x 1 | ||||||||||
g2-standard-8 | 24GB (CUDA) | $633 | 2.7e-4 | 5.3e-4 | 0.695 | - | - | - | ||
A100 GPU x 1 | ||||||||||
a2-highgpu-1g | 40GB (CUDA) | $2,692 | 2.6e-4 | 5.6e-4 | 0.696 | 2.5e-4 | 6.0e-4 | 0.697 | ||
Apple M1 Max 64GB | ||||||||||
apple-m1-max | 64GB (MPS) | - | 8.8e-4 | 1.3e-3 | 0.263 | 1.6e-3 | 2.0e-3 | 0.274 | ||
Sapphire Rapids (SR) | ||||||||||
#1 | c3-highmem-4 | 32GB | $216 | 1.072 | 1.802 | 1.994 | - | - | - | |
#2 | c3-standard-8 | 32GB | $315 | 0.533 | 1.217 | 1.413 | - | - | - | |
#3 | c3-highmem-8 | 64GB | $421 | 0.531 | 1.209 | 1.398 | 0.852 | 2.386 | 2.117 | |
#4 | c3-highcpu-22 | 44GB | $702 | 0.231 | 0.887 | 1.077 | 0.392 | 1.948 | 1.695 | |
#5 | c3-highcpu-44 | 88GB | $1,394 | 0.174 | 0.829 | 1.033 | 0.356 | 1.900 | 1.644 | |
Cooper Lake (CL) | ||||||||||
#1 | n2-highmem-4 | 32GB | $163 | 1.250 | 2.029 | 2.217 | - | - | - | |
#2 | n2-standard-8 | 32GB | $237 | 0.643 | 1.388 | 1.671 | - | - | - | |
#3 | n2-highcpu-32 | 32GB | $702 | 0.259 | 0.969 | 1.196 | - | - | - | |
#4 | n2-highmem-8 | 64GB | $316 | 0.686 | 1.422 | 1.628 | 0.923 | 2.410 | 2.255 | |
#5 | n2-standard-16/td> | 64GB | $464 | 0.375 | 1.084 | 1.307 | 0.508 | 1.967 | 1.820 | |
#6 | n2-highcpu-48 | 48GB | $1,015 | 0.209 | 0.916 | 1.161 | 0.370 | 1.878 | 1.743 | |
Haswell (HW) | ||||||||||
#1 | n1-highmem-8 | 52GB | $251 | 62.317 | 63.095 | 63.461 | 0.876 | 2.760 | 2.727 | |
#2 | n1-standard-16 | 60GB | $398 | 62.218 | 63.048 | 63.397 | 0.530 | 2.365 | 2.298 | |
#3 | n1-highcpu-64 | 57GB | $1,169 | 62.141 | 63.026 | 63.818 | 0.530 | 2.325 | 2.280 | |