diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b5b16e3ea1..5b351478fa9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: hooks: - id: black language_version: python3 - args: [--target-version=py39] + args: [--target-version=py310] files: ^(python/.*|benchmarks/.*)$ exclude: ^python/nx-cugraph/ - repo: https://github.com/PyCQA/flake8 @@ -42,7 +42,7 @@ repos: types_or: [c, c++, cuda] args: ["-fallback-style=none", "-style=file", "-i"] - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.3.1 + rev: v0.4.0 hooks: - id: verify-copyright files: | diff --git a/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py b/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py index 8c46095a7da..083acdde2f4 100644 --- a/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py +++ b/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py @@ -266,7 +266,7 @@ def uns_func(*args, **kwargs): @pytest.mark.managedmem_off @pytest.mark.poolallocator_on @pytest.mark.parametrize("batch_size", params.batch_sizes.values()) -@pytest.mark.parametrize("fanout", [params.fanout_10_25, params.fanout_5_10_15]) +@pytest.mark.parametrize("fanout", [params.fanout_10_25]) @pytest.mark.parametrize( "with_replacement", [False], ids=lambda v: f"with_replacement={v}" ) @@ -287,6 +287,8 @@ def bench_cugraph_uniform_neighbor_sample( start_list=uns_args["start_list"], fanout_vals=uns_args["fanout"], with_replacement=uns_args["with_replacement"], + use_legacy_names=False, + with_edge_properties=True, ) """ dtmap = {"int32": 32 // 8, "int64": 64 // 8} diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 8f9c5ec7a9e..18cca40c320 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -42,6 +42,7 @@ dependencies: - numpy>=1.23,<2.0a0 - numpydoc - nvcc_linux-64=11.8 +- ogb - openmpi - packaging>=21 - pandas @@ -74,6 +75,7 @@ dependencies: - sphinxcontrib-websupport - thriftpy2!=0.5.0,!=0.5.1 - torchdata +- torchmetrics - ucx-proc=*=gpu - ucx-py==0.40.*,>=0.0.0a0 - wget diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index dec67ba4fe4..ef20371e0f5 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -47,6 +47,7 @@ dependencies: - numba>=0.57 - numpy>=1.23,<2.0a0 - numpydoc +- ogb - openmpi - packaging>=21 - pandas @@ -79,6 +80,7 @@ dependencies: - sphinxcontrib-websupport - thriftpy2!=0.5.0,!=0.5.1 - torchdata +- torchmetrics - ucx-proc=*=gpu - ucx-py==0.40.*,>=0.0.0a0 - wget diff --git a/conda/recipes/nx-cugraph/meta.yaml b/conda/recipes/nx-cugraph/meta.yaml index d67287be757..263f53d9a8f 100644 --- a/conda/recipes/nx-cugraph/meta.yaml +++ b/conda/recipes/nx-cugraph/meta.yaml @@ -14,9 +14,7 @@ source: build: number: {{ GIT_DESCRIBE_NUMBER }} - build: - number: {{ GIT_DESCRIBE_NUMBER }} - string: py{{ py_version }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} + string: py{{ py_version }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} requirements: host: diff --git a/cpp/examples/developers/graph_operations/graph_operations.cu b/cpp/examples/developers/graph_operations/graph_operations.cu index 014cedcab7e..912f9f1fd46 100644 --- a/cpp/examples/developers/graph_operations/graph_operations.cu +++ b/cpp/examples/developers/graph_operations/graph_operations.cu @@ -131,7 +131,7 @@ create_graph(raft::handle_t const& handle, // if (multi_gpu) { - std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore) = + std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore, std::ignore) = cugraph::shuffle_external_edges(handle, std::move(d_edge_srcs), std::move(d_edge_dsts), @@ -215,10 +215,10 @@ void perform_example_graph_operations( graph_view); cugraph::update_edge_src_property( - handle, graph_view, vertex_weights.begin(), src_vertex_weights_cache); + handle, graph_view, vertex_weights.begin(), src_vertex_weights_cache.mutable_view()); cugraph::update_edge_dst_property( - handle, graph_view, vertex_weights.begin(), dst_vertex_weights_cache); + handle, graph_view, vertex_weights.begin(), dst_vertex_weights_cache.mutable_view()); rmm::device_uvector weighted_averages( size_of_the_vertex_partition_assigned_to_this_process, handle.get_stream()); @@ -259,10 +259,10 @@ void perform_example_graph_operations( graph_view); cugraph::update_edge_src_property( - handle, graph_view, vertex_weights.begin(), src_vertex_weights_cache); + handle, graph_view, vertex_weights.begin(), src_vertex_weights_cache.mutable_view()); cugraph::update_edge_dst_property( - handle, graph_view, vertex_weights.begin(), dst_vertex_weights_cache); + handle, graph_view, vertex_weights.begin(), dst_vertex_weights_cache.mutable_view()); rmm::device_uvector weighted_averages( size_of_the_vertex_partition_assigned_to_this_process, handle.get_stream()); diff --git a/cpp/examples/developers/vertex_and_edge_partition/vertex_and_edge_partition.cu b/cpp/examples/developers/vertex_and_edge_partition/vertex_and_edge_partition.cu index ce02e3b2639..c261ff6d843 100644 --- a/cpp/examples/developers/vertex_and_edge_partition/vertex_and_edge_partition.cu +++ b/cpp/examples/developers/vertex_and_edge_partition/vertex_and_edge_partition.cu @@ -127,7 +127,7 @@ create_graph(raft::handle_t const& handle, // if (multi_gpu) { - std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore) = + std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore, std::ignore) = cugraph::shuffle_external_edges(handle, std::move(d_edge_srcs), std::move(d_edge_dsts), diff --git a/cpp/examples/users/multi_gpu_application/mg_graph_algorithms.cpp b/cpp/examples/users/multi_gpu_application/mg_graph_algorithms.cpp index a9e2a170208..db629117604 100644 --- a/cpp/examples/users/multi_gpu_application/mg_graph_algorithms.cpp +++ b/cpp/examples/users/multi_gpu_application/mg_graph_algorithms.cpp @@ -123,7 +123,7 @@ create_graph(raft::handle_t const& handle, // if (multi_gpu) { - std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore) = + std::tie(d_edge_srcs, d_edge_dsts, d_edge_wgts, std::ignore, std::ignore, std::ignore) = cugraph::shuffle_external_edges(handle, std::move(d_edge_srcs), std::move(d_edge_dsts), @@ -248,9 +248,8 @@ void run_graph_algorithms( std::cout); } -int main(int argc, char** argv) +void run_tests() { - initialize_mpi_and_set_device(argc, argv); std::unique_ptr handle = initialize_mg_handle(); // @@ -279,6 +278,7 @@ int main(int argc, char** argv) std::move(std::make_optional(edge_wgts)), renumber, is_symmetric); + // Non-owning view of the graph object auto graph_view = graph.view(); @@ -292,5 +292,14 @@ int main(int argc, char** argv) run_graph_algorithms( *handle, graph_view, edge_weight_view); + handle.release(); +} + +int main(int argc, char** argv) +{ + initialize_mpi_and_set_device(argc, argv); + + run_tests(); + RAFT_MPI_TRY(MPI_Finalize()); } diff --git a/cpp/include/cugraph/graph_functions.hpp b/cpp/include/cugraph/graph_functions.hpp index e1364f69991..7f6543ccab8 100644 --- a/cpp/include/cugraph/graph_functions.hpp +++ b/cpp/include/cugraph/graph_functions.hpp @@ -1178,7 +1178,8 @@ std::tuple, rmm::device_uvector, std::optional>, std::optional>, - std::optional>> + std::optional>, + std::vector> shuffle_external_edges(raft::handle_t const& handle, rmm::device_uvector&& edge_srcs, rmm::device_uvector&& edge_dsts, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index da31f498de1..3c3a3650491 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -490,7 +490,7 @@ ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_te ################################################################################################### # - NEGATIVE SAMPLING tests -------------------------------------------------------------------- -ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cpp) +ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cpp PERCENT 100) ################################################################################################### # - Renumber tests -------------------------------------------------------------------------------- diff --git a/dependencies.yaml b/dependencies.yaml index 6506dd10284..8619b32e929 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -713,6 +713,8 @@ dependencies: - &pytorch_unsuffixed pytorch>=2.0,<2.2.0a0 - torchdata - pydantic + - ogb + - torchmetrics specific: - output_types: [requirements] diff --git a/docs/cugraph/source/installation/source_build.md b/docs/cugraph/source/installation/source_build.md index 89e63badef8..80f2d97d40d 100644 --- a/docs/cugraph/source/installation/source_build.md +++ b/docs/cugraph/source/installation/source_build.md @@ -12,8 +12,7 @@ __Compilers:__ * `nvcc` version 11.5+ __CUDA:__ -* CUDA 11.2+ -* NVIDIA driver 470.42.01 or newer +* CUDA 11.8+ * NVIDIA GPU, Volta architecture or later, with [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0+ Further details and download links for these prerequisites are available on the @@ -178,7 +177,7 @@ Run either the C++ or the Python tests with datasets make test ``` -Note: This conda installation only applies to Linux and Python versions 3.8/3.11. +Note: This conda installation only applies to Linux and Python versions 3.10 and 3.11. ### (OPTIONAL) Set environment variable on activation diff --git a/docs/cugraph/source/tutorials/cugraph_notebooks.md b/docs/cugraph/source/tutorials/cugraph_notebooks.md index 559ba36e97e..6d7840dc3c4 100644 --- a/docs/cugraph/source/tutorials/cugraph_notebooks.md +++ b/docs/cugraph/source/tutorials/cugraph_notebooks.md @@ -55,10 +55,9 @@ Running the example in these notebooks requires: * Download via Docker, Conda (See [__Getting Started__](https://rapids.ai/start.html)) * cuGraph is dependent on the latest version of cuDF. Please install all components of RAPIDS -* Python 3.8+ -* A system with an NVIDIA GPU: Pascal architecture or better +* Python 3.10+ +* A system with an NVIDIA GPU: Volta architecture or newer * CUDA 11.4+ -* NVIDIA driver 450.51+ ## Copyright diff --git a/docs/cugraph/source/wholegraph/installation/source_build.md b/docs/cugraph/source/wholegraph/installation/source_build.md index a7727ac4052..33d7c98b28e 100644 --- a/docs/cugraph/source/wholegraph/installation/source_build.md +++ b/docs/cugraph/source/wholegraph/installation/source_build.md @@ -16,8 +16,7 @@ __Compiler__: __CUDA__: * CUDA 11.8+ -* NVIDIA driver 450.80.02+ -* Pascal architecture or better +* Volta architecture or better You can obtain CUDA from [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads). @@ -177,7 +176,7 @@ Run either the C++ or the Python tests with datasets ``` -Note: This conda installation only applies to Linux and Python versions 3.8/3.10. +Note: This conda installation only applies to Linux and Python versions 3.10 and 3.11. ## Creating documentation diff --git a/notebooks/README.md b/notebooks/README.md index 06ab93688ec..a8f094c340b 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -56,10 +56,9 @@ Running the example in these notebooks requires: * Download via Docker, Conda (See [__Getting Started__](https://rapids.ai/start.html)) * cuGraph is dependent on the latest version of cuDF. Please install all components of RAPIDS -* Python 3.8+ -* A system with an NVIDIA GPU: Pascal architecture or better +* Python 3.10+ +* A system with an NVIDIA GPU: Volta architecture or newer * CUDA 11.4+ -* NVIDIA driver 450.51+ ### QuickStart diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py b/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py index 21b70b05f3a..4f36353cb18 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py @@ -140,6 +140,10 @@ def __init__( self.__graph = graph self.__device = device + @property + def _batch_size(self): + return self.__batch_size + @property def dataset( self, diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py index 1a35c3ea027..87d111adcba 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py @@ -194,7 +194,7 @@ def sample( if g.is_homogeneous: indices = torch.concat(list(indices)) - ds.sample_from_nodes(indices, batch_size=batch_size) + ds.sample_from_nodes(indices.long(), batch_size=batch_size) return HomogeneousSampleReader( ds.get_reader(), self.output_format, self.edge_dir ) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 2eba13c6958..011ab736d00 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -29,6 +29,7 @@ HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView, + EmbeddingView, ) @@ -567,8 +568,8 @@ def _has_n_emb(self, ntype: str, emb_name: str) -> bool: return (ntype, emb_name) in self.__ndata_storage def _get_n_emb( - self, ntype: str, emb_name: str, u: Union[str, TensorType] - ) -> "torch.Tensor": + self, ntype: Union[str, None], emb_name: str, u: Union[str, TensorType] + ) -> Union["torch.Tensor", "EmbeddingView"]: """ Gets the embedding of a single node type. Unlike DGL, this function takes the string node @@ -583,11 +584,11 @@ def _get_n_emb( u: Union[str, TensorType] Nodes to get the representation of, or ALL to get the representation of all nodes of - the given type. + the given type (returns embedding view). Returns ------- - torch.Tensor + Union[torch.Tensor, cugraph_dgl.view.EmbeddingView] The embedding of the given edge type with the given embedding name. """ @@ -598,9 +599,14 @@ def _get_n_emb( raise ValueError("Must provide the node type for a heterogeneous graph") if dgl.base.is_all(u): - u = torch.arange(self.num_nodes(ntype), dtype=self.idtype, device="cpu") + return EmbeddingView( + self.__ndata_storage[ntype, emb_name], self.num_nodes(ntype) + ) try: + print( + u, + ) return self.__ndata_storage[ntype, emb_name].fetch( _cast_to_torch_tensor(u), "cuda" ) @@ -644,7 +650,9 @@ def _get_e_emb( etype = self.to_canonical_etype(etype) if dgl.base.is_all(u): - u = torch.arange(self.num_edges(etype), dtype=self.idtype, device="cpu") + return EmbeddingView( + self.__edata_storage[etype, emb_name], self.num_edges(etype) + ) try: return self.__edata_storage[etype, emb_name].fetch( diff --git a/python/cugraph-dgl/cugraph_dgl/view.py b/python/cugraph-dgl/cugraph_dgl/view.py index dbc53e73b6a..4de9406be07 100644 --- a/python/cugraph-dgl/cugraph_dgl/view.py +++ b/python/cugraph-dgl/cugraph_dgl/view.py @@ -12,6 +12,8 @@ # limitations under the License. +import warnings + from collections import defaultdict from collections.abc import MutableMapping from typing import Union, Dict, List, Tuple @@ -20,11 +22,45 @@ import cugraph_dgl from cugraph_dgl.typing import TensorType +from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor torch = import_optional("torch") dgl = import_optional("dgl") +class EmbeddingView: + def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int): + self.__ld = ld + self.__storage = storage + + def __getitem__(self, u: TensorType) -> "torch.Tensor": + u = _cast_to_torch_tensor(u) + try: + return self.__storage.fetch( + u, + "cuda", + ) + except RuntimeError as ex: + warnings.warn( + "Got error accessing data, trying again with index on device: " + + str(ex) + ) + return self.__storage.fetch( + u.cuda(), + "cuda", + ) + + @property + def shape(self) -> "torch.Size": + try: + f = self.__storage.fetch(torch.tensor([0]), "cpu") + except RuntimeError: + f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda") + sz = [s for s in f.shape] + sz[0] = self.__ld + return torch.Size(tuple(sz)) + + class HeteroEdgeDataView(MutableMapping): """ Duck-typed version of DGL's HeteroEdgeDataView. diff --git a/python/cugraph-dgl/examples/graphsage/node-classification-dask.py b/python/cugraph-dgl/examples/graphsage/node-classification-dask.py new file mode 100644 index 00000000000..992669e4284 --- /dev/null +++ b/python/cugraph-dgl/examples/graphsage/node-classification-dask.py @@ -0,0 +1,270 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Example modified from: +# https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/node_classification.py + +# Ignore Warning +import warnings +import time +import cugraph_dgl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +from dgl.data import AsNodePredDataset +from dgl.dataloading import ( + DataLoader, + NeighborSampler, + MultiLayerFullNeighborSampler, +) +from ogb.nodeproppred import DglNodePropPredDataset +import tqdm +import argparse + +warnings.filterwarnings("ignore") + + +def set_allocators(): + import rmm + import cudf + import cupy + from rmm.allocators.torch import rmm_torch_allocator + from rmm.allocators.cupy import rmm_cupy_allocator + + mr = rmm.mr.CudaAsyncMemoryResource() + rmm.mr.set_current_device_resource(mr) + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) + cudf.set_option("spill", True) + + +class SAGE(nn.Module): + def __init__(self, in_size, hid_size, out_size): + super().__init__() + self.layers = nn.ModuleList() + # three-layer GraphSAGE-mean + self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean")) + self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean")) + self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean")) + self.dropout = nn.Dropout(0.5) + self.hid_size = hid_size + self.out_size = out_size + + def forward(self, blocks, x): + h = x + for l_id, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l_id != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + + def inference(self, g, device, batch_size): + """Conduct layer-wise inference to get all the node embeddings.""" + all_node_ids = torch.arange(0, g.num_nodes()).to(device) + feat = g.get_node_storage(key="feat", ntype="_N").fetch( + all_node_ids, device=device + ) + + sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) + dataloader = DataLoader( + g, + torch.arange(g.num_nodes()).to(g.device), + sampler, + device=device, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=0, + ) + buffer_device = torch.device("cpu") + pin_memory = buffer_device != device + + for l_id, layer in enumerate(self.layers): + y = torch.empty( + g.num_nodes(), + self.hid_size if l_id != len(self.layers) - 1 else self.out_size, + device=buffer_device, + pin_memory=pin_memory, + ) + feat = feat.to(device) + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = feat[input_nodes] + h = layer(blocks[0], x) # len(blocks) = 1 + if l_id != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + # by design, our output nodes are contiguous + y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device) + feat = y + return y + + +def evaluate(model, graph, dataloader): + model.eval() + ys = [] + y_hats = [] + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + with torch.no_grad(): + if isinstance(graph.ndata["feat"], dict): + x = graph.ndata["feat"]["_N"][input_nodes] + label = graph.ndata["label"]["_N"][output_nodes] + else: + x = graph.ndata["feat"][input_nodes] + label = graph.ndata["label"][output_nodes] + ys.append(label) + y_hats.append(model(blocks, x)) + num_classes = y_hats[0].shape[1] + return MF.accuracy( + torch.cat(y_hats), + torch.cat(ys), + task="multiclass", + num_classes=num_classes, + ) + + +def layerwise_infer(device, graph, nid, model, batch_size): + model.eval() + with torch.no_grad(): + pred = model.inference(graph, device, batch_size) # pred in buffer_device + pred = pred[nid] + label = graph.ndata["label"] + if isinstance(label, dict): + label = label["_N"] + label = label[nid].to(device).to(pred.device) + num_classes = pred.shape[1] + return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes) + + +def train(args, device, g, dataset, model): + # create sampler & dataloader + train_idx = dataset.train_idx.to(device) + val_idx = dataset.val_idx.to(device) + + use_uva = args.mode == "mixed" + batch_size = 1024 + fanouts = [5, 10, 15] + sampler = NeighborSampler(fanouts) + train_dataloader = DataLoader( + g, + train_idx, + sampler, + device=device, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=0, + use_uva=use_uva, + ) + val_dataloader = DataLoader( + g, + val_idx, + sampler, + device=device, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=0, + use_uva=use_uva, + ) + + opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) + + for epoch in range(10): + model.train() + total_loss = 0 + st = time.time() + for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): + if isinstance(g.ndata["feat"], dict): + x = g.ndata["feat"]["_N"][input_nodes] + y = g.ndata["label"]["_N"][output_nodes] + else: + x = g.ndata["feat"][input_nodes] + y = g.ndata["label"][output_nodes] + + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + total_loss += loss.item() + + et = time.time() + + print(f"Time taken for epoch {epoch} with batch_size {batch_size} = {et-st} s") + acc = evaluate(model, g, val_dataloader) + print( + "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format( + epoch, total_loss / (it + 1), acc.item() + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + default="gpu_cugraph_dgl", + choices=["cpu", "mixed", "gpu_dgl", "gpu_cugraph_dgl"], + help="Training mode." + " 'cpu' for CPU training," + " 'mixed' for CPU-GPU mixed training, " + " 'gpu_dgl' for pure-GPU training, " + " 'gpu_cugraph_dgl' for pure-GPU training.", + ) + args = parser.parse_args() + if not torch.cuda.is_available(): + args.mode = "cpu" + if args.mode == "gpu_cugraph_dgl": + set_allocators() + print(f"Training in {args.mode} mode.") + + # load and preprocess dataset + print("Loading data") + dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products")) + g = dataset[0] + g = dgl.add_self_loop(g) + if args.mode == "gpu_cugraph_dgl": + g = cugraph_dgl.cugraph_storage_from_heterograph(g.to("cuda")) + del dataset.g + + else: + g = g.to("cuda" if args.mode == "gpu_dgl" else "cpu") + device = torch.device( + "cpu" if args.mode == "cpu" or args.mode == "mixed" else "cuda" + ) + + # create GraphSAGE model + feat_shape = ( + g.get_node_storage(key="feat", ntype="_N") + .fetch(torch.LongTensor([0]).to(device), device=device) + .shape[1] + ) + print(feat_shape) + # no ndata in cugraph storage object + in_size = feat_shape + out_size = dataset.num_classes + model = SAGE(in_size, 256, out_size).to(device) + + # model training + print("Training...") + train(args, device, g, dataset, model) + + # test the model + print("Testing...") + acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096) + print("Test Accuracy {:.4f}".format(acc.item())) diff --git a/python/cugraph-dgl/examples/graphsage/node-classification.py b/python/cugraph-dgl/examples/graphsage/node-classification.py index 539fd86d136..2b8b687efab 100644 --- a/python/cugraph-dgl/examples/graphsage/node-classification.py +++ b/python/cugraph-dgl/examples/graphsage/node-classification.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,8 +17,10 @@ # Ignore Warning import warnings +import tempfile import time import cugraph_dgl +import cugraph_dgl.dataloading import torch import torch.nn as nn import torch.nn.functional as F @@ -76,14 +78,17 @@ def forward(self, blocks, x): def inference(self, g, device, batch_size): """Conduct layer-wise inference to get all the node embeddings.""" all_node_ids = torch.arange(0, g.num_nodes()).to(device) - feat = g.get_node_storage(key="feat", ntype="_N").fetch( - all_node_ids, device=device - ) + feat = g.ndata["feat"][all_node_ids].to(device) - sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) - dataloader = DataLoader( + if isinstance(g, cugraph_dgl.Graph): + sampler = cugraph_dgl.dataloading.NeighborSampler([-1]) + loader_cls = cugraph_dgl.dataloading.FutureDataLoader + else: + sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) + loader_cls = DataLoader + dataloader = loader_cls( g, - torch.arange(g.num_nodes()).to(g.device), + torch.arange(g.num_nodes()).to(device), sampler, device=device, batch_size=batch_size, @@ -150,7 +155,7 @@ def layerwise_infer(device, graph, nid, model, batch_size): return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes) -def train(args, device, g, dataset, model): +def train(args, device, g, dataset, model, directory): # create sampler & dataloader train_idx = dataset.train_idx.to(device) val_idx = dataset.val_idx.to(device) @@ -158,8 +163,13 @@ def train(args, device, g, dataset, model): use_uva = args.mode == "mixed" batch_size = 1024 fanouts = [5, 10, 15] - sampler = NeighborSampler(fanouts) - train_dataloader = DataLoader( + if isinstance(g, cugraph_dgl.Graph): + sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts, directory=directory) + loader_cls = cugraph_dgl.dataloading.FutureDataLoader + else: + sampler = NeighborSampler(fanouts) + loader_cls = DataLoader + train_dataloader = loader_cls( g, train_idx, sampler, @@ -170,7 +180,7 @@ def train(args, device, g, dataset, model): num_workers=0, use_uva=use_uva, ) - val_dataloader = DataLoader( + val_dataloader = loader_cls( g, val_idx, sampler, @@ -195,6 +205,7 @@ def train(args, device, g, dataset, model): else: x = g.ndata["feat"][input_nodes] y = g.ndata["label"][output_nodes] + y_hat = model(blocks, x) loss = F.cross_entropy(y_hat, y) opt.zero_grad() @@ -225,6 +236,8 @@ def train(args, device, g, dataset, model): " 'gpu_dgl' for pure-GPU training, " " 'gpu_cugraph_dgl' for pure-GPU training.", ) + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--tempdir_root", type=str, default=None) args = parser.parse_args() if not torch.cuda.is_available(): args.mode = "cpu" @@ -234,11 +247,13 @@ def train(args, device, g, dataset, model): # load and preprocess dataset print("Loading data") - dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products")) + dataset = AsNodePredDataset( + DglNodePropPredDataset("ogbn-products", root=args.dataset_root) + ) g = dataset[0] g = dgl.add_self_loop(g) if args.mode == "gpu_cugraph_dgl": - g = cugraph_dgl.cugraph_storage_from_heterograph(g.to("cuda")) + g = cugraph_dgl.cugraph_dgl_graph_from_heterograph(g.to("cuda")) del dataset.g else: @@ -248,19 +263,17 @@ def train(args, device, g, dataset, model): ) # create GraphSAGE model - feat_shape = ( - g.get_node_storage(key="feat", ntype="_N") - .fetch(torch.LongTensor([0]).to(device), device=device) - .shape[1] - ) - # no ndata in cugraph storage object + feat_shape = g.ndata["feat"].shape[1] + print(feat_shape) + in_size = feat_shape out_size = dataset.num_classes model = SAGE(in_size, 256, out_size).to(device) # model training print("Training...") - train(args, device, g, dataset, model) + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as directory: + train(args, device, g, dataset, model, directory) # test the model print("Testing...") diff --git a/python/cugraph-dgl/examples/multi_trainer_MG_example/model.py b/python/cugraph-dgl/examples/multi_trainer_MG_example/model.py index a6f771e4b51..d3aad2ab309 100644 --- a/python/cugraph-dgl/examples/multi_trainer_MG_example/model.py +++ b/python/cugraph-dgl/examples/multi_trainer_MG_example/model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -58,9 +58,8 @@ def inference(self, g, batch_size, device): # The nodes on each layer are of course splitted in batches. all_node_ids = torch.arange(0, g.num_nodes()).to(device) - feat = g.get_node_storage(key="feat", ntype="_N").fetch( - all_node_ids, device=device - ) + feat = g.ndata["feat"][all_node_ids].to(device) + sampler = dgl.dataloading.MultiLayerFullNeighborSampler( 1, prefetch_node_feats=["feat"] ) @@ -114,15 +113,13 @@ def layerwise_infer(graph, nid, model, batch_size, device): def train_model(model, g, opt, train_dataloader, num_epochs, rank, val_nid): - g.ndata["feat"]["_N"] = g.ndata["feat"]["_N"].to("cuda") - g.ndata["label"]["_N"] = g.ndata["label"]["_N"].to("cuda") st = time.time() model.train() for epoch in range(num_epochs): total_loss = 0 for _, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): - x = g.ndata["feat"]["_N"][input_nodes] - y = g.ndata["label"]["_N"][output_nodes] + x = g.ndata["feat"][input_nodes].to(torch.float32) + y = g.ndata["label"][output_nodes].to(torch.int64) y_hat = model(blocks, x) y = y.squeeze(1) loss = F.cross_entropy(y_hat, y) diff --git a/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow.py b/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow.py deleted file mode 100644 index 474f17dc2bb..00000000000 --- a/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dgl -import torch -import time -from distributed import Client, Event as Dask_Event -import tempfile -from cugraph.dask.comms import comms as Comms - - -def enable_spilling(): - import cudf - - cudf.set_option("spill", True) - - -def setup_cluster(dask_worker_devices): - dask_worker_devices_str = ",".join([str(i) for i in dask_worker_devices]) - from dask_cuda import LocalCUDACluster - - cluster = LocalCUDACluster( - protocol="tcp", - CUDA_VISIBLE_DEVICES=dask_worker_devices_str, - rmm_pool_size="25GB", - ) - - client = Client(cluster) - client.wait_for_workers(n_workers=len(dask_worker_devices)) - client.run(enable_spilling) - print("Dask Cluster Setup Complete") - del client - return cluster - - -def create_dask_client(scheduler_address): - from cugraph.dask.comms import comms as Comms - - client = Client(scheduler_address) - Comms.initialize(p2p=True) - return client - - -def initalize_pytorch_worker(dev_id): - import cupy as cp - import rmm - from rmm.allocators.torch import rmm_torch_allocator - from rmm.allocators.cupy import rmm_cupy_allocator - - dev = cp.cuda.Device( - dev_id - ) # Create cuda context on the right gpu, defaults to gpu-0 - dev.use() - rmm.reinitialize( - pool_allocator=True, - initial_pool_size=10e9, - maximum_pool_size=15e9, - devices=[dev_id], - ) - - if dev_id == 0: - torch.cuda.memory.change_current_allocator(rmm_torch_allocator) - - torch.cuda.set_device(dev_id) - cp.cuda.set_allocator(rmm_cupy_allocator) - enable_spilling() - print("device_id", dev_id, flush=True) - - -def load_dgl_dataset(dataset_name="ogbn-products"): - from ogb.nodeproppred import DglNodePropPredDataset - - dataset = DglNodePropPredDataset(name=dataset_name) - split_idx = dataset.get_idx_split() - train_idx, valid_idx, test_idx = ( - split_idx["train"], - split_idx["valid"], - split_idx["test"], - ) - g, label = dataset[0] - g.ndata["label"] = label - if len(g.etypes) <= 1: - g = dgl.add_self_loop(g) - else: - for etype in g.etypes: - if etype[0] == etype[2]: - # only add self loops for src->dst - g = dgl.add_self_loop(g, etype=etype) - - g = g.int() - train_idx = train_idx.int() - valid_idx = valid_idx.int() - test_idx = test_idx.int() - return g, train_idx, valid_idx, test_idx, dataset.num_classes - - -def create_cugraph_graphstore_from_dgl_dataset( - dataset_name="ogbn-products", single_gpu=False -): - from cugraph_dgl import cugraph_storage_from_heterograph - - dgl_g, train_idx, valid_idx, test_idx, num_classes = load_dgl_dataset(dataset_name) - cugraph_gs = cugraph_storage_from_heterograph(dgl_g, single_gpu=single_gpu) - return cugraph_gs, train_idx, valid_idx, test_idx, num_classes - - -def create_dataloader(gs, train_idx, device): - import cugraph_dgl - - temp_dir = tempfile.TemporaryDirectory() - sampler = cugraph_dgl.dataloading.NeighborSampler([10, 20]) - dataloader = cugraph_dgl.dataloading.DataLoader( - gs, - train_idx, - sampler, - sampling_output_dir=temp_dir.name, - batches_per_partition=10, - device=device, # Put the sampled MFGs on CPU or GPU - use_ddp=True, # Make it work with distributed data parallel - batch_size=1024, - shuffle=False, # Whether to shuffle the nodes for every epoch - drop_last=False, - num_workers=0, - ) - return dataloader - - -def run_workflow(rank, devices, scheduler_address): - from model import Sage, train_model - - # Below sets gpu_number - dev_id = devices[rank] - initalize_pytorch_worker(dev_id) - device = torch.device(f"cuda:{dev_id}") - # cugraph dask client initialization - client = create_dask_client(scheduler_address) - - # Pytorch training worker initialization - dist_init_method = "tcp://{master_ip}:{master_port}".format( - master_ip="127.0.0.1", master_port="12346" - ) - - torch.distributed.init_process_group( - backend="nccl", - init_method=dist_init_method, - world_size=len(devices), - rank=rank, - ) - - print(f"rank {rank}.", flush=True) - print("Initalized across GPUs.") - - event = Dask_Event("cugraph_gs_creation_event") - if rank == 0: - ( - gs, - train_idx, - valid_idx, - test_idx, - num_classes, - ) = create_cugraph_graphstore_from_dgl_dataset( - "ogbn-products", single_gpu=False - ) - client.publish_dataset(cugraph_gs=gs) - client.publish_dataset(train_idx=train_idx) - client.publish_dataset(valid_idx=valid_idx) - client.publish_dataset(test_idx=test_idx) - client.publish_dataset(num_classes=num_classes) - event.set() - else: - if event.wait(timeout=1000): - gs = client.get_dataset("cugraph_gs") - train_idx = client.get_dataset("train_idx") - valid_idx = client.get_dataset("valid_idx") - test_idx = client.get_dataset("test_idx") - num_classes = client.get_dataset("num_classes") - else: - raise RuntimeError(f"Fetch cugraph_gs to worker_id {rank} failed") - - torch.distributed.barrier() - print(f"Loading cugraph_store to worker {rank} is complete", flush=True) - dataloader = create_dataloader(gs, train_idx, device) - print("Data Loading Complete", flush=True) - num_feats = gs.ndata["feat"]["_N"].shape[1] - hid_size = 256 - # Load Training example - model = Sage(num_feats, hid_size, num_classes).to(device) - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[device], - output_device=device, - ) - torch.distributed.barrier() - n_epochs = 10 - total_st = time.time() - opt = torch.optim.Adam(model.parameters(), lr=0.01) - train_model(model, gs, opt, dataloader, n_epochs, rank, valid_idx) - torch.distributed.barrier() - total_et = time.time() - print( - f"Total time taken on n_epochs {n_epochs} = {total_et-total_st} s", - f"measured by worker = {rank}", - ) - - # cleanup dask cluster - if rank == 0: - client.unpublish_dataset("cugraph_gs") - client.unpublish_dataset("train_idx") - client.unpublish_dataset("valid_idx") - client.unpublish_dataset("test_idx") - event.clear() - print("Workflow completed") - print("---" * 10) - Comms.destroy() - - -if __name__ == "__main__": - # Load dummy first - # because new environments - # require dataset download - load_dgl_dataset() - dask_worker_devices = [5, 6] - cluster = setup_cluster(dask_worker_devices) - - trainer_devices = [0, 1, 2] - import torch.multiprocessing as mp - - mp.spawn( - run_workflow, - args=(trainer_devices, cluster.scheduler_address), - nprocs=len(trainer_devices), - ) - Comms.destroy() - cluster.close() diff --git a/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_mnmg.py b/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_mnmg.py new file mode 100644 index 00000000000..b1878b37d4e --- /dev/null +++ b/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_mnmg.py @@ -0,0 +1,311 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dgl +import torch +import time +import tempfile +import argparse +import json +import os +import warnings + +from datetime import timedelta + +import cugraph_dgl + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ["CUDF_SPILL"] = "1" + + +def init_ddp_worker(global_rank, local_rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=local_rank, + managed_memory=True, + pool_allocator=True, + ) + + import cupy + + cupy.cuda.Device(local_rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(local_rank) + + cugraph_comms_init( + rank=global_rank, world_size=world_size, uid=cugraph_id, device=local_rank + ) + + wm_init(global_rank, world_size, local_rank, torch.cuda.device_count()) + + +def load_dgl_dataset(dataset_root="dataset", dataset_name="ogbn-products"): + from ogb.nodeproppred import DglNodePropPredDataset + + dataset = DglNodePropPredDataset(root=dataset_root, name=dataset_name) + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = ( + split_idx["train"], + split_idx["valid"], + split_idx["test"], + ) + g, label = dataset[0] + g.ndata["label"] = label + if len(g.etypes) <= 1: + g = dgl.add_self_loop(g) + else: + for etype in g.etypes: + if etype[0] == etype[2]: + # only add self loops for src->dst + g = dgl.add_self_loop(g, etype=etype) + + g = g.int() + idx = { + "train": train_idx.int(), + "valid": valid_idx.int(), + "test": test_idx.int(), + } + + return g, idx, dataset.num_classes + + +def partition_data( + g, split_idx, num_classes, edge_path, feature_path, label_path, meta_path +): + # Split and save edge index + os.makedirs( + edge_path, + exist_ok=True, + ) + src, dst = g.all_edges(form="uv", order="eid") + edge_index = torch.stack([src, dst]) + for (r, e) in enumerate(torch.tensor_split(edge_index, world_size, dim=1)): + rank_path = os.path.join(edge_path, f"rank={r}.pt") + torch.save( + e.clone(), + rank_path, + ) + + # Split and save features + os.makedirs( + feature_path, + exist_ok=True, + ) + + nix = torch.arange(g.num_nodes()) + for (r, f) in enumerate(torch.tensor_split(nix, world_size)): + feat_path = os.path.join(feature_path, f"rank={r}_feat.pt") + torch.save(g.ndata["feat"][f], feat_path) + + label_f_path = os.path.join(feature_path, f"rank={r}_label.pt") + torch.save(g.ndata["label"][f], label_f_path) + + # Split and save labels + os.makedirs( + label_path, + exist_ok=True, + ) + for (d, i) in split_idx.items(): + i_parts = torch.tensor_split(i, world_size) + for r, i_part in enumerate(i_parts): + rank_path = os.path.join(label_path, f"rank={r}") + os.makedirs(rank_path, exist_ok=True) + torch.save(i_part, os.path.join(rank_path, f"{d}.pt")) + + # Save metadata + meta = { + "num_classes": int(num_classes), + "num_nodes": int(g.num_nodes()), + } + with open(meta_path, "w") as f: + json.dump(meta, f) + + +def load_partitioned_data(rank, edge_path, feature_path, label_path, meta_path): + g = cugraph_dgl.Graph( + is_multi_gpu=True, ndata_storage="wholegraph", edata_storage="wholegraph" + ) + + # Load metadata + with open(meta_path, "r") as f: + meta = json.load(f) + + # Load labels + split_idx = {} + for split in ["train", "test", "valid"]: + split_idx[split] = torch.load( + os.path.join(label_path, f"rank={rank}", f"{split}.pt") + ) + + # Load features + feat_t = torch.load(os.path.join(feature_path, f"rank={rank}_feat.pt")) + label_f_t = torch.load(os.path.join(feature_path, f"rank={rank}_label.pt")) + ndata = {"feat": feat_t, "label": label_f_t} + g.add_nodes(meta["num_nodes"], data=ndata) + + # Load edge index + src, dst = torch.load(os.path.join(edge_path, f"rank={rank}.pt")) + g.add_edges(src.cuda(), dst.cuda(), data=None) + + return g, split_idx, meta["num_classes"] + + +def create_dataloader(gs, train_idx, device, temp_dir, stage): + import cugraph_dgl + + temp_path = os.path.join(temp_dir, f"{stage}_{device}") + os.mkdir(temp_path) + + sampler = cugraph_dgl.dataloading.NeighborSampler( + [10, 20], + directory=temp_path, + batches_per_partition=10, + ) + + dataloader = cugraph_dgl.dataloading.FutureDataLoader( + gs, + train_idx, + sampler, + device=device, # Put the sampled MFGs on CPU or GPU + use_ddp=True, # Make it work with distributed data parallel + batch_size=1024, + shuffle=False, # Whether to shuffle the nodes for every epoch + drop_last=False, + num_workers=0, + ) + return dataloader + + +def run_workflow( + global_rank, local_rank, world_size, g, split_idx, num_classes, temp_dir +): + from model import Sage, train_model + + # Below sets gpu_number + dev_id = local_rank + device = torch.device(f"cuda:{dev_id}") + + dataloader = create_dataloader(g, split_idx["train"], device, temp_dir, "train") + print("Dataloader Creation Complete", flush=True) + num_feats = g.ndata["feat"].shape[1] + hid_size = 256 + # Load Training example + model = Sage(num_feats, hid_size, num_classes).to(device) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device], + output_device=device, + ) + torch.distributed.barrier() + n_epochs = 10 + total_st = time.time() + opt = torch.optim.Adam(model.parameters(), lr=0.01) + train_model(model, g, opt, dataloader, n_epochs, global_rank, split_idx["valid"]) + torch.distributed.barrier() + total_et = time.time() + print( + f"Total time taken on n_epochs {n_epochs} = {total_et-total_st} s", + f"measured by worker = {global_rank}", + ) + + wm_finalize() + cugraph_comms_shutdown() + + +if __name__ == "__main__": + if "LOCAL_RANK" in os.environ: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--tempdir_root", type=str, default=None) + parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--skip_partition", action="store_true") + args = parser.parse_args() + + torch.distributed.init_process_group( + "nccl", + timeout=timedelta(minutes=60), + ) + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(local_rank) + + # Create the uid needed for cuGraph comms + if global_rank == 0: + cugraph_id = [cugraph_comms_create_unique_id()] + else: + cugraph_id = [None] + torch.distributed.broadcast_object_list(cugraph_id, src=0, device=device) + cugraph_id = cugraph_id[0] + + init_ddp_worker(global_rank, local_rank, world_size, cugraph_id) + + # Split the data + edge_path = os.path.join(args.dataset_root, args.dataset + "_eix_part") + feature_path = os.path.join(args.dataset_root, args.dataset + "_fea_part") + label_path = os.path.join(args.dataset_root, args.dataset + "_label_part") + meta_path = os.path.join(args.dataset_root, args.dataset + "_meta.json") + + if not args.skip_partition and global_rank == 0: + partition_data( + *load_dgl_dataset(args.dataset_root, args.dataset), + edge_path, + feature_path, + label_path, + meta_path, + ) + torch.distributed.barrier() + + print("loading partitions...") + g, split_idx, num_classes = load_partitioned_data( + rank=global_rank, + edge_path=edge_path, + feature_path=feature_path, + label_path=label_path, + meta_path=meta_path, + ) + print(f"rank {global_rank} has loaded its partition") + torch.distributed.barrier() + + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as directory: + run_workflow( + global_rank, + local_rank, + world_size, + g, + split_idx, + num_classes, + directory, + ) + else: + warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_snmg.py b/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_snmg.py new file mode 100644 index 00000000000..da5c2b4d64e --- /dev/null +++ b/python/cugraph-dgl/examples/multi_trainer_MG_example/workflow_snmg.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dgl +import torch +import time +import tempfile +import argparse +import os + +import cugraph_dgl + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ["CUDF_SPILL"] = "1" + + +def initalize_pytorch_worker(dev_id): + import cupy as cp + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + dev = cp.cuda.Device( + dev_id + ) # Create cuda context on the right gpu, defaults to gpu-0 + dev.use() + rmm.reinitialize( + pool_allocator=True, + initial_pool_size=10e9, + maximum_pool_size=15e9, + devices=[dev_id], + ) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(dev_id) + cp.cuda.set_allocator(rmm_cupy_allocator) + print("device_id", dev_id, flush=True) + + +def load_dgl_dataset( + dataset_name="ogbn-products", + dataset_root=None, +): + from ogb.nodeproppred import DglNodePropPredDataset + + dataset = DglNodePropPredDataset(name=dataset_name, root=dataset_root) + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = ( + split_idx["train"], + split_idx["valid"], + split_idx["test"], + ) + g, label = dataset[0] + g.ndata["label"] = label + if len(g.etypes) <= 1: + g = dgl.add_self_loop(g) + else: + for etype in g.etypes: + if etype[0] == etype[2]: + # only add self loops for src->dst + g = dgl.add_self_loop(g, etype=etype) + + g = g.int() + train_idx = train_idx.int() + valid_idx = valid_idx.int() + test_idx = test_idx.int() + return g, train_idx, valid_idx, test_idx, dataset.num_classes + + +def create_cugraph_graphstore_from_dgl_dataset(dataset, rank, world_size): + (g, train_idx, valid_idx, test_idx, num_classes) = dataset + # Partition the data + cg = cugraph_dgl.Graph( + is_multi_gpu=True, ndata_storage="wholegraph", edata_storage="wholegraph" + ) + + nix = torch.tensor_split(torch.arange(g.num_nodes()), world_size)[rank] + ndata = {k: g.ndata[k][nix].cuda() for k in g.ndata.keys()} + + eix = torch.tensor_split(torch.arange(g.num_edges()), world_size)[rank] + src, dst = g.all_edges(form="uv", order="eid") + edata = {k: g.edata[k][eix].cuda() for k in g.edata.keys()} + + cg.add_nodes(g.num_nodes(), data=ndata) + cg.add_edges( + torch.tensor_split(src, world_size)[rank].cuda(), + torch.tensor_split(dst, world_size)[rank].cuda(), + data=edata, + ) + + return ( + cg, + torch.tensor_split(train_idx, world_size)[rank].to(torch.int64), + torch.tensor_split(valid_idx, world_size)[rank].to(torch.int64), + torch.tensor_split(test_idx, world_size)[rank].to(torch.int64), + num_classes, + ) + + +def create_dataloader(gs, train_idx, device, temp_dir, stage): + import cugraph_dgl + + temp_path = os.path.join(temp_dir, f"{stage}_{device}") + os.mkdir(temp_path) + + sampler = cugraph_dgl.dataloading.NeighborSampler( + [10, 20], + directory=temp_path, + batches_per_partition=10, + ) + dataloader = cugraph_dgl.dataloading.FutureDataLoader( + gs, + train_idx, + sampler, + device=device, # Put the sampled MFGs on CPU or GPU + use_ddp=True, # Make it work with distributed data parallel + batch_size=1024, + shuffle=False, # Whether to shuffle the nodes for every epoch + drop_last=False, + num_workers=0, + ) + return dataloader + + +def run_workflow(rank, world_size, cugraph_id, dataset, temp_dir): + from model import Sage, train_model + + # Below sets gpu_number + dev_id = rank + initalize_pytorch_worker(dev_id) + device = torch.device(f"cuda:{dev_id}") + + # Pytorch training worker initialization + dist_init_method = "tcp://{master_ip}:{master_port}".format( + master_ip="127.0.0.1", master_port="12346" + ) + + torch.distributed.init_process_group( + backend="nccl", + init_method=dist_init_method, + world_size=world_size, + rank=rank, + ) + + cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank) + wm_init(rank, world_size, rank, world_size) + + print(f"rank {rank}.", flush=True) + print("Initalized across GPUs.") + + ( + gs, + train_idx, + valid_idx, + test_idx, + num_classes, + ) = create_cugraph_graphstore_from_dgl_dataset( + dataset, + rank, + world_size, + ) + del dataset + + torch.distributed.barrier() + print(f"Loading graph to worker {rank} is complete", flush=True) + + dataloader = create_dataloader(gs, train_idx, device, temp_dir, "train") + print("Dataloader Creation Complete", flush=True) + num_feats = gs.ndata["feat"].shape[1] + hid_size = 256 + # Load Training example + model = Sage(num_feats, hid_size, num_classes).to(device) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device], + output_device=device, + ) + torch.distributed.barrier() + n_epochs = 10 + total_st = time.time() + opt = torch.optim.Adam(model.parameters(), lr=0.01) + train_model(model, gs, opt, dataloader, n_epochs, rank, valid_idx) + torch.distributed.barrier() + total_et = time.time() + print( + f"Total time taken on n_epochs {n_epochs} = {total_et-total_st} s", + f"measured by worker = {rank}", + ) + + torch.cuda.synchronize() + wm_finalize() + cugraph_comms_shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--tempdir_root", type=str, default=None) + parser.add_argument("--dataset", type=str, default="ogbn-products") + args = parser.parse_args() + + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + # Create the uid needed for cuGraph comms + cugraph_id = cugraph_comms_create_unique_id() + + ds = load_dgl_dataset(args.dataset, args.dataset_root) + + world_size = torch.cuda.device_count() + + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as directory: + torch.multiprocessing.spawn( + run_workflow, + args=(world_size, cugraph_id, ds, directory), + nprocs=world_size, + ) diff --git a/python/cugraph/cugraph/structure/hypergraph.py b/python/cugraph/cugraph/structure/hypergraph.py index b52fef4dcfc..bdc98333da0 100644 --- a/python/cugraph/cugraph/structure/hypergraph.py +++ b/python/cugraph/cugraph/structure/hypergraph.py @@ -580,14 +580,16 @@ def _create_direct_edges( def _str_scalar_to_category(size, val): - return cudf.core.column.build_categorical_column( - categories=cudf.core.column.as_column([val], dtype="str"), - codes=cudf.core.column.as_column(0, length=size, dtype=np.int32), - mask=None, + return cudf.core.column.CategoricalColumn( + data=None, size=size, + dtype=cudf.CategoricalDtype( + categories=cudf.core.column.as_column([val], dtype="str"), ordered=False + ), + mask=None, offset=0, null_count=0, - ordered=False, + children=(cudf.core.column.as_column(0, length=size, dtype=np.int32),), ) diff --git a/python/nx-cugraph/lint.yaml b/python/nx-cugraph/lint.yaml index ce46360e234..b2184a185c4 100644 --- a/python/nx-cugraph/lint.yaml +++ b/python/nx-cugraph/lint.yaml @@ -43,7 +43,7 @@ repos: rev: v3.16.0 hooks: - id: pyupgrade - args: [--py39-plus] + args: [--py310-plus] - repo: https://github.com/psf/black rev: 24.4.2 hooks: