From 5f7d4e53ff49c30745ae79a62b5588f6921c3247 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 21 Nov 2023 10:26:22 -0800 Subject: [PATCH] Revert bulk sampling changes --- .../bulk_sampling/cugraph_bulk_sampling.py | 55 +++++++++++++++---- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index a8c0658767d..1ca5d6db637 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -22,7 +22,6 @@ get_allocation_counts_dask_lazy, sizeof_fmt, get_peak_output_ratio_across_workers, - restart_client, start_dask_client, stop_dask_client, enable_spilling, @@ -187,10 +186,10 @@ def sample_graph( output_path, seed=42, batch_size=500, - seeds_per_call=200000, + seeds_per_call=400000, batches_per_partition=100, fanout=[5, 5, 5], - persist=False, + sampling_kwargs={}, ): cupy.random.seed(seed) @@ -204,6 +203,7 @@ def sample_graph( seeds_per_call=seeds_per_call, batches_per_partition=batches_per_partition, log_level=logging.INFO, + **sampling_kwargs, ) n_workers = len(default_client().scheduler_info()["workers"]) @@ -469,6 +469,7 @@ def benchmark_cugraph_bulk_sampling( batch_size, seeds_per_call, fanout, + sampling_target_framework, reverse_edges=True, dataset_dir=".", replication_factor=1, @@ -564,17 +565,39 @@ def benchmark_cugraph_bulk_sampling( output_sample_path = os.path.join(output_subdir, "samples") os.makedirs(output_sample_path) - batches_per_partition = 200_000 // batch_size + if sampling_target_framework == "cugraph_dgl_csr": + sampling_kwargs = { + "deduplicate_sources": True, + "prior_sources_behavior": "carryover", + "renumber": True, + "compression": "CSR", + "compress_per_hop": True, + "use_legacy_names": False, + "include_hop_column": False, + } + else: + # FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02) + sampling_kwargs = { + "deduplicate_sources": True, + "prior_sources_behavior": "exclude", + "renumber": True, + "compression": "COO", + "compress_per_hop": False, + "use_legacy_names": False, + "include_hop_column": True, + } + + batches_per_partition = 400_000 // batch_size execution_time, allocation_counts = sample_graph( - G, - dask_label_df, - output_sample_path, + G=G, + label_df=dask_label_df, + output_path=output_sample_path, seed=seed, batch_size=batch_size, seeds_per_call=seeds_per_call, batches_per_partition=batches_per_partition, fanout=fanout, - persist=persist, + sampling_kwargs=sampling_kwargs, ) output_meta = { @@ -701,7 +724,13 @@ def get_args(): required=False, default=False, ) - + parser.add_argument( + "--sampling_target_framework", + type=str, + help="The target framework for sampling (i.e. cugraph_dgl_csr, cugraph_pyg_csc, ...)", + required=False, + default=None, + ) parser.add_argument( "--dask_worker_devices", type=str, @@ -738,6 +767,12 @@ def get_args(): logging.basicConfig() args = get_args() + if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: + raise ValueError( + "sampling_target_framework must be one of cugraph_dgl_csr or None", + "Other frameworks are not supported at this time.", + ) + fanouts = [ [int(f) for f in fanout.split("_")] for fanout in args.fanouts.split(",") ] @@ -785,6 +820,7 @@ def get_args(): batch_size=batch_size, seeds_per_call=seeds_per_call, fanout=fanout, + sampling_target_framework=args.sampling_target_framework, dataset_dir=args.dataset_root, reverse_edges=args.reverse_edges, replication_factor=replication_factor, @@ -809,7 +845,6 @@ def get_args(): warnings.warn("An Exception Occurred!") print(e) traceback.print_exc() - restart_client(client) sleep(10) stats_df = pd.DataFrame(