Skip to content

Commit

Permalink
Updates backend_wrapper to handle exhausting returned iterators, fixe…
Browse files Browse the repository at this point in the history
…s bug in legacy_backend_wrapper.
  • Loading branch information
rlratzel committed Jan 7, 2024
1 parent bf2b73a commit 7ac6e31
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions benchmarks/nx-cugraph/pytest-based/bench_algos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -128,32 +128,32 @@ def nx_graph_from_dataset(dataset_obj):
def get_legacy_backend_wrapper(backend_name):
"""
Returns a callable that wraps an algo function with either the default
dispatch decorator, or the "testing" decorator which unconditionally
dispatches.
dispatcher (which dispatches based on input graph type), or the "testing"
dispatcher (which autoconverts and unconditionally dispatches).
This is only supported for NetworkX <3.2
"""
backends.plugin_name = "cugraph"
orig_dispatch = backends._dispatch
testing_dispatch = backends.test_override_dispatch

# Testing with the networkx <3.2 dispatch mechanism is based on decorating
# networkx APIs. The decorator is either one that only uses a backend if
# the input graph type is for that backend (the default decorator), or the
# "testing" decorator, which unconditionally converts a graph type to the
# type needed by the backend then calls the backend. If the cugraph backend
# is specified, create a callable that decorates the benchmarked function
# with the testing decorator.
#
# Because both the default and testing decorators assume they are only
# applied once and do bookkeeping to ensure algos are not registered
# multiple times, the callable also clears bookkeeping so the decorators
# can be reapplied multiple times. This is obviously a hack and networkx
# >=3.2 makes this use case properly supported.
backends._registered_algorithms = {}
if backend_name == "cugraph":
wrap_callable_for_dispatch = testing_dispatch(*args, **kwargs)
dispatch = testing_dispatch
else:
wrap_callable_for_dispatch = orig_dispatch(*args, **kwargs)
dispatch = orig_dispatch

def wrap_callable_for_dispatch(func, exhaust_returned_iterator=False):
# Networkx <3.2 registers functions when the dispatch decorator is
# applied (called) and errors if re-registered, so clear bookkeeping to
# allow it to be called repeatedly.
backends._registered_algorithms = {}
actual_func = dispatch(func) # returns the func the dispatcher picks
def wrapper(*args, **kwargs):
retval = actual_func(*args, **kwargs)
if exhaust_returned_iterator:
retval = list(retval)
return retval

return wrapper

return wrap_callable_for_dispatch

Expand All @@ -164,11 +164,13 @@ def get_backend_wrapper(backend_name):
"backend" kwarg on it.
This is only supported for NetworkX >= 3.2
"""

def wrap_callable_for_dispatch(func):
def wrap_callable_for_dispatch(func, exhaust_returned_iterator=False):
def wrapper(*args, **kwargs):
kwargs["backend"] = backend_name
return func(*args, **kwargs)
retval = func(*args, **kwargs)
if exhaust_returned_iterator:
retval = list(retval)
return retval

return wrapper

Expand Down Expand Up @@ -436,15 +438,9 @@ def bench_single_target_shortest_path_length(benchmark, graph_obj, backend_wrapp
node = max(degrees, key=lambda t: t[1])[0]
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)

# Ensure the benchmark time includes computation of each result from the
# returned generator
def run_generator(*args, **kwargs):
func = backend_wrapper(nx.single_target_shortest_path_length)
results = func(*args, **kwargs)
return list(results)

result = benchmark.pedantic(
target=run_generator,
target=backend_wrapper(nx.single_target_shortest_path_length,
exhaust_returned_iterator=True),
args=(G,),
kwargs=dict(
target=node,
Expand All @@ -453,4 +449,7 @@ def run_generator(*args, **kwargs):
iterations=iterations,
warmup_rounds=warmup_rounds,
)
# exhaust_returned_iterator=True forces the result to a list, but is not
# needed for this algo in NX 3.3+ which returns a dict instead of an
# iterator. Forcing to a list does not change the benchmark timing.
assert type(result) is list

0 comments on commit 7ac6e31

Please sign in to comment.