diff --git a/benchmarks/cugraph/standalone/bulk_sampling/README.md b/benchmarks/cugraph/standalone/bulk_sampling/README.md index bb01133c52f..a837f309139 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/README.md +++ b/benchmarks/cugraph/standalone/bulk_sampling/README.md @@ -152,7 +152,7 @@ Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. Y the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which can be used to create replications of the dataset for scale testing purposes. -The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` +The final two arguments are `FRAMEWORK` which can be "cuGraphDGL", "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE` is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched GPUs per node is currently unsupported by this script but should be possible in practice. diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py index b9dbe641a4f..dff2d310058 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -235,6 +235,25 @@ def main(args): num_neighbors=fanout, batch_size=args.batch_size, ) + elif args.framework == "cuGraphDGL": + sample_dir = os.path.join( + args.sample_dir, + f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", + ) + from trainers.dgl import DGLCuGraphTrainer + trainer = DGLCuGraphTrainer( + model=args.model, + dataset=dataset, + sample_dir=args.sample_dir, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=[int(f) for f in args.fanout.split("_")], + batch_size=args.batch_size, + ) else: raise ValueError("unsupported framework") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py new file mode 100644 index 00000000000..66386d392f4 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .models_dgl import GraphSAGE \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py new file mode 100644 index 00000000000..f56965be783 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +class GraphSAGE(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + num_layers, + model_backend="dgl", + ): + if model_backend == "dgl": + from dgl.nn import SAGEConv + else: + from cugraph_dgl.nn import SAGEConv + + super(GraphSAGE, self).__init__() + self.convs = torch.nn.ModuleList() + for _ in range(num_layers - 1): + self.convs.append( + SAGEConv(in_channels, hidden_channels, aggregator_type="mean") + ) + in_channels = hidden_channels + self.convs.append( + SAGEConv(hidden_channels, out_channels, aggregator_type="mean") + ) + + def forward(self, blocks, x): + for i, conv in enumerate(self.convs): + x = conv(blocks[i], x) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=0.5) + return x + + +def create_model(feat_size, num_classes, num_layers, model_backend="dgl"): + model = GraphSAGE( + feat_size, 64, num_classes, num_layers, model_backend=model_backend + ) + model = model.to("cuda") + model.train() + return model \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py new file mode 100644 index 00000000000..a7c6712f438 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .trainers_dgl import DGLTrainer +from .trainers_cugraph_dgl import DGLCuGraphTrainer \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py new file mode 100644 index 00000000000..8d892f0b06e --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO" +os.environ["KVIKIO_NTHREADS"] = "32" +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + +from .trainers_dgl import DGLTrainer +from models.dgl import GraphSAGE + +import torch +import numpy as np +import warnings + +from torch.nn.parallel import DistributedDataParallel as ddp +from cugraph_dgl.dataloading import HomogenousBulkSamplerDataset +from cugraph.gnn import FeatureStore + + +def get_dataloader(input_file_paths, total_num_nodes, sparse_format, return_type): + print("Creating dataloader", flush=True) + st = time.time() + dataset = HomogenousBulkSamplerDataset( + total_num_nodes, + edge_dir="in", + sparse_format=sparse_format, + return_type=return_type, + ) + dataset.set_input_files(input_file_paths=input_file_paths) + dataloader = torch.utils.data.DataLoader( + dataset, collate_fn=lambda x: x, shuffle=False, num_workers=0, batch_size=None + ) + et = time.time() + print(f"Time to create dataloader = {et - st:.2f} seconds", flush=True) + return dataloader + + +class DGLCuGraphTrainer(DGLTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + sample_dir=".", + **kwargs, + ): + self.__data = None + self.__device = device + self.__rank = rank + self.__world_size = world_size + self.__num_epochs = num_epochs + self.__dataset = dataset + self.__sample_dir = sample_dir + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train") -> int: + # TODO support online sampling + if stage in ["val", "test"]: + path = os.path.join(self.__sample_dir, stage, "samples") + else: + path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + + dataloader = get_dataloader( + input_file_paths=self.get_input_files(path), + total_num_nodes=None, + sparse_format="csc", + return_type="cugraph_dgl.nn.SparseGraph", + ) + return dataloader + + @property + def data(self): + import logging + + logger = logging.getLogger("DGLCuGraphTrainer") + logger.info("getting data") + + if self.__data is None: + # FIXME wholegraph + fs = FeatureStore(backend="torch") + num_nodes_dict = {} + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + fs.add_data(x, node_type, "x") + num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + fs.add_data(y, node_type, "y") + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + fs.add_data(train, node_type, "train") + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + fs.add_data(test, node_type, "test") + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + fs.add_data(val, node_type, "val") + + # # TODO support online sampling if the edge index is provided + # num_edges_dict = self.__dataset.edge_index_dict + # if not isinstance(list(num_edges_dict.values())[0], int): + # num_edges_dict = {k: len(v) for k, v in num_edges_dict} + + self.__data = fs + return self.__data + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + GraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + model_backend="cugraph_dgl", + ) + .to(torch.float32) + .to(self.__device) + ) + # TODO: Fix for distributed models + if torch.distributed.is_initialized(): + model = ddp(model, device_ids=[self.__device]) + else: + warnings.warn("Distributed training is not available") + print("done creating model") + + return model + + def get_input_files(self, path): + file_list = [entry.path for entry in os.scandir(path)] + return np.array_split(file_list, self.__world_size)[self.__rank] \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py new file mode 100644 index 00000000000..3beeccba22d --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as td +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as ddp + +from trainers import Trainer + +from models.dgl import GraphSAGE + +import time +import warnings + + +def get_features(input_nodes, output_nodes, feature_store, key="paper"): + if isinstance(input_nodes, dict): + input_nodes = input_nodes[key] + if isinstance(output_nodes, dict): + output_nodes = output_nodes[key] + + # TODO: Fix below + # Adding based on assumption that cpu features + # and gpu index is not supported yet + + if feature_store.backend == "torch": + input_nodes = input_nodes.to("cpu") + output_nodes = output_nodes.to("cpu") + + x = feature_store.get_data(indices=input_nodes, type_name=key, feat_name="x") + y = feature_store.get_data(indices=output_nodes, type_name=key, feat_name="y") + return x, y + + +def log_batch( + logger, + iter_i, + num_batches, + time_forward, + time_backward, + time_start, + loader_time_iter, + epoch, + rank, +): + time_forward_iter = time_forward / num_batches + time_backward_iter = time_backward / num_batches + total_time_iter = (time.perf_counter() - time_start) / num_batches + logger.info(f"epoch {epoch}, iteration {iter_i}, rank {rank}") + logger.info(f"time forward: {time_forward_iter}") + logger.info(f"time backward: {time_backward_iter}") + logger.info(f"loader time: {loader_time_iter}") + logger.info(f"total time: {total_time_iter}") + + +def train_epoch( + model, optimizer, loader, feature_store, epoch, num_classes, time_d, logger, rank +): + """ + Train the model for one epoch. + model: The model to train. + optimizer: The optimizer to use. + loader: The loader to use. + data: cuGraph.gnn.FeatueStore + epoch: The epoch number. + num_classes: The number of classes. + time_d: A dictionary of times. + logger: The logger to use. + rank: Total rank + """ + model = model.train() + time_feature_indexing = time_d["time_feature_indexing"] + time_feature_transfer = time_d["time_feature_transfer"] + time_forward = time_d["time_forward"] + time_backward = time_d["time_backward"] + time_loader = time_d["time_loader"] + + time_start = time.perf_counter() + end_time_backward = time.perf_counter() + + num_batches = 0 + total_loss = 0.0 + + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + loader_time_iter = time.perf_counter() - end_time_backward + time_loader += loader_time_iter + feature_indexing_time_start = time.perf_counter() + x, y_true = get_features(input_nodes, output_nodes, feature_store=feature_store) + additional_feature_time_end = time.perf_counter() + time_feature_indexing += ( + additional_feature_time_end - feature_indexing_time_start + ) + feature_trasfer_time_start = time.perf_counter() + x = x.to("cuda") + y_true = y_true.to("cuda") + time_feature_transfer += time.perf_counter() - feature_trasfer_time_start + num_batches += 1 + + start_time_forward = time.perf_counter() + y_pred = model( + blocks, + x, + ) + end_time_forward = time.perf_counter() + time_forward += end_time_forward - start_time_forward + + if y_pred.shape[0] > len(y_true): + raise ValueError(f"illegal shape: {y_pred.shape}; {y_true.shape}") + + y_true = y_true[: y_pred.shape[0]] + y_true = F.one_hot( + y_true.to(torch.int64), + num_classes=num_classes, + ).to(torch.float32) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true shape was {y_true.shape} " + f"but y_pred shape was {y_pred.shape} " + f"in iteration {iter_i} " + f"on rank {y_pred.device.index}" + ) + + start_time_backward = time.perf_counter() + loss = F.cross_entropy(y_pred, y_true) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + end_time_backward = time.perf_counter() + time_backward += end_time_backward - start_time_backward + + if iter_i % 50 == 0: + log_batch( + logger=logger, + iter_i=iter_i, + num_batches=num_batches, + time_forward=time_forward, + time_backward=time_backward, + time_start=time_start, + loader_time_iter=loader_time_iter, + epoch=epoch, + rank=rank, + ) + + time_d["time_loader"] += time_loader + time_d["time_feature_indexing"] += time_feature_indexing + time_d["time_feature_transfer"] += time_feature_transfer + time_d["time_forward"] += time_forward + time_d["time_backward"] += time_backward + + return num_batches, total_loss + + +def get_accuracy(model, loader, feature_store, num_classes): + from torchmetrics import Accuracy + + acc = Accuracy(task="multiclass", num_classes=num_classes).cuda() + model = model.eval() + acc_sum = 0.0 + with torch.no_grad(): + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + x, y_true = get_features( + input_nodes, output_nodes, feature_store=feature_store + ) + x = x.to("cuda") + y_true = y_true.to("cuda") + + out = model(blocks, x) + batch_size = out.shape[0] + acc_sum += acc(out[:batch_size].softmax(dim=-1), y_true[:batch_size]) + # TODO: Dont know why we are averaging on batch size + num_batches = iter_i + print( + f"Accuracy: {acc_sum/(num_batches) * 100.0:.4f}%", + ) + return acc_sum / (num_batches) * 100.0 + + +class DGLTrainer(Trainer): + def train(self): + import logging + + logger = logging.getLogger("DGLTrainer") + time_d = { + "time_loader": 0.0, + "time_feature_indexing": 0.0, + "time_feature_transfer": 0.0, + "time_forward": 0.0, + "time_backward": 0.0, + } + total_batches = 0 + for epoch in range(self.num_epochs): + start_time = time.perf_counter() + with td.algorithms.join.Join([self.model, self.optimizer]): + num_batches, total_loss = train_epoch( + model=self.model, + optimizer=self.optimizer, + loader=self.get_loader(epoch=epoch, stage="train"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + epoch=epoch, + time_d=time_d, + logger=logger, + rank=self.rank, + ) + total_batches = total_batches + num_batches + end_time = time.perf_counter() + epoch_time_taken = end_time - start_time + print( + f"RANK: {self.rank} Total time taken for training epoch {epoch} = {epoch_time_taken}", + flush=True, + ) + print("---" * 30) + td.barrier() + with td.algorithms.join.Join([self.model, self.optimizer]): + # val + if self.rank == 0: + validation_acc = get_accuracy( + model=self.model, + loader=self.get_loader(epoch=epoch, stage="val"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + ) + print(f"Validation Accuracy: {validation_acc:.4f}%") + else: + validation_acc = 0.0 + + # test: + with td.algorithms.join.Join([self.model, self.optimizer]): + if self.rank == 0: + test_acc = get_accuracy( + model=self.model, + loader=self.get_loader(epoch=epoch, stage="test"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + ) + print(f"Test Accuracy: {validation_acc:.4f}%") + else: + test_acc = 0.0 + test_acc = float(test_acc) + stats = { + "Accuracy": test_acc if self.rank == 0 else 0.0, + "# Batches": total_batches, + "Loader Time": time_d["time_loader"], + "Feature Time": time_d["time_feature_indexing"] + + time_d["time_feature_transfer"], + "Forward Time": time_d["time_forward"], + "Backward Time": time_d["time_backward"], + } + return stats + +# For native DGL training, see benchmarks/cugraph-dgl/scale-benchmarks \ No newline at end of file