diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py index 9cc3f3cb0df..3f1b61176e0 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -207,7 +207,7 @@ def main(args): from pylibwholegraph.torch.initialize import ( get_global_communicator, get_local_node_communicator, - init + init, ) logger.info("initializing WG comms...") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py index b971cc476fb..a6e2b6a83bf 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py @@ -270,7 +270,7 @@ def __load_x_torch(self) -> None: def __load_x_wg(self) -> None: import logging - + logger = logging.getLogger("OGBNPapers100MDataset") logger.info("Loading x into WG embedding...") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py index a090156bc1a..5a5cad1fd36 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py @@ -288,12 +288,22 @@ def get_input_files(self, path, epoch=0, stage="train"): np.random.seed(epoch) np.random.shuffle(splits) - ex = re.compile(r'batch=([0-9]+)\-([0-9]+).parquet') - num_batches = min([ - sum([int(ex.match(fname.split('/')[-1])[2]) - int(ex.match(fname.split('/')[-1])[1]) for fname in s]) - for s in splits - ]) + ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet") + num_batches = min( + [ + sum( + [ + int(ex.match(fname.split("/")[-1])[2]) + - int(ex.match(fname.split("/")[-1])[1]) + for fname in s + ] + ) + for s in splits + ] + ) if num_batches == 0: - raise ValueError(f"Too few batches for training with world size {self.__world_size}") + raise ValueError( + f"Too few batches for training with world size {self.__world_size}" + ) return splits[self.rank], num_batches diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py index d9f9130024d..2529295e4a4 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py @@ -98,7 +98,16 @@ def log_batch( def train_epoch( - model, optimizer, loader, feature_store, epoch, num_classes, time_d, logger, rank, max_num_batches + model, + optimizer, + loader, + feature_store, + epoch, + num_classes, + time_d, + logger, + rank, + max_num_batches, ): """ Train the model for one epoch. @@ -185,7 +194,7 @@ def train_epoch( epoch=epoch, rank=rank, ) - + if max_num_batches is not None and iter_i >= max_num_batches: break @@ -319,7 +328,7 @@ def train(self): loader=loader, feature_store=self.data, num_classes=self.dataset.num_labels, - max_num_batches = max_num_batches, + max_num_batches=max_num_batches, ) print(f"Accuracy: {test_acc:.4f}%") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py index a59eec17c82..82daba431c8 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py @@ -258,19 +258,24 @@ def get_input_files(self, path, epoch=0, stage="train"): np.random.shuffle(file_list) splits = np.array_split(file_list, self.__world_size) - + import logging - logger = logging.getLogger('PyGCuGraphTrainer') + + logger = logging.getLogger("PyGCuGraphTrainer") logger.info(f"rank {self.rank} input files: {str(splits[self.rank])}") split = splits[self.rank] - ex = re.compile(r'batch=([0-9]+)\-([0-9]+).parquet') - num_batches = min([ - sum([int(ex.match(fname)[2]) - int(ex.match(fname)[1]) for fname in s]) - for s in splits - ]) + ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet") + num_batches = min( + [ + sum([int(ex.match(fname)[2]) - int(ex.match(fname)[1]) for fname in s]) + for s in splits + ] + ) if num_batches == 0: - raise ValueError(f"Too few batches for training with world size {self.__world_size}") + raise ValueError( + f"Too few batches for training with world size {self.__world_size}" + ) return split, num_batches diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py index 21a315da678..d8ca9f8bca7 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py @@ -51,7 +51,10 @@ def pyg_num_workers(world_size: int) -> int: def calc_accuracy( - loader: NeighborLoader, max_num_batches: int, model: torch.nn.Module, num_classes: int + loader: NeighborLoader, + max_num_batches: int, + model: torch.nn.Module, + num_classes: int, ) -> float: """ Evaluates the accuracy of a model given a loader over evaluation samples. @@ -131,16 +134,14 @@ def train(self): end_time_backward = start_time num_layers = len(self.model.module.convs) - + for epoch in range(self.num_epochs): with td.algorithms.join.Join( [self.model, self.optimizer], divide_by_initial_world_size=False ): self.model.train() loader, max_num_batches = self.get_loader(epoch=epoch, stage="train") - for iter_i, data in enumerate( - loader - ): + for iter_i, data in enumerate(loader): loader_time_iter = time.perf_counter() - end_time_backward time_loader += loader_time_iter @@ -247,7 +248,9 @@ def train(self): loader, max_num_batches = self.get_loader(epoch=epoch, stage="test") num_classes = self.dataset.num_labels - acc = calc_accuracy(loader, max_num_batches, self.model.module, num_classes) + acc = calc_accuracy( + loader, max_num_batches, self.model.module, num_classes + ) if self.rank == 0: print(