From 9bc944b489c91dbade9202069a47d0c860f4b059 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Tue, 2 Apr 2024 14:34:00 -0700 Subject: [PATCH] cleanup --- .../cugraph_pyg/examples/cugraph_dist_sampling.py | 11 ++++++++++- .../cugraph/cugraph/gnn/data_loading/dist_sampler.py | 3 +-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling.py b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling.py index 10033304e91..29366c404df 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling.py @@ -17,8 +17,9 @@ # is intented for users who want to extend cuGraph within a DDP workflow. import os - +import re import tempfile + import numpy as np import torch import torch.multiprocessing as tmp @@ -100,6 +101,14 @@ def main(): nprocs=world_size, ) + print("Printing samples...") + for file in os.listdir(directory): + m=re.match(r'batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet', file) + rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4]) + print(f'File: {file} (batches {start} to {end} for rank {rank})') + print(cudf.read_parquet(os.path.join(directory, file))) + print('\n') + if __name__ == "__main__": main() diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index 537e38cf780..c3948e73da8 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -58,7 +58,6 @@ def __write_minibatches_coo(self, minibatch_dict): has_edge_types = minibatch_dict['edge_type'] is not None has_weights = minibatch_dict['weight'] is not None - print(minibatch_dict) if minibatch_dict['renumber_map'] is None: raise ValueError("Distributed sampling without renumbering is not supported") @@ -120,7 +119,7 @@ def __write_minibatches_coo(self, minibatch_dict): if 'rank' in minibatch_dict: rank = minibatch_dict['rank'] full_output_path = os.path.join( - self.__directory, f"batch={rank:05d}{start_batch_id:08d}-{rank:05d}{end_batch_id:08d}.parquet" + self.__directory, f"batch={rank:05d}.{start_batch_id:08d}-{rank:05d}.{end_batch_id:08d}.parquet" ) else: full_output_path = os.path.join(