Skip to content

Commit

Permalink
Revert bulk sampling changes
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Nov 21, 2023
1 parent a32fd3d commit 5f7d4e5
Showing 1 changed file with 45 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
get_allocation_counts_dask_lazy,
sizeof_fmt,
get_peak_output_ratio_across_workers,
restart_client,
start_dask_client,
stop_dask_client,
enable_spilling,
Expand Down Expand Up @@ -187,10 +186,10 @@ def sample_graph(
output_path,
seed=42,
batch_size=500,
seeds_per_call=200000,
seeds_per_call=400000,
batches_per_partition=100,
fanout=[5, 5, 5],
persist=False,
sampling_kwargs={},
):
cupy.random.seed(seed)

Expand All @@ -204,6 +203,7 @@ def sample_graph(
seeds_per_call=seeds_per_call,
batches_per_partition=batches_per_partition,
log_level=logging.INFO,
**sampling_kwargs,
)

n_workers = len(default_client().scheduler_info()["workers"])
Expand Down Expand Up @@ -469,6 +469,7 @@ def benchmark_cugraph_bulk_sampling(
batch_size,
seeds_per_call,
fanout,
sampling_target_framework,
reverse_edges=True,
dataset_dir=".",
replication_factor=1,
Expand Down Expand Up @@ -564,17 +565,39 @@ def benchmark_cugraph_bulk_sampling(
output_sample_path = os.path.join(output_subdir, "samples")
os.makedirs(output_sample_path)

batches_per_partition = 200_000 // batch_size
if sampling_target_framework == "cugraph_dgl_csr":
sampling_kwargs = {
"deduplicate_sources": True,
"prior_sources_behavior": "carryover",
"renumber": True,
"compression": "CSR",
"compress_per_hop": True,
"use_legacy_names": False,
"include_hop_column": False,
}
else:
# FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02)
sampling_kwargs = {
"deduplicate_sources": True,
"prior_sources_behavior": "exclude",
"renumber": True,
"compression": "COO",
"compress_per_hop": False,
"use_legacy_names": False,
"include_hop_column": True,
}

batches_per_partition = 400_000 // batch_size
execution_time, allocation_counts = sample_graph(
G,
dask_label_df,
output_sample_path,
G=G,
label_df=dask_label_df,
output_path=output_sample_path,
seed=seed,
batch_size=batch_size,
seeds_per_call=seeds_per_call,
batches_per_partition=batches_per_partition,
fanout=fanout,
persist=persist,
sampling_kwargs=sampling_kwargs,
)

output_meta = {
Expand Down Expand Up @@ -701,7 +724,13 @@ def get_args():
required=False,
default=False,
)

parser.add_argument(
"--sampling_target_framework",
type=str,
help="The target framework for sampling (i.e. cugraph_dgl_csr, cugraph_pyg_csc, ...)",
required=False,
default=None,
)
parser.add_argument(
"--dask_worker_devices",
type=str,
Expand Down Expand Up @@ -738,6 +767,12 @@ def get_args():
logging.basicConfig()

args = get_args()
if args.sampling_target_framework not in ["cugraph_dgl_csr", None]:
raise ValueError(
"sampling_target_framework must be one of cugraph_dgl_csr or None",
"Other frameworks are not supported at this time.",
)

fanouts = [
[int(f) for f in fanout.split("_")] for fanout in args.fanouts.split(",")
]
Expand Down Expand Up @@ -785,6 +820,7 @@ def get_args():
batch_size=batch_size,
seeds_per_call=seeds_per_call,
fanout=fanout,
sampling_target_framework=args.sampling_target_framework,
dataset_dir=args.dataset_root,
reverse_edges=args.reverse_edges,
replication_factor=replication_factor,
Expand All @@ -809,7 +845,6 @@ def get_args():
warnings.warn("An Exception Occurred!")
print(e)
traceback.print_exc()
restart_client(client)
sleep(10)

stats_df = pd.DataFrame(
Expand Down

0 comments on commit 5f7d4e5

Please sign in to comment.