diff --git a/benchmarks/nx-cugraph/pytest-based/bench_algos.py b/benchmarks/nx-cugraph/pytest-based/bench_algos.py index 8fbce66d9ef..0a63164d7ca 100644 --- a/benchmarks/nx-cugraph/pytest-based/bench_algos.py +++ b/benchmarks/nx-cugraph/pytest-based/bench_algos.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -128,32 +128,32 @@ def nx_graph_from_dataset(dataset_obj): def get_legacy_backend_wrapper(backend_name): """ Returns a callable that wraps an algo function with either the default - dispatch decorator, or the "testing" decorator which unconditionally - dispatches. + dispatcher (which dispatches based on input graph type), or the "testing" + dispatcher (which autoconverts and unconditionally dispatches). This is only supported for NetworkX <3.2 """ backends.plugin_name = "cugraph" orig_dispatch = backends._dispatch testing_dispatch = backends.test_override_dispatch - # Testing with the networkx <3.2 dispatch mechanism is based on decorating - # networkx APIs. The decorator is either one that only uses a backend if - # the input graph type is for that backend (the default decorator), or the - # "testing" decorator, which unconditionally converts a graph type to the - # type needed by the backend then calls the backend. If the cugraph backend - # is specified, create a callable that decorates the benchmarked function - # with the testing decorator. - # - # Because both the default and testing decorators assume they are only - # applied once and do bookkeeping to ensure algos are not registered - # multiple times, the callable also clears bookkeeping so the decorators - # can be reapplied multiple times. This is obviously a hack and networkx - # >=3.2 makes this use case properly supported. - backends._registered_algorithms = {} if backend_name == "cugraph": - wrap_callable_for_dispatch = testing_dispatch(*args, **kwargs) + dispatch = testing_dispatch else: - wrap_callable_for_dispatch = orig_dispatch(*args, **kwargs) + dispatch = orig_dispatch + + def wrap_callable_for_dispatch(func, exhaust_returned_iterator=False): + # Networkx <3.2 registers functions when the dispatch decorator is + # applied (called) and errors if re-registered, so clear bookkeeping to + # allow it to be called repeatedly. + backends._registered_algorithms = {} + actual_func = dispatch(func) # returns the func the dispatcher picks + def wrapper(*args, **kwargs): + retval = actual_func(*args, **kwargs) + if exhaust_returned_iterator: + retval = list(retval) + return retval + + return wrapper return wrap_callable_for_dispatch @@ -164,11 +164,13 @@ def get_backend_wrapper(backend_name): "backend" kwarg on it. This is only supported for NetworkX >= 3.2 """ - - def wrap_callable_for_dispatch(func): + def wrap_callable_for_dispatch(func, exhaust_returned_iterator=False): def wrapper(*args, **kwargs): kwargs["backend"] = backend_name - return func(*args, **kwargs) + retval = func(*args, **kwargs) + if exhaust_returned_iterator: + retval = list(retval) + return retval return wrapper @@ -436,15 +438,9 @@ def bench_single_target_shortest_path_length(benchmark, graph_obj, backend_wrapp node = max(degrees, key=lambda t: t[1])[0] G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) - # Ensure the benchmark time includes computation of each result from the - # returned generator - def run_generator(*args, **kwargs): - func = backend_wrapper(nx.single_target_shortest_path_length) - results = func(*args, **kwargs) - return list(results) - result = benchmark.pedantic( - target=run_generator, + target=backend_wrapper(nx.single_target_shortest_path_length, + exhaust_returned_iterator=True), args=(G,), kwargs=dict( target=node, @@ -453,4 +449,7 @@ def run_generator(*args, **kwargs): iterations=iterations, warmup_rounds=warmup_rounds, ) + # exhaust_returned_iterator=True forces the result to a list, but is not + # needed for this algo in NX 3.3+ which returns a dict instead of an + # iterator. Forcing to a list does not change the benchmark timing. assert type(result) is list