diff --git a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py index 842e08aee8..233607c281 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py @@ -15,21 +15,21 @@ # This script is inspired by # 1: https://github.com/erikbern/ann-benchmarks/blob/main/plot.py -# 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py -# 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py +# 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py # noqa: E501 +# 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py # noqa: E501 # Licence: https://github.com/erikbern/ann-benchmarks/blob/main/LICENSE -import matplotlib as mpl - import argparse -from collections import OrderedDict import itertools +import os +from collections import OrderedDict + +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd -import os -mpl.use("Agg") # noqa +mpl.use("Agg") metrics = { "k-nn": { @@ -40,16 +40,19 @@ "qps": { "description": "Queries per second (1/s)", "worst": float("-inf"), - } + }, } + def positive_int(input_str: str) -> int: try: i = int(input_str) if i < 1: raise ValueError except ValueError: - raise argparse.ArgumentTypeError(f"{input_str} is not a positive integer") + raise argparse.ArgumentTypeError( + f"{input_str} is not a positive integer" + ) return i @@ -62,17 +65,36 @@ def euclidean(a, b): return sum((x - y) ** 2 for x, y in zip(a, b)) while len(colors) < n: - new_color = max(itertools.product(vs, vs, vs), key=lambda a: min(euclidean(a, b) for b in colors)) + new_color = max( + itertools.product(vs, vs, vs), + key=lambda a: min(euclidean(a, b) for b in colors), + ) colors.append(new_color + (1.0,)) return colors def create_linestyles(unique_algorithms): - colors = dict(zip(unique_algorithms, generate_n_colors(len(unique_algorithms)))) - linestyles = dict((algo, ["--", "-.", "-", ":"][i % 4]) for i, algo in enumerate(unique_algorithms)) - markerstyles = dict((algo, ["+", "<", "o", "*", "x"][i % 5]) for i, algo in enumerate(unique_algorithms)) - faded = dict((algo, (r, g, b, 0.3)) for algo, (r, g, b, a) in colors.items()) - return dict((algo, (colors[algo], faded[algo], linestyles[algo], markerstyles[algo])) for algo in unique_algorithms) + colors = dict( + zip(unique_algorithms, generate_n_colors(len(unique_algorithms))) + ) + linestyles = dict( + (algo, ["--", "-.", "-", ":"][i % 4]) + for i, algo in enumerate(unique_algorithms) + ) + markerstyles = dict( + (algo, ["+", "<", "o", "*", "x"][i % 5]) + for i, algo in enumerate(unique_algorithms) + ) + faded = dict( + (algo, (r, g, b, 0.3)) for algo, (r, g, b, a) in colors.items() + ) + return dict( + ( + algo, + (colors[algo], faded[algo], linestyles[algo], markerstyles[algo]), + ) + for algo in unique_algorithms + ) def get_up_down(metric): @@ -97,7 +119,9 @@ def create_pointset(data, xn, yn): # Generate Pareto frontier xs, ys, ls, idxs = [], [], [], [] last_x = xm["worst"] - comparator = (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx) + comparator = ( + (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx) + ) for algo_name, index_name, xv, yv in data: if not xv or not yv: continue @@ -114,8 +138,9 @@ def create_pointset(data, xn, yn): return xs, ys, ls, idxs, axs, ays, als, aidxs -def create_plot_search(all_data, raw, x_scale, y_scale, fn_out, linestyles, - dataset, k, batch_size): +def create_plot_search( + all_data, raw, x_scale, y_scale, fn_out, linestyles, dataset, k, batch_size +): xn = "k-nn" yn = "qps" xm, ym = (metrics[xn], metrics[yn]) @@ -126,23 +151,43 @@ def create_plot_search(all_data, raw, x_scale, y_scale, fn_out, linestyles, # Sorting by mean y-value helps aligning plots with labels def mean_y(algo): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(all_data[algo], xn, yn) + xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( + all_data[algo], xn, yn + ) return -np.log(np.array(ys)).mean() # Find range for logit x-scale min_x, max_x = 1, 0 for algo in sorted(all_data.keys(), key=mean_y): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(all_data[algo], xn, yn) + xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( + all_data[algo], xn, yn + ) min_x = min([min_x] + [x for x in xs if x > 0]) max_x = max([max_x] + [x for x in xs if x < 1]) color, faded, linestyle, marker = linestyles[algo] (handle,) = plt.plot( - xs, ys, "-", label=algo, color=color, ms=7, mew=3, lw=3, marker=marker + xs, + ys, + "-", + label=algo, + color=color, + ms=7, + mew=3, + lw=3, + marker=marker, ) handles.append(handle) if raw: (handle2,) = plt.plot( - axs, ays, "-", label=algo, color=faded, ms=5, mew=2, lw=2, marker=marker + axs, + ays, + "-", + label=algo, + color=faded, + ms=5, + mew=2, + lw=2, + marker=marker, ) labels.append(algo) @@ -176,7 +221,13 @@ def inv_fun(x): ax.set_title(f"{dataset} k={k} batch_size={batch_size}") plt.gca().get_position() # plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height]) - ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5), prop={"size": 9}) + ax.legend( + handles, + labels, + loc="center left", + bbox_to_anchor=(1, 0.5), + prop={"size": 9}, + ) plt.grid(visible=True, which="major", color="0.65", linestyle="-") plt.setp(ax.get_xminorticklabels(), visible=True) @@ -197,8 +248,9 @@ def inv_fun(x): plt.close() -def create_plot_build(build_results, search_results, linestyles, fn_out, - dataset, k, batch_size): +def create_plot_build( + build_results, search_results, linestyles, fn_out, dataset, k, batch_size +): xn = "k-nn" yn = "qps" @@ -219,11 +271,15 @@ def create_plot_build(build_results, search_results, linestyles, fn_out, # Sorting by mean y-value helps aligning plots with labels def mean_y(algo): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(search_results[algo], xn, yn) + xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( + search_results[algo], xn, yn + ) return -np.log(np.array(ys)).mean() for pos, algo in enumerate(sorted(search_results.keys(), key=mean_y)): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(search_results[algo], xn, yn) + xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( + search_results[algo], xn, yn + ) # x is recall, y is qps, ls is algo_name, idxs is index_name for i in range(len(xs)): if xs[i] >= 0.85 and xs[i] < 0.9 and ys[i] > qps_85[pos]: @@ -241,7 +297,7 @@ def mean_y(algo): data[algo] = [bt_85[pos], bt_90[pos], bt_95[pos]] colors[algo] = linestyles[algo][0] - index = ['@85% Recall', '@90% Recall', '@95% Recall'] + index = ["@85% Recall", "@90% Recall", "@95% Recall"] df = pd.DataFrame(data, index=index) plt.figure(figsize=(12, 9)) @@ -258,8 +314,8 @@ def load_lines(results_path, result_files, method, index_key): results = dict() for result_filename in result_files: - if result_filename.endswith('.csv'): - with open(os.path.join(results_path, result_filename), 'r') as f: + if result_filename.endswith(".csv"): + with open(os.path.join(results_path, result_filename), "r") as f: lines = f.readlines() lines = lines[:-1] if lines[-1] == "\n" else lines @@ -269,7 +325,7 @@ def load_lines(results_path, result_files, method, index_key): key_idx = [2, 3] for line in lines[1:]: - split_lines = line.split(',') + split_lines = line.split(",") algo_name = split_lines[0] index_name = split_lines[1] @@ -288,14 +344,22 @@ def load_lines(results_path, result_files, method, index_key): return results -def load_all_results(dataset_path, algorithms, k, batch_size, method, index_key): +def load_all_results( + dataset_path, algorithms, k, batch_size, method, index_key +): results_path = os.path.join(dataset_path, "result", method) result_files = os.listdir(results_path) - result_files = [result_filename for result_filename in result_files \ - if f"{k}-{batch_size}" in result_filename] + result_files = [ + result_filename + for result_filename in result_files + if f"{k}-{batch_size}" in result_filename + ] if len(algorithms) > 0: - result_files = [result_filename for result_filename in result_files if \ - result_filename.split('-')[0] in algorithms] + result_files = [ + result_filename + for result_filename in result_files + if result_filename.split("-")[0] in algorithms + ] results = load_lines(results_path, result_files, method, index_key) @@ -310,37 +374,48 @@ def main(): default_dataset_path = os.path.join(call_path, "datasets/") parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--dataset", help="dataset to download", - default="glove-100-inner") - parser.add_argument("--dataset-path", help="path to dataset folder", - default=default_dataset_path) - parser.add_argument("--output-filepath", - help="directory for PNG to be saved", - default=os.getcwd()) - parser.add_argument("--algorithms", - help="plot only comma separated list of named \ - algorithms", - default=None) + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( - "-k", "--count", default=10, type=positive_int, help="the number of nearest neighbors to search for" + "--dataset", help="dataset to download", default="glove-100-inner" ) parser.add_argument( - "-bs", "--batch-size", default=10000, type=positive_int, help="number of query vectors to use in each query trial" + "--dataset-path", + help="path to dataset folder", + default=default_dataset_path, ) parser.add_argument( - "--build", - action="store_true" + "--output-filepath", + help="directory for PNG to be saved", + default=os.getcwd(), + ) + parser.add_argument( + "--algorithms", + help="plot only comma separated list of named \ + algorithms", + default=None, ) parser.add_argument( - "--search", - action="store_true" + "-k", + "--count", + default=10, + type=positive_int, + help="the number of nearest neighbors to search for", ) + parser.add_argument( + "-bs", + "--batch-size", + default=10000, + type=positive_int, + help="number of query vectors to use in each query trial", + ) + parser.add_argument("--build", action="store_true") + parser.add_argument("--search", action="store_true") parser.add_argument( "--x-scale", help="Scale to use when drawing the X-axis. \ Typically linear, logit or a2", - default="linear" + default="linear", ) parser.add_argument( "--y-scale", @@ -349,13 +424,15 @@ def main(): default="linear", ) parser.add_argument( - "--raw", help="Show raw results (not just Pareto frontier) in faded colours", action="store_true" + "--raw", + help="Show raw results (not just Pareto frontier) in faded colours", + action="store_true", ) args = parser.parse_args() if args.algorithms: - algorithms = args.algorithms.split(',') + algorithms = args.algorithms.split(",") else: algorithms = [] k = args.count @@ -367,22 +444,54 @@ def main(): build = args.build search = args.search - search_output_filepath = os.path.join(args.output_filepath, f"search-{args.dataset}-k{k}-batch_size{batch_size}.png") - build_output_filepath = os.path.join(args.output_filepath, f"build-{args.dataset}-k{k}-batch_size{batch_size}.png") + search_output_filepath = os.path.join( + args.output_filepath, + f"search-{args.dataset}-k{k}-batch_size{batch_size}.png", + ) + build_output_filepath = os.path.join( + args.output_filepath, + f"build-{args.dataset}-k{k}-batch_size{batch_size}.png", + ) search_results = load_all_results( - os.path.join(args.dataset_path, args.dataset), - algorithms, k, batch_size, "search", "algo") + os.path.join(args.dataset_path, args.dataset), + algorithms, + k, + batch_size, + "search", + "algo", + ) linestyles = create_linestyles(sorted(search_results.keys())) if search: - create_plot_search(search_results, args.raw, args.x_scale, args.y_scale, - search_output_filepath, linestyles, args.dataset, k, batch_size) + create_plot_search( + search_results, + args.raw, + args.x_scale, + args.y_scale, + search_output_filepath, + linestyles, + args.dataset, + k, + batch_size, + ) if build: build_results = load_all_results( os.path.join(args.dataset_path, args.dataset), - algorithms, k, batch_size, "build", "index") - create_plot_build(build_results, search_results, linestyles, build_output_filepath, - args.dataset, k, batch_size) + algorithms, + k, + batch_size, + "build", + "index", + ) + create_plot_build( + build_results, + search_results, + linestyles, + build_output_filepath, + args.dataset, + k, + batch_size, + ) if __name__ == "__main__":