Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Nov 7, 2023
1 parent 1d9bbb3 commit 1b73fc7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import rmm

pool = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource(), initial_pool_size=2**30)
rmm.mr.set_current_device_resource(pool)
from rmm.allocators.cupy import rmm_cupy_allocator
import cupy as cp

cp.cuda.set_allocator(rmm_cupy_allocator)

import argparse
import os

import cupy as cp
import numpy as np
import math
from timeit import default_timer as timer
from pylibraft.neighbors.brute_force import knn
import rmm
from pylibraft.common import DeviceResources

from utils import dtype_from_filename, suffix_from_dtype, memmap_bin_file, write_bin
from pylibraft.neighbors.brute_force import knn
from rmm.allocators.cupy import rmm_cupy_allocator
from utils import memmap_bin_file, suffix_from_dtype, write_bin


def generate_random_queries(n_queries, n_features, dtype=np.float32):
print("Generating random queries")
if np.issubdtype(dtype, np.integer):
queries = cp.random.randint(0, 255, size=(n_queries, n_features), dtype=dtype)
queries = cp.random.randint(
0, 255, size=(n_queries, n_features), dtype=dtype
)
else:
queries = cp.random.uniform(size=(n_queries, n_features)).astype(dtype)
return queries


def choose_random_queries(dataset, n_queries):
print("Choosing random vector from dataset as query vectors")
query_idx = np.random.choice(dataset.shape[0], size=(n_queries,), replace=False)
query_idx = np.random.choice(
dataset.shape[0], size=(n_queries,), replace=False
)
return dataset[query_idx, :]


Expand All @@ -65,7 +60,7 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):

X = cp.asarray(dataset[i : i + n_batch, :], cp.float32)

D, I = knn(
D, Ind = knn(
X,
queries,
k,
Expand All @@ -75,13 +70,13 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):
)
handle.sync()

D, I = cp.asarray(D), cp.asarray(I)
D, Ind = cp.asarray(D), cp.asarray(Ind)
if distances is None:
distances = D
indices = I
indices = Ind
else:
distances = cp.concatenate([distances, D], axis=1)
indices = cp.concatenate([indices, I], axis=1)
indices = cp.concatenate([indices, Ind], axis=1)
idx = cp.argsort(distances, axis=1)[:, :k]
distances = cp.take_along_axis(distances, idx, axis=1)
indices = cp.take_along_axis(indices, idx, axis=1)
Expand All @@ -92,19 +87,30 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):


if __name__ == "__main__":
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(), initial_pool_size=2**30
)
rmm.mr.set_current_device_resource(pool)
cp.cuda.set_allocator(rmm_cupy_allocator)

parser = argparse.ArgumentParser(
prog="generate_groundtruth",
description="Generate true neighbors using exact NN search. "
"The input and output files are in big-ann-benchmark's binary format.",
epilog="""Example usage
# With existing query file
python generate_groundtruth.py --dataset /dataset/base.1B.fbin --output=groundtruth_dir --queries=/dataset/query.public.10K.fbin
python generate_groundtruth.py --dataset /dataset/base.1B.fbin \
--output=groundtruth_dir --queries=/dataset/query.public.10K.fbin
# With randomly generated queries
python generate_groundtruth.py --dataset /dataset/base.1B.fbin --output=groundtruth_dir --queries=random --n_queries=10000
# Using only a subset of the dataset. Define queries by randomly selecting vectors from the (subset of the) dataset.
python generate_groundtruth.py --dataset /dataset/base.1B.fbin --nrows=2000000 --cols=128 --output=groundtruth_dir --queries=random-choice --n_queries=10000
python generate_groundtruth.py --dataset /dataset/base.1B.fbin \
--output=groundtruth_dir --queries=random --n_queries=10000
# Using only a subset of the dataset. Define queries by randomly
# selecting vectors from the (subset of the) dataset.
python generate_groundtruth.py --dataset /dataset/base.1B.fbin \
--nrows=2000000 --cols=128 --output=groundtruth_dir \
--queries=random-choice --n_queries=10000
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
Expand All @@ -114,9 +120,9 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):
"--queries",
type=str,
default="random",
help="Queries file name, or one of 'random-choice' or 'random' (default). "
"'random-choice': select n_queries vectors from the input dataset. "
"'random': generate n_queries as uniform random numbers.",
help="Queries file name, or one of 'random-choice' or 'random' "
"(default). 'random-choice': select n_queries vectors from the input "
"dataset. 'random': generate n_queries as uniform random numbers.",
)
parser.add_argument(
"--output",
Expand All @@ -129,31 +135,37 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):
"--n_queries",
type=int,
default=10000,
help="Number of quries to generate (if no query file is given). Default: 10000.",
help="Number of quries to generate (if no query file is given). "
"Default: 10000.",
)

parser.add_argument(
"-N",
"--rows",
default=0,
type=int,
help="use only first N rows from dataset, by default the whole dataset is used",
help="use only first N rows from dataset, by default the whole "
"dataset is used",
)
parser.add_argument(
"-D",
"--cols",
default=0,
type=int,
help="number of features (dataset columns). Must be specified if --rows is used. Default: read from dataset file.",
help="number of features (dataset columns). Must be specified if "
"--rows is used. Default: read from dataset file.",
)
parser.add_argument(
"--dtype",
type=str,
help="Dataset dtype. If not given, then derived from filename extension.",
help="Dataset dtype. When not specified, then derived from extension.",
)

parser.add_argument(
"-k", type=int, default=100, help="Number of neighbors (per query) to calculate"
"-k",
type=int,
default=100,
help="Number of neighbors (per query) to calculate",
)
parser.add_argument(
"--metric",
Expand Down Expand Up @@ -195,9 +207,13 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):

if args.queries == "random" or args.queries == "random-choice":
if args.n_queries is None:
raise RuntimeError("n_queries must be given to generate random queries")
raise RuntimeError(
"n_queries must be given to generate random queries"
)
if args.queries == "random":
queries = generate_random_queries(args.n_queries, n_features, dtype)
queries = generate_random_queries(
args.n_queries, n_features, dtype
)
elif args.queries == "random-choice":
queries = choose_random_queries(dataset, args.n_queries)

Expand Down
23 changes: 16 additions & 7 deletions python/raft-ann-bench/src/raft-ann-bench/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.
#

import numpy as np
import cupy as cp
import time
import os
import time

import cupy as cp
import numpy as np


def dtype_from_filename(filename):
Expand Down Expand Up @@ -47,7 +48,9 @@ def suffix_from_dtype(dtype):
raise RuntimeError("Not supported dtype extension" + dtype)


def memmap_bin_file(bin_file, dtype, shape=None, mode="r", size_dtype=np.uint32):
def memmap_bin_file(
bin_file, dtype, shape=None, mode="r", size_dtype=np.uint32
):
extent_itemsize = np.dtype(size_dtype).itemsize
offset = int(extent_itemsize) * 2
if bin_file is None:
Expand All @@ -60,7 +63,9 @@ def memmap_bin_file(bin_file, dtype, shape=None, mode="r", size_dtype=np.uint32)
if shape is None:
shape = (a[0], a[1])
print("Read shape from file", shape)
return np.memmap(bin_file, mode=mode, dtype=dtype, offset=offset, shape=shape)
return np.memmap(
bin_file, mode=mode, dtype=dtype, offset=offset, shape=shape
)
elif mode[0] == "w":
if shape is None:
raise ValueError("Need to specify shape to map file in write mode")
Expand All @@ -74,7 +79,9 @@ def memmap_bin_file(bin_file, dtype, shape=None, mode="r", size_dtype=np.uint32)
a[1] = shape[1]
a.flush()
del a
fp = np.memmap(bin_file, mode="r+", dtype=dtype, offset=offset, shape=shape)
fp = np.memmap(
bin_file, mode="r+", dtype=dtype, offset=offset, shape=shape
)
return fp

# print('# {}: shape: {}, dtype: {}'.format(bin_file, shape, dtype))
Expand All @@ -92,7 +99,9 @@ def calc_recall(ann_idx, true_nn_idx):
ann_idx = cp.asnumpy(ann_idx)
if ann_idx.shape != true_nn_idx.shape:
raise RuntimeError(
"Incompatible shapes {} vs {}".format(ann_idx.shape, true_nn_idx.shape)
"Incompatible shapes {} vs {}".format(
ann_idx.shape, true_nn_idx.shape
)
)
n = 0
for i in range(ann_idx.shape[0]):
Expand Down

0 comments on commit 1b73fc7

Please sign in to comment.