Skip to content

Commit

Permalink
add algo-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Oct 25, 2023
1 parent 14e2c5d commit b00942d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 28 deletions.
94 changes: 77 additions & 17 deletions python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,32 +345,74 @@ def load_lines(results_path, result_files, method, index_key):


def load_all_results(
dataset_path, algorithms, k, batch_size, method, index_key
dataset_path, algorithms, groups, algo_groups, k, batch_size, method,
index_key
):
results_path = os.path.join(dataset_path, "result", method)
result_files = os.listdir(results_path)
print(result_files)
result_files = [result_file for result_file in result_files \
if ".csv" in result_file]
# print(result_files)
if method == "search":
result_files = [
result_filename
for result_filename in result_files
if f"{k}-{batch_size}" in result_filename
]
if len(algorithms) > 0:
result_files = [
result_filename
for result_filename in result_files
if result_filename.split("-")[0] in algorithms
]
elif method == "build":
if len(algorithms) > 0:
result_files = [
result_filename
for result_filename in result_files
if result_filename.split("-")[0] in algorithms
]

results = load_lines(results_path, result_files, method, index_key)
algo_group_files = [
result_filename.split("-")[0]
for result_filename in result_files
]
else:
algo_group_files = [
result_filename
for result_filename in result_files
]
for i in range(len(algo_group_files)):
algo_group = algo_group_files[i].replace(".csv", "").split("_")
if len(algo_group) == 2:
algo_group_files[i] = ("_".join(algo_group), "base")
else:
algo_group_files[i] = ("_".join(algo_group[:-1]), algo_group[-1])
algo_group_files = list(zip(*algo_group_files))
# final_groups = [result_files[i] for i in range(len(result_files)) if \
# algo_group_files[i][1] in groups]
# if len(algorithms) > 0:
# final_algos = [final_groups[i] for i in range(len(result_files)) if \
# ("_".join(result_files[i].split("_")[:-1]) in algorithms)]
# final_results = []
if len(algorithms) > 0:
final_results = [result_files[i] for i in range(len(result_files)) if \
(algo_group_files[0][i] in algorithms) and \
(algo_group_files[1][i] in groups)]
else:
final_results = [result_files[i] for i in range(len(result_files)) if \
(algo_group_files[1][i] in groups)]

if len(algo_groups) > 0:
split_algo_groups = [algo_group.split(".") for algo_group in algo_groups]
split_algo_groups = list(zip(*split_algo_groups))
final_algo_groups = [result_files[i] for i in range(len(result_files)) if \
(algo_group_files[0][i] in split_algo_groups[0]) and \
(algo_group_files[1][i] in split_algo_groups[1])]
final_results = final_results + final_algo_groups
final_results = set(final_results)

# if len(algorithms) > 0:
# result_files = [
# result_filename
# for result_filename in result_files
# if result_filename.split("-")[0] in algorithms
# ]
# elif method == "build":
# if len(algorithms) > 0:
# result_files = [
# result_filename
# for result_filename in result_files
# if result_filename.split("-")[0] in algorithms
# ]

results = load_lines(results_path, final_results, method, index_key)

return results

Expand Down Expand Up @@ -404,6 +446,15 @@ def main():
algorithms",
default=None,
)
parser.add_argument(
"--groups",
help="plot only comma separated groups of parameters",
default="base"
)
parser.add_argument(
"--algo-groups",
help="add comma separated algorithm+groups to the plot",
)
parser.add_argument(
"-k",
"--count",
Expand Down Expand Up @@ -444,6 +495,11 @@ def main():
algorithms = args.algorithms.split(",")
else:
algorithms = []
groups = args.groups.split(",")
if args.algo_groups:
algo_groups = args.algo_groups.split(",")
else:
algo_groups = []
k = args.count
batch_size = args.batch_size
if not args.build and not args.search:
Expand All @@ -465,6 +521,8 @@ def main():
search_results = load_all_results(
os.path.join(args.dataset_path, args.dataset),
algorithms,
groups,
algo_groups,
k,
batch_size,
"search",
Expand All @@ -487,6 +545,8 @@ def main():
build_results = load_all_results(
os.path.join(args.dataset_path, args.dataset),
algorithms,
groups,
algo_groups,
k,
batch_size,
"build",
Expand Down
47 changes: 36 additions & 11 deletions python/raft-ann-bench/src/raft-ann-bench/run/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,28 @@ def validate_algorithm(algos_conf, algo, gpu_present):
)


def find_executable(algos_conf, algo, k, batch_size):
def find_executable(algos_conf, algo, group, k, batch_size):
executable = algos_conf[algo]["executable"]

if group != "base":
return_str = f"{algo}_{group}-{k}-{batch_size}"
else:
return_str = f"{algo}-{k}-{batch_size}"

build_path = os.getenv("RAFT_HOME")
if build_path is not None:
build_path = os.path.join(build_path, "cpp", "build", executable)
if os.path.exists(build_path):
print(f"-- Using RAFT bench from repository in {build_path}. ")
return (executable, build_path, f"{algo}-{k}-{batch_size}")
return (executable, build_path, return_str)

# if there is no build folder present, we look in the conda environment
conda_path = os.getenv("CONDA_PREFIX")
if conda_path is not None:
conda_path = os.path.join(conda_path, "bin", "ann", executable)
if os.path.exists(conda_path):
print("-- Using RAFT bench found in conda environment. ")
return (executable, conda_path, f"{algo}-{k}-{batch_size}")
return (executable, conda_path, return_str)

else:
raise FileNotFoundError(executable)
Expand Down Expand Up @@ -214,9 +219,13 @@ def main():
# )
parser.add_argument(
"--groups",
help="comma separated groups of parameters to run the benchmarks for",
help="run only comma separated groups of parameters",
default="base"
)
parser.add_argument(
"--algo-groups",
help="add comma separated algorithm+groups to run",
)
parser.add_argument(
"-f",
"--force",
Expand Down Expand Up @@ -280,30 +289,44 @@ def main():
if filter_algos:
allowed_algos = args.algorithms.split(",")
named_groups = args.groups.split(",")
filter_algo_groups = True if args.algo_groups else False
allowed_algo_groups = None
if filter_algo_groups:
allowed_algo_groups = [algo_group.split(".") for algo_group in args.algo_groups.split(",")]
allowed_algo_groups = list(zip(*allowed_algo_groups))
algos_conf = dict()
for algo_f in algos_conf_fs:
with open(algo_f, "r") as f:
if algo_f.split("/")[-1] == "raft_cagra.yaml":
algo = yaml.safe_load(f)
insert_algo = True
insert_algo_group = False
if filter_algos:
if algo["name"] not in allowed_algos:
insert_algo = False
if insert_algo:
if filter_algo_groups:
if algo["name"] in allowed_algo_groups[0]:
insert_algo_group = True
def add_algo_group(group_list):
if algo["name"] not in algos_conf:
algos_conf[algo["name"]] = dict()
for group in algo.keys():
if group != "name":
if group in named_groups:
if group in group_list:
algos_conf[algo["name"]][group] = algo[group]
if insert_algo:
add_algo_group(named_groups)
if insert_algo_group:
add_algo_group(allowed_algo_groups[1])

print(algos_conf)
executables_to_run = dict()
for algo in algos_conf.keys():
validate_algorithm(algos_yaml, algo, gpu_present)
executable = find_executable(algos_yaml, algo, k, batch_size)
if executable not in executables_to_run:
executables_to_run[executable] = {"index": []}
for group in algos_conf[algo].keys():
executable = find_executable(algos_yaml, algo, group, k, batch_size)
if executable not in executables_to_run:
executables_to_run[executable] = {"index": []}
build_params = algos_conf[algo][group]["build"]
search_params = algos_conf[algo][group]["search"]

Expand All @@ -323,7 +346,10 @@ def main():

for params in all_build_params:
index = {"algo": algo, "build_param": {}}
index_name = f"{algo}"
if group != "base":
index_name = f"{algo}_{group}"
else:
index_name = f"{algo}"
for i in range(len(params)):
index["build_param"][param_names[i]] = params[i]
index_name += "." + f"{param_names[i]}{params[i]}"
Expand Down Expand Up @@ -408,7 +434,6 @@ def main():
# )
# executables_to_run[executable_path]["index"][pos] = index

print(conf_filedir)
run_build_and_search(
conf_file,
f"{args.dataset}.json",
Expand Down

0 comments on commit b00942d

Please sign in to comment.