diff --git a/benchmarks/cugraph/standalone/bulk_sampling/README.md b/benchmarks/cugraph/standalone/bulk_sampling/README.md index bb01133c52f..a837f309139 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/README.md +++ b/benchmarks/cugraph/standalone/bulk_sampling/README.md @@ -152,7 +152,7 @@ Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. Y the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which can be used to create replications of the dataset for scale testing purposes. -The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` +The final two arguments are `FRAMEWORK` which can be "cuGraphDGL", "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE` is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched GPUs per node is currently unsupported by this script but should be possible in practice. diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py index c9e347b261d..728902e3981 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -123,6 +123,13 @@ def parse_args(): required=True, ) + parser.add_argument( + "--use_wholegraph", + action="store_true", + help="Whether to use WholeGraph feature storage", + required=False, + ) + parser.add_argument( "--model", type=str, @@ -162,6 +169,13 @@ def parse_args(): required=False, ) + parser.add_argument( + "--skip_download", + action="store_true", + help="Whether to skip downloading", + required=False, + ) + return parser.parse_args() @@ -186,16 +200,36 @@ def main(args): world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node + if args.use_wholegraph: + # TODO support WG without cuGraph + if args.framework not in ["cuGraphPyG", "cuGraphDGL"]: + raise ValueError("WG feature store only supported with cuGraph backends") + from pylibwholegraph.torch.initialize import ( + get_global_communicator, + get_local_node_communicator, + ) + + logger.info("initializing WG comms...") + wm_comm = get_global_communicator() + get_local_node_communicator() + + wm_comm = wm_comm.wmb_comm + logger.info(f"rank {global_rank} successfully initialized WG comms") + wm_comm.barrier() + dataset = OGBNPapers100MDataset( replication_factor=args.replication_factor, dataset_dir=args.dataset_dir, train_split=args.train_split, val_split=args.val_split, load_edge_index=(args.framework == "PyG"), + backend="wholegraph" if args.use_wholegraph else "torch", ) - if global_rank == 0: + # Note: this does not generate WG files + if global_rank == 0 and not args.skip_download: dataset.download() + dist.barrier() fanout = [int(f) for f in args.fanout.split("_")] @@ -234,6 +268,28 @@ def main(args): replace=False, num_neighbors=fanout, batch_size=args.batch_size, + backend="wholegraph" if args.use_wholegraph else "torch", + ) + elif args.framework == "cuGraphDGL": + sample_dir = os.path.join( + args.sample_dir, + f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", + ) + from trainers.dgl import DGLCuGraphTrainer + + trainer = DGLCuGraphTrainer( + model=args.model, + dataset=dataset, + sample_dir=sample_dir, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=[int(f) for f in args.fanout.split("_")], + batch_size=args.batch_size, + backend="wholegraph" if args.use_wholegraph else "torch", ) else: raise ValueError("unsupported framework") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index e3a5bba3162..b1f54c924fe 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -200,19 +200,20 @@ def sample_graph( total_time = 0.0 for epoch in range(num_epochs): - steps = [("train", train_df), ("test", test_df)] + steps = [("train", train_df)] if epoch == num_epochs - 1: steps.append(("val", val_df)) + steps.append(("test", test_df)) for step, batch_df in steps: batch_df = batch_df.sample(frac=1.0, random_state=seed) - if step == "val": - output_sample_path = os.path.join(output_path, "val", "samples") - else: + if step == "train": output_sample_path = os.path.join( output_path, f"epoch={epoch}", f"{step}", "samples" ) + else: + output_sample_path = os.path.join(output_path, step, "samples") os.makedirs(output_sample_path) sampler = BulkSampler( @@ -372,7 +373,7 @@ def load_disk_dataset( can_edge_type = tuple(edge_type.split("__")) edge_index_dict[can_edge_type] = dask_cudf.read_parquet( Path(parquet_path) / edge_type / "edge_index.parquet" - ).repartition(n_workers * 2) + ).repartition(npartitions=n_workers * 2) edge_index_dict[can_edge_type]["src"] += node_offsets_replicated[ can_edge_type[0] @@ -431,7 +432,7 @@ def load_disk_dataset( if os.path.exists(node_label_path): node_labels[node_type] = ( dask_cudf.read_parquet(node_label_path) - .repartition(n_workers) + .repartition(npartitions=n_workers) .drop("label", axis=1) .persist() ) @@ -574,7 +575,7 @@ def benchmark_cugraph_bulk_sampling( "use_legacy_names": False, "include_hop_column": False, } - else: + elif sampling_target_framework == "cugraph_pyg": # FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02) sampling_kwargs = { "deduplicate_sources": True, @@ -585,6 +586,8 @@ def benchmark_cugraph_bulk_sampling( "use_legacy_names": False, "include_hop_column": True, } + else: + raise ValueError("Only cugraph_dgl_csr or cugraph_pyg are valid frameworks") batches_per_partition = 600_000 // batch_size execution_time, allocation_counts = sample_graph( @@ -761,9 +764,9 @@ def get_args(): logger.setLevel(logging.INFO) args = get_args() - if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: + if args.sampling_target_framework not in ["cugraph_dgl_csr", "cugraph_pyg"]: raise ValueError( - "sampling_target_framework must be one of cugraph_dgl_csr or None", + "sampling_target_framework must be one of cugraph_dgl_csr or cugraph_pyg", "Other frameworks are not supported at this time.", ) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py index a50e40f6d55..0299d9cb8ba 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py @@ -24,6 +24,10 @@ import os import json +from cugraph.utilities.utils import import_optional + +wgth = import_optional("pylibwholegraph.torch") + class OGBNPapers100MDataset(Dataset): def __init__( @@ -34,6 +38,7 @@ def __init__( train_split=0.8, val_split=0.5, load_edge_index=True, + backend="torch", ): self.__replication_factor = replication_factor self.__disk_x = None @@ -43,6 +48,7 @@ def __init__( self.__train_split = train_split self.__val_split = val_split self.__load_edge_index = load_edge_index + self.__backend = backend def download(self): import logging @@ -152,6 +158,27 @@ def download(self): ) ldf.to_parquet(node_label_file_path) + # WholeGraph + wg_bin_file_path = os.path.join(dataset_path, "wgb", "paper") + if self.__replication_factor == 1: + wg_bin_rep_path = os.path.join(wg_bin_file_path, "node_feat.d") + else: + wg_bin_rep_path = os.path.join( + wg_bin_file_path, f"node_feat_{self.__replication_factor}x.d" + ) + + if not os.path.exists(wg_bin_rep_path): + os.makedirs(wg_bin_rep_path) + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + node_feat = dataset[0][0]["node_feat"] + for k in range(self.__replication_factor): + node_feat.tofile(os.path.join(wg_bin_rep_path, f"{k:04d}.bin")) + @property def edge_index_dict( self, @@ -224,21 +251,52 @@ def edge_index_dict( @property def x_dict(self) -> Dict[str, torch.Tensor]: + if self.__disk_x is None: + if self.__backend == "wholegraph": + self.__load_x_wg() + else: + self.__load_x_torch() + + return self.__disk_x + + def __load_x_torch(self) -> None: node_type_path = os.path.join( self.__dataset_dir, "ogbn_papers100M", "npy", "paper" ) + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.npy") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.npy" + ) - if self.__disk_x is None: - if self.__replication_factor == 1: - full_path = os.path.join(node_type_path, "node_feat.npy") - else: - full_path = os.path.join( - node_type_path, f"node_feat_{self.__replication_factor}x.npy" - ) + self.__disk_x = {"paper": torch.as_tensor(np.load(full_path, mmap_mode="r"))} - self.__disk_x = {"paper": np.load(full_path, mmap_mode="r")} + def __load_x_wg(self) -> None: + node_type_path = os.path.join( + self.__dataset_dir, "ogbn_papers100M", "wgb", "paper" + ) + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.d") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.d" + ) - return self.__disk_x + file_list = [os.path.join(full_path, f) for f in os.listdir(full_path)] + + x = wgth.create_embedding_from_filelist( + wgth.get_global_communicator(), + "chunked", # TODO support other options + "cpu", # TODO support GPU + file_list, + torch.float32, + 128, + ) + + print("created x wg embedding", x) + + self.__disk_x = {"paper": x} @property def y_dict(self) -> Dict[str, torch.Tensor]: diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py new file mode 100644 index 00000000000..610a7648801 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .models_dgl import GraphSAGE diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py new file mode 100644 index 00000000000..38558439516 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +class GraphSAGE(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + num_layers, + model_backend="dgl", + ): + if model_backend == "dgl": + from dgl.nn import SAGEConv + else: + from cugraph_dgl.nn import SAGEConv + + super(GraphSAGE, self).__init__() + self.convs = torch.nn.ModuleList() + for _ in range(num_layers - 1): + self.convs.append( + SAGEConv(in_channels, hidden_channels, aggregator_type="mean") + ) + in_channels = hidden_channels + self.convs.append( + SAGEConv(hidden_channels, out_channels, aggregator_type="mean") + ) + + def forward(self, blocks, x): + for i, conv in enumerate(self.convs): + x = conv(blocks[i], x) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=0.5) + return x + + +def create_model(feat_size, num_classes, num_layers, model_backend="dgl"): + model = GraphSAGE( + feat_size, 64, num_classes, num_layers, model_backend=model_backend + ) + model = model.to("cuda") + model.train() + return model diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh index 41792c0b63a..6dfeebec0ae 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh @@ -21,6 +21,7 @@ FANOUT=$2 REPLICATION_FACTOR=$3 SCRIPTS_DIR=$4 NUM_EPOCHS=$5 +SAMPLING_FRAMEWORK=$6 SAMPLES_DIR=/samples DATASET_DIR=/datasets @@ -78,7 +79,8 @@ if [[ $SLURM_NODEID == 0 ]]; then --batch_sizes $BATCH_SIZE \ --seeds_per_call_opts "524288" \ --num_epochs $NUM_EPOCHS \ - --random_seed 42 + --random_seed 42 \ + --sampling_target_framework $SAMPLING_FRAMEWORK echo "DONE" > ${SAMPLES_DIR}/status.txt fi diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh index 977745a9593..44c2f9407b8 100755 --- a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -#SBATCH -A datascience_rapids_cugraphgnn -#SBATCH -p luna #SBATCH -J datascience_rapids_cugraphgnn-papers:bulkSamplingPyG #SBATCH -N 1 #SBATCH -t 00:25:00 @@ -28,13 +26,13 @@ mkdir -p $LOGS_DIR mkdir -p $SAMPLES_DIR mkdir -p $DATASETS_DIR -BATCH_SIZE=512 +BATCH_SIZE=555 FANOUT="10_10_10" NUM_EPOCHS=1 REPLICATION_FACTOR=1 -# options: PyG or cuGraphPyG -FRAMEWORK="cuGraphPyG" +# options: PyG, cuGraphPyG, or cuGraphDGL +FRAMEWORK="cuGraphDGL" GPUS_PER_NODE=8 nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) @@ -59,10 +57,16 @@ if [[ "$FRAMEWORK" == "cuGraphPyG" ]]; then srun \ --container-image $CONTAINER_IMAGE \ --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ - bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS + bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS "cugraph_pyg" +elif [[ "$FRAMEWORK" == "cuGraphDGL" ]]; then + srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS "cugraph_dgl_csr" fi # Train +# Should always use WholeGraph for benchmarks since it will eventually become the default feature store backend srun \ --container-image $CONTAINER_IMAGE \ --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ @@ -80,5 +84,6 @@ srun \ --batch_size $BATCH_SIZE \ --fanout $FANOUT \ --replication_factor $REPLICATION_FACTOR \ - --num_epochs $NUM_EPOCHS + --num_epochs $NUM_EPOCHS \ + --use_wholegraph diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py new file mode 100644 index 00000000000..03d2a51e538 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .trainers_dgl import DGLTrainer +from .trainers_cugraph_dgl import DGLCuGraphTrainer diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py new file mode 100644 index 00000000000..3a5e5b28fb8 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +from .trainers_dgl import DGLTrainer +from models.dgl import GraphSAGE + +import torch +import numpy as np +import warnings + +from torch.nn.parallel import DistributedDataParallel as ddp +from cugraph_dgl.dataloading import HomogenousBulkSamplerDataset +from cugraph.gnn import FeatureStore + + +def get_dataloader(input_file_paths, total_num_nodes, sparse_format, return_type): + print("Creating dataloader", flush=True) + st = time.time() + if len(input_file_paths) > 0: + dataset = HomogenousBulkSamplerDataset( + total_num_nodes, + edge_dir="in", + sparse_format=sparse_format, + return_type=return_type, + ) + dataset.set_input_files(input_file_paths=input_file_paths) + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=lambda x: x, + shuffle=False, + num_workers=0, + batch_size=None, + ) + et = time.time() + print(f"Time to create dataloader = {et - st:.2f} seconds", flush=True) + return dataloader + else: + return [] + + +class DGLCuGraphTrainer(DGLTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + sample_dir=".", + backend="torch", + **kwargs, + ): + self.__data = None + self.__device = device + self.__rank = rank + self.__world_size = world_size + self.__num_epochs = num_epochs + self.__dataset = dataset + self.__sample_dir = sample_dir + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + self.__backend = backend + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train") -> int: + # TODO support online sampling + if stage == "train": + path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + else: + path = os.path.join(self.__sample_dir, stage, "samples") + + dataloader = get_dataloader( + input_file_paths=self.get_input_files( + path, epoch=epoch, stage=stage + ).tolist(), + total_num_nodes=None, + sparse_format="csc", + return_type="cugraph_dgl.nn.SparseGraph", + ) + return dataloader + + @property + def data(self): + import logging + + logger = logging.getLogger("DGLCuGraphTrainer") + logger.info("getting data") + + if self.__data is None: + # FIXME wholegraph + fs = FeatureStore(backend=self.__backend) + num_nodes_dict = {} + + if self.__backend == "wholegraph": + from pylibwholegraph.torch.initialize import get_global_communicator + + wm_comm = get_global_communicator() + wm_comm.barrier() + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + fs.add_data(x, node_type, "x") + num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + if self.__backend == "wholegraph": + wm_comm.barrier() + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + fs.add_data_no_cast(y, node_type, "y") + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + fs.add_data_no_cast(train, node_type, "train") + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + fs.add_data_no_cast(test, node_type, "test") + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + fs.add_data_no_cast(val, node_type, "val") + + # # TODO support online sampling if the edge index is provided + # num_edges_dict = self.__dataset.edge_index_dict + # if not isinstance(list(num_edges_dict.values())[0], int): + # num_edges_dict = {k: len(v) for k, v in num_edges_dict} + + if self.__backend == "wholegraph": + wm_comm.barrier() + + self.__data = fs + return self.__data + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + GraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + model_backend="cugraph_dgl", + ) + .to(torch.float32) + .to(self.__device) + ) + # TODO: Fix for distributed models + if torch.distributed.is_initialized(): + model = ddp(model, device_ids=[self.__device]) + else: + warnings.warn("Distributed training is not available") + print("done creating model") + + return model + + def get_input_files(self, path, epoch=0, stage="train"): + file_list = np.array([f.path for f in os.scandir(path)]) + file_list.sort() + + splits = np.array_split(file_list, self.__world_size) + np.random.seed(epoch) + np.random.shuffle(splits) + + return splits[self.rank] diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py new file mode 100644 index 00000000000..ab475066f0a --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import torch +import torch.distributed as td +import torch.nn.functional as F +from torchmetrics import Accuracy +from trainers import Trainer +import time + + +def get_features(input_nodes, output_nodes, feature_store, key="paper"): + if isinstance(input_nodes, dict): + input_nodes = input_nodes[key] + if isinstance(output_nodes, dict): + output_nodes = output_nodes[key] + + # TODO: Fix below + # Adding based on assumption that cpu features + # and gpu index is not supported yet + + if feature_store.backend == "torch": + input_nodes = input_nodes.to("cpu") + output_nodes = output_nodes.to("cpu") + + x = feature_store.get_data(indices=input_nodes, type_name=key, feat_name="x") + y = feature_store.get_data(indices=output_nodes, type_name=key, feat_name="y") + return x, y + + +def log_batch( + logger, + iter_i, + num_batches, + time_forward, + time_backward, + time_start, + loader_time_iter, + epoch, + rank, +): + time_forward_iter = time_forward / num_batches + time_backward_iter = time_backward / num_batches + total_time_iter = (time.perf_counter() - time_start) / num_batches + logger.info(f"epoch {epoch}, iteration {iter_i}, rank {rank}") + logger.info(f"time forward: {time_forward_iter}") + logger.info(f"time backward: {time_backward_iter}") + logger.info(f"loader time: {loader_time_iter}") + logger.info(f"total time: {total_time_iter}") + + +def train_epoch( + model, optimizer, loader, feature_store, epoch, num_classes, time_d, logger, rank +): + """ + Train the model for one epoch. + model: The model to train. + optimizer: The optimizer to use. + loader: The loader to use. + data: cuGraph.gnn.FeatueStore + epoch: The epoch number. + num_classes: The number of classes. + time_d: A dictionary of times. + logger: The logger to use. + rank: Total rank + """ + model = model.train() + time_feature_indexing = time_d["time_feature_indexing"] + time_feature_transfer = time_d["time_feature_transfer"] + time_forward = time_d["time_forward"] + time_backward = time_d["time_backward"] + time_loader = time_d["time_loader"] + + time_start = time.perf_counter() + end_time_backward = time.perf_counter() + + num_batches = 0 + total_loss = 0.0 + + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + loader_time_iter = time.perf_counter() - end_time_backward + time_loader += loader_time_iter + feature_indexing_time_start = time.perf_counter() + x, y_true = get_features(input_nodes, output_nodes, feature_store=feature_store) + additional_feature_time_end = time.perf_counter() + time_feature_indexing += ( + additional_feature_time_end - feature_indexing_time_start + ) + feature_trasfer_time_start = time.perf_counter() + x = x.to("cuda") + y_true = y_true.to("cuda") + time_feature_transfer += time.perf_counter() - feature_trasfer_time_start + num_batches += 1 + + start_time_forward = time.perf_counter() + y_pred = model( + blocks, + x, + ) + end_time_forward = time.perf_counter() + time_forward += end_time_forward - start_time_forward + + if y_pred.shape[0] > len(y_true): + raise ValueError(f"illegal shape: {y_pred.shape}; {y_true.shape}") + + y_true = y_true[: y_pred.shape[0]] + y_true = F.one_hot( + y_true.to(torch.int64), + num_classes=num_classes, + ).to(torch.float32) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true shape was {y_true.shape} " + f"but y_pred shape was {y_pred.shape} " + f"in iteration {iter_i} " + f"on rank {y_pred.device.index}" + ) + + start_time_backward = time.perf_counter() + loss = F.cross_entropy(y_pred, y_true) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + end_time_backward = time.perf_counter() + time_backward += end_time_backward - start_time_backward + + if iter_i % 50 == 0: + log_batch( + logger=logger, + iter_i=iter_i, + num_batches=num_batches, + time_forward=time_forward, + time_backward=time_backward, + time_start=time_start, + loader_time_iter=loader_time_iter, + epoch=epoch, + rank=rank, + ) + + time_d["time_loader"] += time_loader + time_d["time_feature_indexing"] += time_feature_indexing + time_d["time_feature_transfer"] += time_feature_transfer + time_d["time_forward"] += time_forward + time_d["time_backward"] += time_backward + + return num_batches, total_loss + + +def get_accuracy(model, loader, feature_store, num_classes): + print("Computing accuracy...", flush=True) + acc = Accuracy(task="multiclass", num_classes=num_classes).cuda() + acc_sum = 0.0 + num_batches = 0 + with torch.no_grad(): + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + x, y_true = get_features( + input_nodes, output_nodes, feature_store=feature_store + ) + x = x.to("cuda") + y_true = y_true.to("cuda") + + out = model(blocks, x) + batch_size = out.shape[0] + acc_sum += acc(out[:batch_size].softmax(dim=-1), y_true[:batch_size]) + num_batches += 1 + + num_batches = num_batches + + acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32, device="cuda") + td.all_reduce(acc_sum, op=td.ReduceOp.SUM) + nb = torch.tensor(float(num_batches), dtype=torch.float32, device=acc_sum.device) + td.all_reduce(nb, op=td.ReduceOp.SUM) + + acc = acc_sum / nb + + print( + f"Accuracy: {acc * 100.0:.4f}%", + ) + return acc * 100.0 + + +class DGLTrainer(Trainer): + def train(self): + logger = logging.getLogger("DGLTrainer") + time_d = { + "time_loader": 0.0, + "time_feature_indexing": 0.0, + "time_feature_transfer": 0.0, + "time_forward": 0.0, + "time_backward": 0.0, + } + total_batches = 0 + for epoch in range(self.num_epochs): + start_time = time.perf_counter() + self.model.train() + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + num_batches, total_loss = train_epoch( + model=self.model, + optimizer=self.optimizer, + loader=self.get_loader(epoch=epoch, stage="train"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + epoch=epoch, + time_d=time_d, + logger=logger, + rank=self.rank, + ) + total_batches = total_batches + num_batches + end_time = time.perf_counter() + epoch_time_taken = end_time - start_time + print( + f"RANK: {self.rank} Total time taken for training epoch {epoch} = {epoch_time_taken}", + flush=True, + ) + print("---" * 30) + td.barrier() + self.model.eval() + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + # test + test_acc = get_accuracy( + model=self.model.module, + loader=self.get_loader(epoch=epoch, stage="test"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + ) + print(f"Accuracy: {test_acc:.4f}%") + + # val: + self.model.eval() + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + val_acc = get_accuracy( + model=self.model.module, + loader=self.get_loader(epoch=epoch, stage="val"), + feature_store=self.data, + num_classes=self.dataset.num_labels, + ) + print(f"Validation Accuracy: {val_acc:.4f}%") + + val_acc = float(val_acc) + stats = { + "Accuracy": val_acc, + "# Batches": total_batches, + "Loader Time": time_d["time_loader"], + "Feature Time": time_d["time_feature_indexing"] + + time_d["time_feature_transfer"], + "Forward Time": time_d["time_forward"], + "Backward Time": time_d["time_backward"], + } + return stats + + +# For native DGL training, see benchmarks/cugraph-dgl/scale-benchmarks diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py index 71151e9ba59..22d1f29558c 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py @@ -36,8 +36,13 @@ def __init__( world_size=1, num_epochs=1, sample_dir=".", + backend="torch", **kwargs, ): + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + logger.info("creating trainer") self.__data = None self.__device = device self.__rank = rank @@ -47,7 +52,9 @@ def __init__( self.__sample_dir = sample_dir self.__loader_kwargs = kwargs self.__model = self.get_model(model) + self.__backend = backend self.__optimizer = None + logger.info("created trainer") @property def rank(self): @@ -81,10 +88,10 @@ def get_loader(self, epoch: int = 0, stage="train") -> int: logger.info(f"getting loader for epoch {epoch}, {stage} stage") # TODO support online sampling - if stage == "val": - path = os.path.join(self.__sample_dir, "val", "samples") - else: + if stage == "train": path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + else: + path = os.path.join(self.__sample_dir, stage, "samples") loader = BulkSampleLoader( self.data, @@ -106,36 +113,46 @@ def data(self): logger.info("getting data") if self.__data is None: - # FIXME wholegraph - fs = FeatureStore(backend="torch") + fs = FeatureStore(backend=self.__backend) num_nodes_dict = {} + if self.__backend == "wholegraph": + from pylibwholegraph.torch.initialize import get_global_communicator + + wm_comm = get_global_communicator() + wm_comm.barrier() + for node_type, x in self.__dataset.x_dict.items(): logger.debug(f"getting x for {node_type}") fs.add_data(x, node_type, "x") num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + if self.__backend == "wholegraph": + wm_comm.barrier() for node_type, y in self.__dataset.y_dict.items(): logger.debug(f"getting y for {node_type}") - fs.add_data(y, node_type, "y") + fs.add_data_no_cast(y.cuda(), node_type, "y") for node_type, train in self.__dataset.train_dict.items(): logger.debug(f"getting train for {node_type}") - fs.add_data(train, node_type, "train") + fs.add_data_no_cast(train.cuda(), node_type, "train") for node_type, test in self.__dataset.test_dict.items(): logger.debug(f"getting test for {node_type}") - fs.add_data(test, node_type, "test") + fs.add_data_no_cast(test.cuda(), node_type, "test") for node_type, val in self.__dataset.val_dict.items(): logger.debug(f"getting val for {node_type}") - fs.add_data(val, node_type, "val") + fs.add_data_no_cast(val.cuda(), node_type, "val") # TODO support online sampling if the edge index is provided num_edges_dict = self.__dataset.edge_index_dict if not isinstance(list(num_edges_dict.values())[0], int): num_edges_dict = {k: len(v) for k, v in num_edges_dict} + if self.__backend == "wholegraph": + wm_comm.barrier() + self.__data = CuGraphStore( fs, num_edges_dict, @@ -175,10 +192,7 @@ def get_input_files(self, path, epoch=0, stage="train"): file_list = np.array(os.listdir(path)) file_list.sort() - if stage == "train": - splits = np.array_split(file_list, self.__world_size) - np.random.seed(epoch) - np.random.shuffle(splits) - return splits[self.rank] - else: - return file_list + splits = np.array_split(file_list, self.__world_size) + np.random.seed(epoch) + np.random.shuffle(splits) + return splits[self.rank] diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py index bddd6ae2644..ca5a3fa50fd 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py @@ -45,6 +45,43 @@ def pyg_num_workers(world_size): return int(num_workers) +def calc_accuracy(loader, model, num_classes): + from torchmetrics import Accuracy + + acc = Accuracy(task="multiclass", num_classes=num_classes).cuda() + + acc_sum = 0.0 + num_batches = 0 + with torch.no_grad(): + for i, batch in enumerate(loader): + num_sampled_nodes = sum( + [torch.as_tensor(n) for n in batch.num_sampled_nodes_dict.values()] + ) + num_sampled_edges = sum( + [torch.as_tensor(e) for e in batch.num_sampled_edges_dict.values()] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long) + out = model( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc(out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + num_batches += 1 + + acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32, device="cuda") + td.all_reduce(acc_sum, op=td.ReduceOp.SUM) + nb = torch.tensor(float(num_batches), dtype=torch.float32, device=acc_sum.device) + td.all_reduce(nb, op=td.ReduceOp.SUM) + + return acc_sum / nb + + class PyGTrainer(Trainer): def train(self): import logging @@ -166,95 +203,33 @@ def train(self): end_time = time.perf_counter() - # test - from torchmetrics import Accuracy - - acc = Accuracy( - task="multiclass", num_classes=self.dataset.num_labels - ).cuda() - with td.algorithms.join.Join( [self.model], divide_by_initial_world_size=False ): self.model.eval() - if self.rank == 0: - acc_sum = 0.0 - with torch.no_grad(): - for i, batch in enumerate( - self.get_loader(epoch=epoch, stage="test") - ): - num_sampled_nodes = sum( - [ - torch.as_tensor(n) - for n in batch.num_sampled_nodes_dict.values() - ] - ) - num_sampled_edges = sum( - [ - torch.as_tensor(e) - for e in batch.num_sampled_edges_dict.values() - ] - ) - batch_size = num_sampled_nodes[0] - - batch = batch.to_homogeneous().cuda() - - batch.y = batch.y.to(torch.long) - out = self.model.module( - batch.x, - batch.edge_index, - num_sampled_nodes, - num_sampled_edges, - ) - acc_sum += acc( - out[:batch_size].softmax(dim=-1), batch.y[:batch_size] - ) - print( - f"Accuracy: {acc_sum/(i) * 100.0:.4f}%", - ) + loader = self.get_loader(epoch=epoch, stage="test") + num_classes = self.dataset.num_labels - td.barrier() + acc = calc_accuracy(loader, self.model.module, num_classes) - with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): - self.model.eval() if self.rank == 0: - acc_sum = 0.0 - with torch.no_grad(): - for i, batch in enumerate( - self.get_loader(epoch=epoch, stage="val") - ): - num_sampled_nodes = sum( - [ - torch.as_tensor(n) - for n in batch.num_sampled_nodes_dict.values() - ] - ) - num_sampled_edges = sum( - [ - torch.as_tensor(e) - for e in batch.num_sampled_edges_dict.values() - ] - ) - batch_size = num_sampled_nodes[0] - - batch = batch.to_homogeneous().cuda() - - batch.y = batch.y.to(torch.long) - out = self.model.module( - batch.x, - batch.edge_index, - num_sampled_nodes, - num_sampled_edges, - ) - acc_sum += acc( - out[:batch_size].softmax(dim=-1), batch.y[:batch_size] - ) print( - f"Validation Accuracy: {acc_sum/(i) * 100.0:.4f}%", + f"Accuracy: {acc * 100.0:.4f}%", ) + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + self.model.eval() + loader = self.get_loader(epoch=epoch, stage="val") + num_classes = self.dataset.num_labels + acc = calc_accuracy(loader, self.model.module, num_classes) + + if self.rank == 0: + print( + f"Validation Accuracy: {acc * 100.0:.4f}%", + ) + stats = { - "Accuracy": float(acc_sum / (i) * 100.0) if self.rank == 0 else 0.0, + "Accuracy": float(acc * 100.0), "# Batches": num_batches, "Loader Time": time_loader, "Feature Transfer Time": time_feature_transfer, diff --git a/cugraph_sampling_stats.csv b/cugraph_sampling_stats.csv new file mode 100644 index 00000000000..906f3e6be4e --- /dev/null +++ b/cugraph_sampling_stats.csv @@ -0,0 +1,2 @@ +,dataset,num_input_edges,directed,renumber,input_memory_per_worker,peak_allocation_across_workers,input_to_peak_ratio,output_to_peak_ratio +0,ogbn_papers100M,3231371744,,,6.0GB,12.0GB,1.9942090995767523,227565.63197257713 diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py index 815fd30d8eb..f6fe38fe9f8 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -63,6 +63,10 @@ def __getitem__(self, idx: int): fn, batch_offset = self._batch_to_fn_d[idx] if fn != self._current_batch_fn: + # Remove current batches to free up memory + # before loading new batches + if hasattr(self, "_current_batches"): + del self._current_batches if self.sparse_format == "csc": df = _load_sampled_file(dataset_obj=self, fn=fn, skip_rename=True) self._current_batches = ( diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index 05d540b7c45..df16fc9fd6c 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -1083,13 +1083,12 @@ def _get_tensor(self, attr: CuGraphTensorAttr) -> TensorType: idx = attr.index if idx is not None: - if feature_backend == "torch": + if feature_backend in ["torch", "wholegraph"]: if not isinstance(idx, torch.Tensor): raise TypeError( f"Type {type(idx)} invalid" f" for feature store backend {feature_backend}" ) - idx = idx.cpu() elif feature_backend == "numpy": # allow feature indexing through cupy arrays if isinstance(idx, cupy.ndarray): @@ -1244,5 +1243,77 @@ def _infer_unspecified_attr(self, attr: CuGraphTensorAttr) -> CuGraphTensorAttr: return attr + def filter( + self, + format: str, + node_dict: Dict[str, torch.Tensor], + row_dict: Dict[str, torch.Tensor], + col_dict: Dict[str, torch.Tensor], + edge_dict: Dict[str, Tuple[torch.Tensor]], + ) -> torch_geometric.data.HeteroData: + """ + Parameters + ---------- + format: str + COO or CSC + node_dict: Dict[str, torch.Tensor] + IDs of nodes in original store being outputted + row_dict: Dict[str, torch.Tensor] + Renumbered output edge index row + col_dict: Dict[str, torch.Tensor] + Renumbered output edge index column + edge_dict: Dict[str, Tuple[torch.Tensor]] + Currently unused original edge mapping + """ + data = torch_geometric.data.HeteroData() + + # TODO use torch_geometric.EdgeIndex in release 24.04 (Issue #4051) + for attr in self.get_all_edge_attrs(): + key = attr.edge_type + if key in row_dict and key in col_dict: + if format == "CSC": + data.put_edge_index( + (row_dict[key], col_dict[key]), + edge_type=key, + layout="csc", + is_sorted=True, + ) + else: + data[key].edge_index = torch.stack( + [ + row_dict[key], + col_dict[key], + ], + dim=0, + ) + + required_attrs = [] + # To prevent copying multiple times, we use a cache; + # the original node_dict serves as the gpu cache if needed + node_dict_cpu = {} + for attr in self.get_all_tensor_attrs(): + if attr.group_name in node_dict: + device = self.__features.get_storage(attr.group_name, attr.attr_name) + attr.index = node_dict[attr.group_name] + if not isinstance(attr.index, torch.Tensor): + raise ValueError("Node index must be a tensor!") + if attr.index.is_cuda and device == "cpu": + if attr.group_name not in node_dict_cpu: + node_dict_cpu[attr.group_name] = attr.index.cpu() + attr.index = node_dict_cpu[attr.group_name] + elif attr.index.is_cpu and device == "cuda": + node_dict_cpu[attr.group_name] = attr.index + node_dict[attr.group_name] = attr.index.cuda() + attr.index = node_dict[attr.group_name] + + required_attrs.append(attr) + data[attr.group_name].num_nodes = attr.index.size(0) + + tensors = self.multi_get_tensor(required_attrs) + for i, attr in enumerate(required_attrs): + data[attr.group_name][attr.attr_name] = tensors[i] + + return data + def __len__(self): return len(self.get_all_tensor_attrs()) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index bcfaf579820..55c9e9b3329 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -28,7 +28,6 @@ _sampler_output_from_sampling_results_heterogeneous, _sampler_output_from_sampling_results_homogeneous_csr, _sampler_output_from_sampling_results_homogeneous_coo, - filter_cugraph_store_csc, ) from typing import Union, Tuple, Sequence, List, Dict @@ -454,31 +453,20 @@ def __next__(self): start_time_feature = perf_counter() # Create a PyG HeteroData object, loading the required features - if self.__coo: - pyg_filter_fn = ( - torch_geometric.loader.utils.filter_custom_hetero_store - if hasattr(torch_geometric.loader.utils, "filter_custom_hetero_store") - else torch_geometric.loader.utils.filter_custom_store - ) - out = pyg_filter_fn( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) - else: - out = filter_cugraph_store_csc( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) + if self.__graph_store != self.__feature_store: + # TODO Possibly support this if there is an actual use case + raise ValueError("Separate graph and feature stores currently unsupported") + + out = self.__graph_store.filter( + "COO" if self.__coo else "CSC", + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) # Account for CSR format in cuGraph vs. CSC format in PyG + # TODO deprecate and remove this functionality if self.__coo and self.__graph_store.order == "CSC": for edge_type in out.edge_index_dict: out[edge_type].edge_index = out[edge_type].edge_index.flip(dims=[0]) diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 65cb63d25e0..ffab54efe08 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -411,6 +411,10 @@ def filter_cugraph_store_csc( col_dict: Dict[str, torch.Tensor], edge_dict: Dict[str, Tuple[torch.Tensor]], ) -> torch_geometric.data.HeteroData: + """ + Deprecated + """ + data = torch_geometric.data.HeteroData() for attr in graph_store.get_all_edge_attrs(): diff --git a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py index 77a53882fc4..2981e6edbf9 100644 --- a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py +++ b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -168,19 +168,53 @@ def get_data( feat, wgth.WholeMemoryEmbedding ): indices_tensor = ( - indices + indices.cuda() if isinstance(indices, torch.Tensor) else torch.as_tensor(indices, device="cuda") ) return feat.gather(indices_tensor) - else: - return feat[indices] + elif isinstance(feat, torch.Tensor): + if not isinstance(indices, torch.Tensor): + indices = torch.as_tensor(indices) + + if feat.is_cpu and indices.is_cuda: + # TODO maybe add a warning here + indices = indices.cpu() + return feat[indices] def get_feature_list(self) -> list[str]: return {feat_name: feats.keys() for feat_name, feats in self.fd.items()} + def get_storage(self, type_name: str, feat_name: str) -> str: + """ + Returns where the data is stored (cuda, cpu). + Note: will return "cuda" for data managed by CUDA, even if + it is in host memory. + + Parameters + ---------- + type_name : str + The node-type/edge-type to store data + feat_name: + The feature name to retrieve data for + + Returns + ------- + "cuda" for data managed by CUDA, otherwise "CPU". + """ + feat = self.fd[feat_name][type_name] + if not isinstance(wgth, MissingModule) and isinstance( + feat, wgth.WholeMemoryEmbedding + ): + return "cuda" + elif isinstance(feat, torch.Tensor): + return "cpu" if feat.is_cpu else "cuda" + else: + return "cpu" + @staticmethod def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs): + # TODO (Issue #4078) support casting WG tensors to numpy and torch if backend == "numpy": if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)): return _cast_to_numpy_ar(feat_obj.values, **kwargs) @@ -192,6 +226,8 @@ def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs): else: return _cast_to_torch_tensor(feat_obj, **kwargs) elif backend == "wholegraph": + if isinstance(feat_obj, wgth.WholeMemoryEmbedding): + return feat_obj return _get_wg_embedding(feat_obj, **kwargs)