Skip to content

Commit

Permalink
Revert bulk sampling changes
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Nov 21, 2023
1 parent 64ec881 commit a32fd3d
Showing 1 changed file with 11 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_allocation_counts_dask_lazy,
sizeof_fmt,
get_peak_output_ratio_across_workers,
restart_client,
start_dask_client,
stop_dask_client,
enable_spilling,
Expand Down Expand Up @@ -186,10 +187,10 @@ def sample_graph(
output_path,
seed=42,
batch_size=500,
seeds_per_call=400000,
seeds_per_call=200000,
batches_per_partition=100,
fanout=[5, 5, 5],
sampling_kwargs={},
persist=False,
):
cupy.random.seed(seed)

Expand All @@ -203,7 +204,6 @@ def sample_graph(
seeds_per_call=seeds_per_call,
batches_per_partition=batches_per_partition,
log_level=logging.INFO,
**sampling_kwargs,
)

n_workers = len(default_client().scheduler_info()["workers"])
Expand Down Expand Up @@ -469,7 +469,6 @@ def benchmark_cugraph_bulk_sampling(
batch_size,
seeds_per_call,
fanout,
sampling_target_framework,
reverse_edges=True,
dataset_dir=".",
replication_factor=1,
Expand Down Expand Up @@ -565,39 +564,17 @@ def benchmark_cugraph_bulk_sampling(
output_sample_path = os.path.join(output_subdir, "samples")
os.makedirs(output_sample_path)

if sampling_target_framework == "cugraph_dgl_csr":
sampling_kwargs = {
"deduplicate_sources": True,
"prior_sources_behavior": "carryover",
"renumber": True,
"compression": "CSR",
"compress_per_hop": True,
"use_legacy_names": False,
"include_hop_column": False,
}
else:
# FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02)
sampling_kwargs = {
"deduplicate_sources": True,
"prior_sources_behavior": "exclude",
"renumber": True,
"compression": "COO",
"compress_per_hop": False,
"use_legacy_names": False,
"include_hop_column": True,
}

batches_per_partition = 400_000 // batch_size
batches_per_partition = 200_000 // batch_size
execution_time, allocation_counts = sample_graph(
G=G,
label_df=dask_label_df,
output_path=output_sample_path,
G,
dask_label_df,
output_sample_path,
seed=seed,
batch_size=batch_size,
seeds_per_call=seeds_per_call,
batches_per_partition=batches_per_partition,
fanout=fanout,
sampling_kwargs=sampling_kwargs,
persist=persist,
)

output_meta = {
Expand Down Expand Up @@ -724,13 +701,7 @@ def get_args():
required=False,
default=False,
)
parser.add_argument(
"--sampling_target_framework",
type=str,
help="The target framework for sampling (i.e. cugraph_dgl_csr, cugraph_pyg_csc, ...)",
required=False,
default=None,
)

parser.add_argument(
"--dask_worker_devices",
type=str,
Expand Down Expand Up @@ -767,12 +738,6 @@ def get_args():
logging.basicConfig()

args = get_args()
if args.sampling_target_framework not in ["cugraph_dgl_csr", None]:
raise ValueError(
"sampling_target_framework must be one of cugraph_dgl_csr or None",
"Other frameworks are not supported at this time.",
)

fanouts = [
[int(f) for f in fanout.split("_")] for fanout in args.fanouts.split(",")
]
Expand All @@ -784,7 +749,7 @@ def get_args():
client, cluster = start_dask_client(
dask_worker_devices=dask_worker_devices,
jit_unspill=False,
rmm_pool_size=20e9,
rmm_pool_size=28e9,
rmm_async=True,
)
enable_spilling()
Expand Down Expand Up @@ -820,7 +785,6 @@ def get_args():
batch_size=batch_size,
seeds_per_call=seeds_per_call,
fanout=fanout,
sampling_target_framework=args.sampling_target_framework,
dataset_dir=args.dataset_root,
reverse_edges=args.reverse_edges,
replication_factor=replication_factor,
Expand All @@ -845,6 +809,7 @@ def get_args():
warnings.warn("An Exception Occurred!")
print(e)
traceback.print_exc()
restart_client(client)
sleep(10)

stats_df = pd.DataFrame(
Expand Down

0 comments on commit a32fd3d

Please sign in to comment.