diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ccfdb826812..0f490283795 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -22,7 +22,7 @@ on: default: nightly concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} cancel-in-progress: true jobs: diff --git a/README.md b/README.md index 560d1483242..8026e4feb64 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,18 @@ ----- +## News -## Table of content +___NEW!___ _[nx-cugraph](./python/nx-cugraph/README.md)_, a NetworkX backend that provides GPU acceleration to NetworkX with zero code change. +``` +> pip install nx-cugraph-cu11 --extra-index-url https://pypi.nvidia.com +> export NETWORKX_AUTOMATIC_BACKENDS=cugraph +``` +That's it. NetworkX now leverages cuGraph for accelerated graph algorithms. + +----- + +## Table of contents - Installation - [Getting cuGraph Packages](./docs/cugraph/source/installation/getting_cugraph.md) - [Building from Source](./docs/cugraph/source/installation/source_build.md) @@ -52,6 +62,7 @@ - [External Data Types](./readme_pages/data_types.md) - [pylibcugraph](./readme_pages/pylibcugraph.md) - [libcugraph (C/C++/CUDA)](./readme_pages/libcugraph.md) + - [nx-cugraph](./python/nx-cugraph/README.md) - [cugraph-service](./readme_pages/cugraph_service.md) - [cugraph-dgl](./readme_pages/cugraph_dgl.md) - [cugraph-ops](./readme_pages/cugraph_ops.md) @@ -116,6 +127,7 @@ df_page.sort_values('pagerank', ascending=False).head(10) * ArangoDB - a free and open-source native multi-model database system - https://www.arangodb.com/ * CuPy - "NumPy/SciPy-compatible Array Library for GPU-accelerated Computing with Python" - https://cupy.dev/ * Memgraph - In-memory Graph database - https://memgraph.com/ +* NetworkX (via [nx-cugraph](./python/nx-cugraph/README.md) backend) - an extremely popular, free and open-source package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks - https://networkx.org/ * PyGraphistry - free and open-source GPU graph ETL, AI, and visualization, including native RAPIDS & cuGraph support - http://github.com/graphistry/pygraphistry * ScanPy - a scalable toolkit for analyzing single-cell gene expression data - https://scanpy.readthedocs.io/en/stable/ diff --git a/benchmarks/cugraph-dgl/scale-benchmarks/cugraph_dgl_benchmark.py b/benchmarks/cugraph-dgl/scale-benchmarks/cugraph_dgl_benchmark.py new file mode 100644 index 00000000000..85f43b97b90 --- /dev/null +++ b/benchmarks/cugraph-dgl/scale-benchmarks/cugraph_dgl_benchmark.py @@ -0,0 +1,152 @@ +# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO" +os.environ["KVIKIO_NTHREADS"] = "64" +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +import json +import pandas as pd +import os +import time +from rmm.allocators.torch import rmm_torch_allocator +import rmm +import torch +from cugraph_dgl.dataloading import HomogenousBulkSamplerDataset +from model import run_1_epoch +from argparse import ArgumentParser +from load_graph_feats import load_node_labels, load_node_features + + +def create_dataloader(sampled_dir, total_num_nodes, sparse_format, return_type): + print("Creating dataloader", flush=True) + st = time.time() + dataset = HomogenousBulkSamplerDataset( + total_num_nodes, + edge_dir="in", + sparse_format=sparse_format, + return_type=return_type, + ) + + dataset.set_input_files(sampled_dir) + dataloader = torch.utils.data.DataLoader( + dataset, collate_fn=lambda x: x, shuffle=False, num_workers=0, batch_size=None + ) + et = time.time() + print(f"Time to create dataloader = {et - st:.2f} seconds", flush=True) + return dataloader + + +def setup_common_pool(): + rmm.reinitialize(initial_pool_size=5e9, pool_allocator=True) + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + +def main(args): + print( + f"Running cugraph-dgl dataloading benchmark with the following parameters:\n" + f"Dataset path = {args.dataset_path}\n" + f"Sampling path = {args.sampling_path}\n" + ) + with open(os.path.join(args.dataset_path, "meta.json"), "r") as f: + input_meta = json.load(f) + + sampled_dirs = [ + os.path.join(args.sampling_path, f) for f in os.listdir(args.sampling_path) + ] + + time_ls = [] + for sampled_dir in sampled_dirs: + with open(os.path.join(sampled_dir, "output_meta.json"), "r") as f: + sampled_meta_d = json.load(f) + + replication_factor = sampled_meta_d["replication_factor"] + feat_load_st = time.time() + label_data = load_node_labels( + args.dataset_path, replication_factor, input_meta + )["paper"]["y"] + feat_data = feat_data = load_node_features( + args.dataset_path, replication_factor, node_type="paper" + ) + print( + f"Feature and label data loading took = {time.time()-feat_load_st}", + flush=True, + ) + + r_time_ls = e2e_benchmark(sampled_dir, feat_data, label_data, sampled_meta_d) + [x.update({"replication_factor": replication_factor}) for x in r_time_ls] + [x.update({"num_edges": sampled_meta_d["total_num_edges"]}) for x in r_time_ls] + time_ls.extend(r_time_ls) + + print( + f"Benchmark completed for replication factor = {replication_factor}\n{'=' * 30}", + flush=True, + ) + + df = pd.DataFrame(time_ls) + df.to_csv("cugraph_dgl_e2e_benchmark.csv", index=False) + print(f"Benchmark completed for all replication factors\n{'=' * 30}", flush=True) + + +def e2e_benchmark( + sampled_dir: str, feat: torch.Tensor, y: torch.Tensor, sampled_meta_d: dict +): + """ + Run the e2e_benchmark + Args: + sampled_dir: directory containing the sampled graph + feat: node features + y: node labels + sampled_meta_d: dictionary containing the sampled graph metadata + """ + time_ls = [] + + # TODO: Make this a parameter in bulk sampling script + sampled_meta_d["sparse_format"] = "csc" + sampled_dir = os.path.join(sampled_dir, "samples") + dataloader = create_dataloader( + sampled_dir, + sampled_meta_d["total_num_nodes"], + sampled_meta_d["sparse_format"], + return_type="cugraph_dgl.nn.SparseGraph", + ) + time_d = run_1_epoch( + dataloader, + feat, + y, + fanout=sampled_meta_d["fanout"], + batch_size=sampled_meta_d["batch_size"], + model_backend="cugraph_dgl", + ) + time_ls.append(time_d) + print("=" * 30) + return time_ls + + +def parse_arguments(): + parser = ArgumentParser() + parser.add_argument( + "--dataset_path", type=str, default="/raid/vjawa/ogbn_papers100M/" + ) + parser.add_argument( + "--sampling_path", + type=str, + default="/raid/vjawa/nov_1_bulksampling_benchmarks/", + ) + return parser.parse_args() + + +if __name__ == "__main__": + setup_common_pool() + arguments = parse_arguments() + main(arguments) diff --git a/benchmarks/cugraph-dgl/scale-benchmarks/model.py b/benchmarks/cugraph-dgl/scale-benchmarks/model.py index 08ae0e8b1ee..9a9dfe58f96 100644 --- a/benchmarks/cugraph-dgl/scale-benchmarks/model.py +++ b/benchmarks/cugraph-dgl/scale-benchmarks/model.py @@ -57,11 +57,11 @@ def create_model(feat_size, num_classes, num_layers, model_backend="dgl"): def train_model(model, dataloader, opt, feat, y): - times = {key: 0 for key in ["mfg_creation", "feature", "m_fwd", "m_bkwd"]} + times_d = {key: 0 for key in ["mfg_creation", "feature", "m_fwd", "m_bkwd"]} epoch_st = time.time() mfg_st = time.time() for input_nodes, output_nodes, blocks in dataloader: - times["mfg_creation"] += time.time() - mfg_st + times_d["mfg_creation"] += time.time() - mfg_st if feat is not None: fst = time.time() input_nodes = input_nodes.to("cpu") @@ -71,23 +71,24 @@ def train_model(model, dataloader, opt, feat, y): output_nodes = output_nodes["paper"] output_nodes = output_nodes.to(y.device) y_batch = y[output_nodes].to("cuda") - times["feature"] += time.time() - fst + times_d["feature"] += time.time() - fst m_fwd_st = time.time() y_hat = model(blocks, input_feat) - times["m_fwd"] += time.time() - m_fwd_st + times_d["m_fwd"] += time.time() - m_fwd_st m_bkwd_st = time.time() loss = F.cross_entropy(y_hat, y_batch) opt.zero_grad() loss.backward() opt.step() - times["m_bkwd"] += time.time() - m_bkwd_st + times_d["m_bkwd"] += time.time() - m_bkwd_st mfg_st = time.time() print(f"Epoch time = {time.time() - epoch_st:.2f} seconds") + print(f"Time to create MFG = {times_d['mfg_creation']:.2f} seconds") - return times + return times_d def analyze_time(dataloader, times, epoch_time, fanout, batch_size): @@ -119,6 +120,10 @@ def run_1_epoch(dataloader, feat, y, fanout, batch_size, model_backend): else: model = None opt = None + + # Warmup RUN + times = train_model(model, dataloader, opt, feat, y) + epoch_st = time.time() times = train_model(model, dataloader, opt, feat, y) epoch_time = time.time() - epoch_st diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index a8c0658767d..1ca5d6db637 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -22,7 +22,6 @@ get_allocation_counts_dask_lazy, sizeof_fmt, get_peak_output_ratio_across_workers, - restart_client, start_dask_client, stop_dask_client, enable_spilling, @@ -187,10 +186,10 @@ def sample_graph( output_path, seed=42, batch_size=500, - seeds_per_call=200000, + seeds_per_call=400000, batches_per_partition=100, fanout=[5, 5, 5], - persist=False, + sampling_kwargs={}, ): cupy.random.seed(seed) @@ -204,6 +203,7 @@ def sample_graph( seeds_per_call=seeds_per_call, batches_per_partition=batches_per_partition, log_level=logging.INFO, + **sampling_kwargs, ) n_workers = len(default_client().scheduler_info()["workers"]) @@ -469,6 +469,7 @@ def benchmark_cugraph_bulk_sampling( batch_size, seeds_per_call, fanout, + sampling_target_framework, reverse_edges=True, dataset_dir=".", replication_factor=1, @@ -564,17 +565,39 @@ def benchmark_cugraph_bulk_sampling( output_sample_path = os.path.join(output_subdir, "samples") os.makedirs(output_sample_path) - batches_per_partition = 200_000 // batch_size + if sampling_target_framework == "cugraph_dgl_csr": + sampling_kwargs = { + "deduplicate_sources": True, + "prior_sources_behavior": "carryover", + "renumber": True, + "compression": "CSR", + "compress_per_hop": True, + "use_legacy_names": False, + "include_hop_column": False, + } + else: + # FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02) + sampling_kwargs = { + "deduplicate_sources": True, + "prior_sources_behavior": "exclude", + "renumber": True, + "compression": "COO", + "compress_per_hop": False, + "use_legacy_names": False, + "include_hop_column": True, + } + + batches_per_partition = 400_000 // batch_size execution_time, allocation_counts = sample_graph( - G, - dask_label_df, - output_sample_path, + G=G, + label_df=dask_label_df, + output_path=output_sample_path, seed=seed, batch_size=batch_size, seeds_per_call=seeds_per_call, batches_per_partition=batches_per_partition, fanout=fanout, - persist=persist, + sampling_kwargs=sampling_kwargs, ) output_meta = { @@ -701,7 +724,13 @@ def get_args(): required=False, default=False, ) - + parser.add_argument( + "--sampling_target_framework", + type=str, + help="The target framework for sampling (i.e. cugraph_dgl_csr, cugraph_pyg_csc, ...)", + required=False, + default=None, + ) parser.add_argument( "--dask_worker_devices", type=str, @@ -738,6 +767,12 @@ def get_args(): logging.basicConfig() args = get_args() + if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: + raise ValueError( + "sampling_target_framework must be one of cugraph_dgl_csr or None", + "Other frameworks are not supported at this time.", + ) + fanouts = [ [int(f) for f in fanout.split("_")] for fanout in args.fanouts.split(",") ] @@ -785,6 +820,7 @@ def get_args(): batch_size=batch_size, seeds_per_call=seeds_per_call, fanout=fanout, + sampling_target_framework=args.sampling_target_framework, dataset_dir=args.dataset_root, reverse_edges=args.reverse_edges, replication_factor=replication_factor, @@ -809,7 +845,6 @@ def get_args(): warnings.warn("An Exception Occurred!") print(e) traceback.print_exc() - restart_client(client) sleep(10) stats_df = pd.DataFrame( diff --git a/build.sh b/build.sh index 99082fa96fb..1723e750978 100755 --- a/build.sh +++ b/build.sh @@ -31,6 +31,7 @@ VALIDARGS=" cugraph-dgl nx-cugraph cpp-mgtests + cpp-mtmgtests docs all -v @@ -59,6 +60,7 @@ HELP="$0 [ ...] [ ...] cugraph-dgl - build the cugraph-dgl extensions for DGL nx-cugraph - build the nx-cugraph Python package cpp-mgtests - build libcugraph and libcugraph_etl MG tests. Builds MPI communicator, adding MPI as a dependency. + cpp-mtmgtests - build libcugraph MTMG tests. Adds UCX as a dependency (temporary). docs - build the docs all - build everything and is: @@ -105,6 +107,7 @@ BUILD_TYPE=Release INSTALL_TARGET="--target install" BUILD_CPP_TESTS=ON BUILD_CPP_MG_TESTS=OFF +BUILD_CPP_MTMG_TESTS=OFF BUILD_ALL_GPU_ARCH=0 BUILD_WITH_CUGRAPHOPS=ON CMAKE_GENERATOR_OPTION="-G Ninja" @@ -172,6 +175,9 @@ fi if hasArg --without_cugraphops; then BUILD_WITH_CUGRAPHOPS=OFF fi +if hasArg cpp-mtmgtests; then + BUILD_CPP_MTMG_TESTS=ON +fi if hasArg cpp-mgtests || hasArg all; then BUILD_CPP_MG_TESTS=ON fi @@ -264,6 +270,7 @@ if buildDefault || hasArg libcugraph || hasArg all; then -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DBUILD_TESTS=${BUILD_CPP_TESTS} \ -DBUILD_CUGRAPH_MG_TESTS=${BUILD_CPP_MG_TESTS} \ + -DBUILD_CUGRAPH_MTMG_TESTS=${BUILD_CPP_MTMG_TESTS} \ -DUSE_CUGRAPH_OPS=${BUILD_WITH_CUGRAPHOPS} \ ${CMAKE_GENERATOR_OPTION} \ ${CMAKE_VERBOSE_OPTION} @@ -294,6 +301,7 @@ if buildDefault || hasArg libcugraph_etl || hasArg all; then -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DBUILD_TESTS=${BUILD_CPP_TESTS} \ -DBUILD_CUGRAPH_MG_TESTS=${BUILD_CPP_MG_TESTS} \ + -DBUILD_CUGRAPH_MTMG_TESTS=${BUILD_CPP_MTMG_TESTS} \ -DCMAKE_PREFIX_PATH=${LIBCUGRAPH_BUILD_DIR} \ ${CMAKE_GENERATOR_OPTION} \ ${CMAKE_VERBOSE_OPTION} \ diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index c888c908056..163520ea1da 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -40,8 +40,11 @@ for dep in rmm cudf raft-dask pylibcugraph pylibraft ucx-py; do sed -r -i "s/${dep}==(.*)\"/${dep}${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} done -# dask-cuda doesn't get a suffix, but it does get an alpha spec. -sed -r -i "s/dask-cuda==(.*)\"/dask-cuda==\1${alpha_spec}\"/g" ${pyproject_file} +# dask-cuda & rapids-dask-dependency doesn't get a suffix, but it does get an alpha spec. +for dep in dask-cuda rapids-dask-dependency; do + sed -r -i "s/${dep}==(.*)\"/${dep}==\1${alpha_spec}\"/g" ${pyproject_file} +done + if [[ $PACKAGE_CUDA_SUFFIX == "-cu12" ]]; then sed -i "s/cupy-cuda11x/cupy-cuda12x/g" ${pyproject_file} diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index d3dbed6ae46..69eb085e7ed 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -88,10 +88,12 @@ DEPENDENCIES=( raft-dask rmm ucx-py + rapids-dask-dependency ) for DEP in "${DEPENDENCIES[@]}"; do - for FILE in dependencies.yaml conda/environments/*.yaml; do + for FILE in dependencies.yaml conda/environments/*.yaml python/cugraph-{pyg,dgl}/conda/*.yaml; do sed_runner "/-.* ${DEP}==/ s/==.*/==${NEXT_SHORT_TAG_PEP440}.*/g" ${FILE} + sed_runner "/-.* ${DEP}-cu[0-9][0-9]==/ s/==.*/==${NEXT_SHORT_TAG_PEP440}.*/g" ${FILE} sed_runner "/-.* ucx-py==/ s/==.*/==${NEXT_UCX_PY_VERSION}.*/g" ${FILE} done for FILE in python/**/pyproject.toml python/**/**/pyproject.toml; do @@ -108,6 +110,11 @@ sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}.*\"/}" cond sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}.*\"/}" conda/recipes/cugraph-service/conda_build_config.yaml sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}.*\"/}" conda/recipes/pylibcugraph/conda_build_config.yaml +# nx-cugraph NetworkX entry-point meta-data +sed_runner "s@branch-[0-9][0-9].[0-9][0-9]@branch-${NEXT_SHORT_TAG}@g" python/nx-cugraph/_nx_cugraph/__init__.py +# FIXME: can this use the standard VERSION file and update mechanism? +sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/nx-cugraph/_nx_cugraph/__init__.py + # CI files for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" diff --git a/ci/test_python.sh b/ci/test_python.sh index 1690ce2f15b..273d3c93482 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -197,27 +197,26 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then conda activate test_cugraph_pyg set -u - # Install pytorch + # Will automatically install built dependencies of cuGraph-PyG rapids-mamba-retry install \ - --force-reinstall \ - --channel pyg \ + --channel "${CPP_CHANNEL}" \ + --channel "${PYTHON_CHANNEL}" \ --channel pytorch \ --channel nvidia \ - 'pyg=2.3' \ - 'pytorch=2.0.0' \ - 'pytorch-cuda=11.8' + --channel pyg \ + --channel rapidsai-nightly \ + "cugraph-pyg" \ + "pytorch>=2.0,<2.1" \ + "pytorch-cuda=11.8" # Install pyg dependencies (which requires pip) - pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html - - rapids-mamba-retry install \ - --channel "${CPP_CHANNEL}" \ - --channel "${PYTHON_CHANNEL}" \ - libcugraph \ - pylibcugraph \ - pylibcugraphops \ - cugraph \ - cugraph-pyg + pip install \ + pyg_lib \ + torch_scatter \ + torch_sparse \ + torch_cluster \ + torch_spline_conv \ + -f https://data.pyg.org/whl/torch-2.0.0+cu118.html rapids-print-env diff --git a/ci/test_wheel_cugraph.sh b/ci/test_wheel_cugraph.sh index f9e2aa6d8da..d351ea21624 100755 --- a/ci/test_wheel_cugraph.sh +++ b/ci/test_wheel_cugraph.sh @@ -8,7 +8,4 @@ RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" RAPIDS_PY_WHEEL_NAME="pylibcugraph_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibcugraph-dep python -m pip install --no-deps ./local-pylibcugraph-dep/pylibcugraph*.whl -# Always install latest dask for testing -python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main - ./ci/test_wheel.sh cugraph python/cugraph diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 2f3a9c988cf..aa38defcd7c 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -20,11 +20,8 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-core>=2023.9.2 - dask-cuda==23.12.* - dask-cudf==23.12.* -- dask>=2023.7.1 -- distributed>=2023.7.1 - doxygen - fsspec>=0.6.0 - gcc_linux-64=11.* @@ -62,6 +59,7 @@ dependencies: - pytest-xdist - python-louvain - raft-dask==23.12.* +- rapids-dask-dependency==23.12.* - recommonmark - requests - rmm==23.12.* diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index 31ff503e682..a9f793b15f5 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -20,11 +20,8 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-core>=2023.9.2 - dask-cuda==23.12.* - dask-cudf==23.12.* -- dask>=2023.7.1 -- distributed>=2023.7.1 - doxygen - fsspec>=0.6.0 - gcc_linux-64=11.* @@ -61,6 +58,7 @@ dependencies: - pytest-xdist - python-louvain - raft-dask==23.12.* +- rapids-dask-dependency==23.12.* - recommonmark - requests - rmm==23.12.* diff --git a/conda/recipes/cugraph-dgl/meta.yaml b/conda/recipes/cugraph-dgl/meta.yaml index bb85734098a..aaa1cd8a936 100644 --- a/conda/recipes/cugraph-dgl/meta.yaml +++ b/conda/recipes/cugraph-dgl/meta.yaml @@ -26,7 +26,7 @@ requirements: - dgl >=1.1.0.cu* - numba >=0.57 - numpy >=1.21 - - pylibcugraphops ={{ version }} + - pylibcugraphops ={{ minor_version }} - python - pytorch diff --git a/conda/recipes/cugraph-pyg/meta.yaml b/conda/recipes/cugraph-pyg/meta.yaml index 2714dcfa55a..a2a02a1d9f6 100644 --- a/conda/recipes/cugraph-pyg/meta.yaml +++ b/conda/recipes/cugraph-pyg/meta.yaml @@ -26,15 +26,15 @@ requirements: - python - scikit-build >=0.13.1 run: - - distributed >=2023.9.2 + - rapids-dask-dependency ={{ minor_version }} - numba >=0.57 - numpy >=1.21 - python - pytorch >=2.0 - cupy >=12.0.0 - cugraph ={{ version }} - - pylibcugraphops ={{ version }} - - pyg >=2.3,<2.4 + - pylibcugraphops ={{ minor_version }} + - pyg >=2.3,<2.5 tests: imports: diff --git a/conda/recipes/cugraph-service/meta.yaml b/conda/recipes/cugraph-service/meta.yaml index ae8074ba7d3..d52a004db05 100644 --- a/conda/recipes/cugraph-service/meta.yaml +++ b/conda/recipes/cugraph-service/meta.yaml @@ -59,10 +59,10 @@ outputs: - cupy >=12.0.0 - dask-cuda ={{ minor_version }} - dask-cudf ={{ minor_version }} - - distributed >=2023.9.2 - numba >=0.57 - numpy >=1.21 - python + - rapids-dask-dependency ={{ minor_version }} - thriftpy2 >=0.4.15 - ucx-py {{ ucx_py_version }} diff --git a/conda/recipes/cugraph/meta.yaml b/conda/recipes/cugraph/meta.yaml index 65403bc8d73..58b9ea220d4 100644 --- a/conda/recipes/cugraph/meta.yaml +++ b/conda/recipes/cugraph/meta.yaml @@ -76,15 +76,13 @@ requirements: - cupy >=12.0.0 - dask-cuda ={{ minor_version }} - dask-cudf ={{ minor_version }} - - dask >=2023.9.2 - - dask-core >=2023.9.2 - - distributed >=2023.9.2 - fsspec>=0.6.0 - libcugraph ={{ version }} - pylibcugraph ={{ version }} - pylibraft ={{ minor_version }} - python - raft-dask ={{ minor_version }} + - rapids-dask-dependency ={{ minor_version }} - requests - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 41870cbc92b..3e867643041 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -153,6 +153,11 @@ rapids_cpm_init() # lags behind. ### +# Need to make sure rmm is found before cuco so that rmm patches the libcudacxx +# directory to be found by cuco. +include(${rapids-cmake-dir}/cpm/rmm.cmake) +rapids_cpm_rmm(BUILD_EXPORT_SET cugraph-exports + INSTALL_EXPORT_SET cugraph-exports) # Putting this before raft to override RAFT from pulling them in. include(cmake/thirdparty/get_libcudacxx.cmake) include(${rapids-cmake-dir}/cpm/cuco.cmake) @@ -166,7 +171,10 @@ endif() include(cmake/thirdparty/get_nccl.cmake) include(cmake/thirdparty/get_cuhornet.cmake) -include(cmake/thirdparty/get_ucp.cmake) + +if (BUILD_CUGRAPH_MTMG_TESTS) + include(cmake/thirdparty/get_ucp.cmake) +endif() if(BUILD_TESTS) include(cmake/thirdparty/get_gtest.cmake) diff --git a/cpp/src/centrality/eigenvector_centrality_impl.cuh b/cpp/src/centrality/eigenvector_centrality_impl.cuh index 291abf18455..8d1bea4004d 100644 --- a/cpp/src/centrality/eigenvector_centrality_impl.cuh +++ b/cpp/src/centrality/eigenvector_centrality_impl.cuh @@ -96,7 +96,8 @@ rmm::device_uvector eigenvector_centrality( centralities.end(), old_centralities.data()); - update_edge_src_property(handle, pull_graph_view, centralities.begin(), edge_src_centralities); + update_edge_src_property( + handle, pull_graph_view, old_centralities.begin(), edge_src_centralities); if (edge_weight_view) { per_v_transform_reduce_incoming_e( @@ -122,6 +123,13 @@ rmm::device_uvector eigenvector_centrality( centralities.begin()); } + thrust::transform(handle.get_thrust_policy(), + centralities.begin(), + centralities.end(), + old_centralities.begin(), + centralities.begin(), + thrust::plus()); + // Normalize the centralities auto hypotenuse = sqrt(transform_reduce_v( handle, diff --git a/cpp/src/link_analysis/hits_impl.cuh b/cpp/src/link_analysis/hits_impl.cuh index 9badb041218..674046745b1 100644 --- a/cpp/src/link_analysis/hits_impl.cuh +++ b/cpp/src/link_analysis/hits_impl.cuh @@ -112,7 +112,8 @@ std::tuple hits(raft::handle_t const& handle, prev_hubs + graph_view.local_vertex_partition_range_size(), result_t{1.0} / num_vertices); } - for (size_t iter = 0; iter < max_iterations; ++iter) { + size_t iter{0}; + while (true) { // Update current destination authorities property per_v_transform_reduce_incoming_e( handle, @@ -162,17 +163,19 @@ std::tuple hits(raft::handle_t const& handle, thrust::make_zip_iterator(thrust::make_tuple(curr_hubs, prev_hubs)), [] __device__(auto, auto val) { return std::abs(thrust::get<0>(val) - thrust::get<1>(val)); }, result_t{0}); - if (diff_sum < epsilon) { - final_iteration_count = iter; - std::swap(prev_hubs, curr_hubs); - break; - } update_edge_src_property(handle, graph_view, curr_hubs, prev_src_hubs); // Swap pointers for the next iteration // After this swap call, prev_hubs has the latest value of hubs std::swap(prev_hubs, curr_hubs); + iter++; + + if (diff_sum < epsilon) { + break; + } else if (iter >= max_iterations) { + CUGRAPH_FAIL("HITS failed to converge."); + } } if (normalize) { @@ -188,7 +191,7 @@ std::tuple hits(raft::handle_t const& handle, hubs); } - return std::make_tuple(diff_sum, final_iteration_count); + return std::make_tuple(diff_sum, iter); } } // namespace detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 2f69cf9cb0d..6530a25d178 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -415,13 +415,6 @@ ConfigureTest(K_HOP_NBRS_TEST traversal/k_hop_nbrs_test.cpp) # - install tests --------------------------------------------------------------------------------- rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing DESTINATION bin/gtests/libcugraph) -################################################################################################### -# - MTMG tests ------------------------------------------------------------------------- -ConfigureTest(MTMG_TEST mtmg/threaded_test.cu) -target_link_libraries(MTMG_TEST - PRIVATE - UCP::UCP - ) ################################################################################################### # - MG tests -------------------------------------------------------------------------------------- @@ -681,15 +674,6 @@ if(BUILD_CUGRAPH_MG_TESTS) rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing_mg DESTINATION bin/gtests/libcugraph_mg) - ############################################################################################### - # - Multi-node MTMG tests --------------------------------------------------------------------- - ConfigureTest(MTMG_MULTINODE_TEST mtmg/multi_node_threaded_test.cu utilities/mg_utilities.cpp) - target_link_libraries(MTMG_MULTINODE_TEST - PRIVATE - cugraphmgtestutil - UCP::UCP - ) - endif() ################################################################################################### @@ -749,4 +733,25 @@ ConfigureCTest(CAPI_EGONET_TEST c_api/egonet_test.c) ConfigureCTest(CAPI_TWO_HOP_NEIGHBORS_TEST c_api/two_hop_neighbors_test.c) ConfigureCTest(CAPI_LEGACY_K_TRUSS_TEST c_api/legacy_k_truss_test.c) +if (BUILD_CUGRAPH_MTMG_TESTS) + ################################################################################################### + # - MTMG tests ------------------------------------------------------------------------- + ConfigureTest(MTMG_TEST mtmg/threaded_test.cu) + target_link_libraries(MTMG_TEST + PRIVATE + UCP::UCP + ) + + if(BUILD_CUGRAPH_MG_TESTS) + ############################################################################################### + # - Multi-node MTMG tests --------------------------------------------------------------------- + ConfigureTest(MTMG_MULTINODE_TEST mtmg/multi_node_threaded_test.cu utilities/mg_utilities.cpp) + target_link_libraries(MTMG_MULTINODE_TEST + PRIVATE + cugraphmgtestutil + UCP::UCP + ) + endif(BUILD_CUGRAPH_MG_TESTS) +endif(BUILD_CUGRAPH_MTMG_TESTS) + rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing_c DESTINATION bin/gtests/libcugraph_c) diff --git a/cpp/tests/c_api/eigenvector_centrality_test.c b/cpp/tests/c_api/eigenvector_centrality_test.c index 9fd2d2bee6f..8bc5971a70c 100644 --- a/cpp/tests/c_api/eigenvector_centrality_test.c +++ b/cpp/tests/c_api/eigenvector_centrality_test.c @@ -109,11 +109,30 @@ int test_eigenvector_centrality() h_src, h_dst, h_wgt, h_result, num_vertices, num_edges, TRUE, epsilon, max_iterations); } +int test_eigenvector_centrality_3971() +{ + size_t num_edges = 4; + size_t num_vertices = 3; + + vertex_t h_src[] = {0, 1, 1, 2}; + vertex_t h_dst[] = {1, 0, 2, 1}; + weight_t h_wgt[] = {1.0f, 1.0f, 1.0f, 1.0f}; + weight_t h_result[] = {0.5, 0.707107, 0.5}; + + double epsilon = 1e-6; + size_t max_iterations = 1000; + + // Eigenvector centrality wants store_transposed = TRUE + return generic_eigenvector_centrality_test( + h_src, h_dst, h_wgt, h_result, num_vertices, num_edges, TRUE, epsilon, max_iterations); +} + /******************************************************************************/ int main(int argc, char** argv) { int result = 0; result |= RUN_TEST(test_eigenvector_centrality); + result |= RUN_TEST(test_eigenvector_centrality_3971); return result; } diff --git a/cpp/tests/centrality/eigenvector_centrality_test.cpp b/cpp/tests/centrality/eigenvector_centrality_test.cpp index 7cafcfbde85..6c3bd510abd 100644 --- a/cpp/tests/centrality/eigenvector_centrality_test.cpp +++ b/cpp/tests/centrality/eigenvector_centrality_test.cpp @@ -60,7 +60,6 @@ void eigenvector_centrality_reference(vertex_t const* src, size_t iter{0}; while (true) { std::copy(tmp_centralities.begin(), tmp_centralities.end(), old_centralities.begin()); - std::fill(tmp_centralities.begin(), tmp_centralities.end(), double{0}); for (size_t e = 0; e < num_edges; ++e) { auto w = weights ? (*weights)[e] : weight_t{1.0}; diff --git a/cpp/tests/link_analysis/hits_test.cpp b/cpp/tests/link_analysis/hits_test.cpp index 44fa619b503..d0e77769034 100644 --- a/cpp/tests/link_analysis/hits_test.cpp +++ b/cpp/tests/link_analysis/hits_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -176,7 +176,7 @@ class Tests_Hits : public ::testing::TestWithParam d_hubs(graph_view.local_vertex_partition_range_size(), handle.get_stream()); diff --git a/datasets/README.md b/datasets/README.md index e42413fc996..a23dc644081 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -120,9 +120,13 @@ The benchmark datasets are described below: | soc-twitter-2010 | 21,297,772 | 265,025,809 | No | No | **cit-Patents** : A citation graph that includes all citations made by patents granted between 1975 and 1999, totaling 16,522,438 citations. + **soc-LiveJournal** : A graph of the LiveJournal social network. + **europe_osm** : A graph of OpenStreetMap data for Europe. + **hollywood** : A graph of movie actors where vertices are actors, and two actors are joined by an edge whenever they appeared in a movie together. + **soc-twitter-2010** : A network of follower relationships from a snapshot of Twitter in 2010, where an edge from i to j indicates that j is a follower of i. _NOTE: the benchmark datasets were converted to a CSV format from their original format described in the reference URL below, and in doing so had edge weights and isolated vertices discarded._ diff --git a/dependencies.yaml b/dependencies.yaml index b127d9bd29e..a89acd9288b 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -377,15 +377,13 @@ dependencies: common: - output_types: [conda, pyproject] packages: - - &dask dask>=2023.7.1 - - &distributed distributed>=2023.7.1 + - &dask rapids-dask-dependency==23.12.* - &dask_cuda dask-cuda==23.12.* - &numba numba>=0.57 - &ucx_py ucx-py==0.35.* - output_types: conda packages: - aiohttp - - &dask-core_conda dask-core>=2023.9.2 - fsspec>=0.6.0 - libcudf==23.12.* - requests @@ -431,14 +429,10 @@ dependencies: packages: - *dask - *dask_cuda - - *distributed - *numba - *numpy - *thrift - *ucx_py - - output_types: conda - packages: - - *dask-core_conda - output_types: pyproject packages: - *cugraph @@ -503,9 +497,9 @@ dependencies: - output_types: [conda] packages: - cugraph==23.12.* - - pytorch==2.0 + - pytorch>=2.0 - pytorch-cuda==11.8 - - pyg=2.3.1=*torch_2.0.0*cu118* + - pyg>=2.4.0 depends_on_rmm: common: diff --git a/img/Stack2.png b/img/Stack2.png index 132e85c9d15..97f1979c2d8 100644 Binary files a/img/Stack2.png and b/img/Stack2.png differ diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py index a4f64668348..f674bece8be 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Dict, Optional from collections import defaultdict import cudf -import cupy from cugraph.utilities.utils import import_optional from cugraph_dgl.nn import SparseGraph @@ -444,53 +443,58 @@ def _process_sampled_df_csc( destinations, respectively. """ # dropna - major_offsets = df.major_offsets.dropna().values - label_hop_offsets = df.label_hop_offsets.dropna().values - renumber_map_offsets = df.renumber_map_offsets.dropna().values - renumber_map = df.map.dropna().values - minors = df.minors.dropna().values + major_offsets = cast_to_tensor(df.major_offsets.dropna()) + label_hop_offsets = cast_to_tensor(df.label_hop_offsets.dropna()) + renumber_map_offsets = cast_to_tensor(df.renumber_map_offsets.dropna()) + renumber_map = cast_to_tensor(df.map.dropna()) + minors = cast_to_tensor(df.minors.dropna()) - n_batches = renumber_map_offsets.size - 1 - n_hops = int((label_hop_offsets.size - 1) / n_batches) + n_batches = len(renumber_map_offsets) - 1 + n_hops = int((len(label_hop_offsets) - 1) / n_batches) # make global offsets local - major_offsets -= major_offsets[0] - label_hop_offsets -= label_hop_offsets[0] - renumber_map_offsets -= renumber_map_offsets[0] + # Have to make a clone as pytorch does not allow + # in-place operations on tensors + major_offsets -= major_offsets[0].clone() + label_hop_offsets -= label_hop_offsets[0].clone() + renumber_map_offsets -= renumber_map_offsets[0].clone() # get the sizes of each adjacency matrix (for MFGs) mfg_sizes = (label_hop_offsets[1:] - label_hop_offsets[:-1]).reshape( (n_batches, n_hops) ) n_nodes = renumber_map_offsets[1:] - renumber_map_offsets[:-1] - mfg_sizes = cupy.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1))) + mfg_sizes = torch.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1))) if reverse_hop_id: - mfg_sizes = mfg_sizes[:, ::-1] + mfg_sizes = mfg_sizes.flip(1) tensors_dict = {} renumber_map_list = [] + # Note: minors and major_offsets from BulkSampler are of type int32 + # and int64 respectively. Since pylibcugraphops binding code doesn't + # support distinct node and edge index type, we simply casting both + # to int32 for now. + minors = minors.int() + major_offsets = major_offsets.int() + # Note: We transfer tensors to CPU here to avoid the overhead of + # transferring them in each iteration of the for loop below. + major_offsets_cpu = major_offsets.to("cpu").numpy() + label_hop_offsets_cpu = label_hop_offsets.to("cpu").numpy() + for batch_id in range(n_batches): batch_dict = {} - for hop_id in range(n_hops): hop_dict = {} idx = batch_id * n_hops + hop_id # idx in label_hop_offsets - major_offsets_start = label_hop_offsets[idx].item() - major_offsets_end = label_hop_offsets[idx + 1].item() - minors_start = major_offsets[major_offsets_start].item() - minors_end = major_offsets[major_offsets_end].item() - # Note: minors and major_offsets from BulkSampler are of type int32 - # and int64 respectively. Since pylibcugraphops binding code doesn't - # support distinct node and edge index type, we simply casting both - # to int32 for now. - hop_dict["minors"] = torch.as_tensor( - minors[minors_start:minors_end], device="cuda" - ).int() - hop_dict["major_offsets"] = torch.as_tensor( + major_offsets_start = label_hop_offsets_cpu[idx] + major_offsets_end = label_hop_offsets_cpu[idx + 1] + minors_start = major_offsets_cpu[major_offsets_start] + minors_end = major_offsets_cpu[major_offsets_end] + hop_dict["minors"] = minors[minors_start:minors_end] + hop_dict["major_offsets"] = ( major_offsets[major_offsets_start : major_offsets_end + 1] - - major_offsets[major_offsets_start], - device="cuda", - ).int() + - major_offsets[major_offsets_start] + ) if reverse_hop_id: batch_dict[n_hops - 1 - hop_id] = hop_dict else: @@ -499,12 +503,9 @@ def _process_sampled_df_csc( tensors_dict[batch_id] = batch_dict renumber_map_list.append( - torch.as_tensor( - renumber_map[ - renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1] - ], - device="cuda", - ) + renumber_map[ + renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1] + ], ) return tensors_dict, renumber_map_list, mfg_sizes.tolist() diff --git a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml index f98eab430ba..71d1c7e389c 100644 --- a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml +++ b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml @@ -13,13 +13,13 @@ dependencies: - cugraph==23.12.* - pandas - pre-commit -- pyg=2.3.1=*torch_2.0.0*cu118* +- pyg>=2.4.0 - pylibcugraphops==23.12.* - pytest - pytest-benchmark - pytest-cov - pytest-xdist - pytorch-cuda==11.8 -- pytorch==2.0 +- pytorch>=2.0 - scipy name: cugraph_pyg_dev_cuda-118 diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index 6192cd621d5..edeeface4c4 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -210,11 +210,14 @@ class EXPERIMENTAL__CuGraphStore: def __init__( self, F: cugraph.gnn.FeatureStore, - G: Union[Dict[str, Tuple[TensorType]], Dict[str, int]], + G: Union[ + Dict[Tuple[str, str, str], Tuple[TensorType]], + Dict[Tuple[str, str, str], int], + ], num_nodes_dict: Dict[str, int], *, multi_gpu: bool = False, - order: str = "CSC", + order: str = "CSR", ): """ Constructs a new CuGraphStore from the provided @@ -260,11 +263,11 @@ def __init__( Whether the store should be backed by a multi-GPU graph. Requires dask to have been set up. - order: str (Optional ["CSR", "CSC"], default = CSC) - The order to use for sampling. Should nearly always be CSC - unless there is a specific expectation of "reverse" sampling. - It is also not uncommon to use CSR order for correctness - testing, which some cuGraph-PyG tests do. + order: str (Optional ["CSR", "CSC"], default = CSR) + The order to use for sampling. CSR corresponds to the + standard OGB dataset order that is usually used in PyG. + CSC order constructs the same graph as CSR, but with + edges in the opposite direction. """ if None in G: @@ -744,7 +747,7 @@ def _subgraph(self, edge_types: List[tuple] = None) -> cugraph.MultiGraph: def _get_vertex_groups_from_sample( self, nodes_of_interest: TensorType, is_sorted: bool = False - ) -> dict: + ) -> Dict[str, torch.Tensor]: """ Given a tensor of nodes of interest, this method a single dictionary, noi_index. @@ -808,7 +811,10 @@ def _get_sample_from_vertex_groups( def _get_renumbered_edge_groups_from_sample( self, sampling_results: cudf.DataFrame, noi_index: dict - ) -> Tuple[dict, dict]: + ) -> Tuple[ + Dict[Tuple[str, str, str], torch.Tensor], + Tuple[Dict[Tuple[str, str, str], torch.Tensor]], + ]: """ Given a cudf (NOT dask_cudf) DataFrame of sampling results and a dictionary of non-renumbered vertex ids grouped by vertex type, this method diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index 8552e7412e0..ad8d22e255e 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -15,6 +15,7 @@ import os import re +import warnings import cupy import cudf @@ -159,23 +160,34 @@ def __init__( if batch_size is None or batch_size < 1: raise ValueError("Batch size must be >= 1") - self.__directory = tempfile.TemporaryDirectory(dir=directory) + self.__directory = ( + tempfile.TemporaryDirectory() if directory is None else directory + ) if isinstance(num_neighbors, dict): raise ValueError("num_neighbors dict is currently unsupported!") - renumber = ( - True - if ( - (len(self.__graph_store.node_types) == 1) - and (len(self.__graph_store.edge_types) == 1) + if "renumber" in kwargs: + warnings.warn( + "Setting renumbering manually could result in invalid output," + " please ensure you intended to do this." + ) + renumber = kwargs.pop("renumber") + else: + renumber = ( + True + if ( + (len(self.__graph_store.node_types) == 1) + and (len(self.__graph_store.edge_types) == 1) + ) + else False ) - else False - ) bulk_sampler = BulkSampler( batch_size, - self.__directory.name, + self.__directory + if isinstance(self.__directory, str) + else self.__directory.name, self.__graph_store._subgraph(edge_types), fanout_vals=num_neighbors, with_replacement=replace, @@ -219,7 +231,13 @@ def __init__( ) bulk_sampler.flush() - self.__input_files = iter(os.listdir(self.__directory.name)) + self.__input_files = iter( + os.listdir( + self.__directory + if isinstance(self.__directory, str) + else self.__directory.name + ) + ) def __next__(self): from time import perf_counter @@ -423,9 +441,6 @@ def __next__(self): sampler_output.edge, ) else: - if self.__graph_store.order == "CSR": - raise ValueError("CSR format incompatible with CSC output") - out = filter_cugraph_store_csc( self.__feature_store, self.__graph_store, @@ -437,11 +452,8 @@ def __next__(self): # Account for CSR format in cuGraph vs. CSC format in PyG if self.__coo and self.__graph_store.order == "CSC": - for node_type in out.edge_index_dict: - out[node_type].edge_index[0], out[node_type].edge_index[1] = ( - out[node_type].edge_index[1], - out[node_type].edge_index[0], - ) + for edge_type in out.edge_index_dict: + out[edge_type].edge_index = out[edge_type].edge_index.flip(dims=[0]) out.set_value_dict("num_sampled_nodes", sampler_output.num_sampled_nodes) out.set_value_dict("num_sampled_edges", sampler_output.num_sampled_edges) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py index 083c4a2b37b..1512901822a 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py @@ -24,7 +24,7 @@ import torch import numpy as np from cugraph.gnn import FeatureStore -from cugraph.experimental.datasets import karate +from cugraph.datasets import karate import tempfile diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py index ed7f70034e2..13c9c90c7c2 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py @@ -120,7 +120,7 @@ def test_get_edge_index(graph, edge_index_type, dask_client): G[et][0] = dask_cudf.from_cudf(cudf.Series(G[et][0]), npartitions=1) G[et][1] = dask_cudf.from_cudf(cudf.Series(G[et][1]), npartitions=1) - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = CuGraphStore(F, G, N, order="CSC", multi_gpu=True) for pyg_can_edge_type in G: src, dst = cugraph_store.get_edge_index( diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index 853836dc2a6..27b73bf7d35 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -18,6 +18,7 @@ import cudf import cupy +import numpy as np from cugraph_pyg.loader import CuGraphNeighborLoader from cugraph_pyg.loader import BulkSampleLoader @@ -27,6 +28,8 @@ from cugraph.gnn import FeatureStore from cugraph.utilities.utils import import_optional, MissingModule +from typing import Dict, Tuple + torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") trim_to_layer = import_optional("torch_geometric.utils.trim_to_layer") @@ -40,7 +43,11 @@ @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -def test_cugraph_loader_basic(karate_gnn): +def test_cugraph_loader_basic( + karate_gnn: Tuple[ + FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] + ] +): F, G, N = karate_gnn cugraph_store = CuGraphStore(F, G, N, order="CSR") loader = CuGraphNeighborLoader( @@ -66,7 +73,11 @@ def test_cugraph_loader_basic(karate_gnn): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -def test_cugraph_loader_hetero(karate_gnn): +def test_cugraph_loader_hetero( + karate_gnn: Tuple[ + FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] + ] +): F, G, N = karate_gnn cugraph_store = CuGraphStore(F, G, N, order="CSR") loader = CuGraphNeighborLoader( @@ -342,7 +353,7 @@ def test_cugraph_loader_e2e_coo(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available") @pytest.mark.parametrize("framework", ["pyg", "cugraph-ops"]) -def test_cugraph_loader_e2e_csc(framework): +def test_cugraph_loader_e2e_csc(framework: str): m = [2, 9, 99, 82, 9, 3, 18, 1, 12] x = torch.randint(3000, (256, 256)).to(torch.float32) F = FeatureStore() @@ -442,3 +453,40 @@ def test_cugraph_loader_e2e_csc(framework): x = x.narrow(dim=0, start=0, length=s - num_sampled_nodes[1]) assert list(x.shape) == [1, 1] + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("directory", ["local", "temp"]) +def test_load_directory( + karate_gnn: Tuple[ + FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] + ], + directory: str, +): + if directory == "local": + local_dir = tempfile.TemporaryDirectory(dir=".") + + cugraph_store = CuGraphStore(*karate_gnn) + cugraph_loader = CuGraphNeighborLoader( + (cugraph_store, cugraph_store), + torch.arange(8, dtype=torch.int64), + 2, + num_neighbors=[8, 4, 2], + random_state=62, + replace=False, + directory=None if directory == "temp" else local_dir.name, + batches_per_partition=1, + ) + + it = iter(cugraph_loader) + next_batch = next(it) + assert next_batch is not None + + if directory == "local": + assert len(os.listdir(local_dir.name)) == 4 + + count = 1 + while next(it, None) is not None: + count += 1 + + assert count == 4 diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py index da3043760d4..b39ebad8254 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py @@ -113,7 +113,7 @@ def test_get_edge_index(graph, edge_index_type): G[et][0] = cudf.Series(G[et][0]) G[et][1] = cudf.Series(G[et][1]) - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = CuGraphStore(F, G, N, order="CSC") for pyg_can_edge_type in G: src, dst = cugraph_store.get_edge_index( diff --git a/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py b/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py index 5f9eac6b2a3..361226c8071 100644 --- a/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py +++ b/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py @@ -17,7 +17,7 @@ import cugraph from cugraph.experimental import PropertyGraph, MGPropertyGraph -from cugraph.experimental import datasets +from cugraph import datasets from cugraph.generators import rmat diff --git a/python/cugraph-service/server/pyproject.toml b/python/cugraph-service/server/pyproject.toml index f50b33b3f15..d68f8055ded 100644 --- a/python/cugraph-service/server/pyproject.toml +++ b/python/cugraph-service/server/pyproject.toml @@ -25,10 +25,9 @@ dependencies = [ "cupy-cuda11x>=12.0.0", "dask-cuda==23.12.*", "dask-cudf==23.12.*", - "dask>=2023.7.1", - "distributed>=2023.7.1", "numba>=0.57", "numpy>=1.21", + "rapids-dask-dependency==23.12.*", "rmm==23.12.*", "thriftpy2", "ucx-py==0.35.*", diff --git a/python/cugraph/cugraph/dask/community/leiden.py b/python/cugraph/cugraph/dask/community/leiden.py index 75582fa48f7..67bd0876ce6 100644 --- a/python/cugraph/cugraph/dask/community/leiden.py +++ b/python/cugraph/cugraph/dask/community/leiden.py @@ -125,7 +125,7 @@ def leiden( Examples -------- - >>> from cugraph.experimental.datasets import karate + >>> from cugraph.datasets import karate >>> G = karate.get_graph(fetch=True) >>> parts, modularity_score = cugraph.leiden(G) diff --git a/python/cugraph/cugraph/dask/community/louvain.py b/python/cugraph/cugraph/dask/community/louvain.py index 8efbbafaf7b..1b091817a1a 100644 --- a/python/cugraph/cugraph/dask/community/louvain.py +++ b/python/cugraph/cugraph/dask/community/louvain.py @@ -129,7 +129,7 @@ def louvain( Examples -------- - >>> from cugraph.experimental.datasets import karate + >>> from cugraph.datasets import karate >>> G = karate.get_graph(fetch=True) >>> parts = cugraph.louvain(G) diff --git a/python/cugraph/cugraph/datasets/__init__.py b/python/cugraph/cugraph/datasets/__init__.py index 65a820f108b..ac18274d354 100644 --- a/python/cugraph/cugraph/datasets/__init__.py +++ b/python/cugraph/cugraph/datasets/__init__.py @@ -39,3 +39,13 @@ small_tree = Dataset(meta_path / "small_tree.yaml") toy_graph = Dataset(meta_path / "toy_graph.yaml") toy_graph_undirected = Dataset(meta_path / "toy_graph_undirected.yaml") + +# Benchmarking datasets: be mindful of memory usage +# 250 MB +soc_livejournal = Dataset(meta_path / "soc-livejournal1.yaml") +# 965 MB +cit_patents = Dataset(meta_path / "cit-patents.yaml") +# 1.8 GB +europe_osm = Dataset(meta_path / "europe_osm.yaml") +# 1.5 GB +hollywood = Dataset(meta_path / "hollywood.yaml") diff --git a/python/cugraph/cugraph/datasets/dataset.py b/python/cugraph/cugraph/datasets/dataset.py index 877eade7708..dd7aa0df00a 100644 --- a/python/cugraph/cugraph/datasets/dataset.py +++ b/python/cugraph/cugraph/datasets/dataset.py @@ -14,44 +14,45 @@ import cudf import yaml import os +import pandas as pd from pathlib import Path from cugraph.structure.graph_classes import Graph class DefaultDownloadDir: """ - Maintains the path to the download directory used by Dataset instances. + Maintains a path to be used as a default download directory. + + All DefaultDownloadDir instances are based on RAPIDS_DATASET_ROOT_DIR if + set, or _default_base_dir if not set. + Instances of this class are typically shared by several Dataset instances in order to allow for the download directory to be defined and updated by a single object. """ - def __init__(self): - self._path = Path( - os.environ.get("RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets") - ) + _default_base_dir = Path.home() / ".cugraph/datasets" - @property - def path(self): + def __init__(self, *, subdir=""): """ - If `path` is not set, set it to the environment variable - RAPIDS_DATASET_ROOT_DIR. If the variable is not set, default to the - user's home directory. + subdir can be specified to provide a specialized dir under the base dir. """ - if self._path is None: - self._path = Path( - os.environ.get( - "RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets" - ) - ) - return self._path + self._subdir = Path(subdir) + self.reset() + + @property + def path(self): + return self._path.absolute() @path.setter def path(self, new): self._path = Path(new) - def clear(self): - self._path = None + def reset(self): + self._basedir = Path( + os.environ.get("RAPIDS_DATASET_ROOT_DIR", self._default_base_dir) + ) + self._path = self._basedir / self._subdir default_download_dir = DefaultDownloadDir() @@ -159,7 +160,7 @@ def unload(self): """ self._edgelist = None - def get_edgelist(self, download=False): + def get_edgelist(self, download=False, reader="cudf"): """ Return an Edgelist @@ -168,6 +169,9 @@ def get_edgelist(self, download=False): download : Boolean (default=False) Automatically download the dataset from the 'url' location within the YAML file. + + reader : 'cudf' or 'pandas' (default='cudf') + The library used to read a CSV and return an edgelist DataFrame. """ if self._edgelist is None: full_path = self.get_path() @@ -180,14 +184,29 @@ def get_edgelist(self, download=False): " exist. Try setting download=True" " to download the datafile" ) + header = None if isinstance(self.metadata["header"], int): header = self.metadata["header"] - self._edgelist = cudf.read_csv( - full_path, + + if reader == "cudf": + self.__reader = cudf.read_csv + elif reader == "pandas": + self.__reader = pd.read_csv + else: + raise ValueError( + "reader must be a module with a read_csv function compatible with \ + cudf.read_csv" + ) + + self._edgelist = self.__reader( + filepath_or_buffer=full_path, delimiter=self.metadata["delim"], names=self.metadata["col_names"], - dtype=self.metadata["col_types"], + dtype={ + self.metadata["col_names"][i]: self.metadata["col_types"][i] + for i in range(len(self.metadata["col_types"])) + }, header=header, ) @@ -219,6 +238,10 @@ def get_graph( dataset -if present- will be applied to the Graph. If the dataset does not contain weights, the Graph returned will be unweighted regardless of ignore_weights. + + store_transposed: Boolean (default=False) + If True, stores the transpose of the adjacency matrix. Required + for certain algorithms, such as pagerank. """ if self._edgelist is None: self.get_edgelist(download) @@ -237,20 +260,19 @@ def get_graph( "(or subclass) type or instance, got: " f"{type(create_using)}" ) - if len(self.metadata["col_names"]) > 2 and not (ignore_weights): G.from_cudf_edgelist( self._edgelist, - source="src", - destination="dst", - edge_attr="wgt", + source=self.metadata["col_names"][0], + destination=self.metadata["col_names"][1], + edge_attr=self.metadata["col_names"][2], store_transposed=store_transposed, ) else: G.from_cudf_edgelist( self._edgelist, - source="src", - destination="dst", + source=self.metadata["col_names"][0], + destination=self.metadata["col_names"][1], store_transposed=store_transposed, ) return G @@ -331,7 +353,7 @@ def download_all(force=False): def set_download_dir(path): """ - Set the download location fors datasets + Set the download location for datasets Parameters ---------- @@ -339,10 +361,10 @@ def set_download_dir(path): Location used to store datafiles """ if path is None: - default_download_dir.clear() + default_download_dir.reset() else: default_download_dir.path = path def get_download_dir(): - return default_download_dir.path.absolute() + return default_download_dir.path diff --git a/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml b/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml new file mode 100644 index 00000000000..d5c4cf195bd --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml @@ -0,0 +1,22 @@ +name: cit-Patents +file_type: .csv +description: A citation graph that includes all citations made by patents granted between 1975 and 1999, totaling 16,522,438 citations. +author: NBER +refs: + J. Leskovec, J. Kleinberg and C. Faloutsos. Graphs over Time Densification Laws, Shrinking Diameters and Possible Explanations. + ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2005. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: true +is_directed: true +is_multigraph: false +is_symmetric: false +number_of_edges: 16518948 +number_of_nodes: 3774768 +url: https://data.rapids.ai/cugraph/datasets/cit-Patents.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml b/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml new file mode 100644 index 00000000000..fe0e42a4b86 --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml @@ -0,0 +1,21 @@ +name: europe_osm +file_type: .csv +description: A graph of OpenStreetMap data for Europe. +author: M. Kobitzsh / Geofabrik GmbH +refs: + Rossi, Ryan. Ahmed, Nesreen. The Network Data Respoistory with Interactive Graph Analytics and Visualization. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: true +number_of_edges: 54054660 +number_of_nodes: 50912018 +url: https://data.rapids.ai/cugraph/datasets/europe_osm.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/hollywood.yaml b/python/cugraph/cugraph/datasets/metadata/hollywood.yaml new file mode 100644 index 00000000000..2f09cf7679b --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/hollywood.yaml @@ -0,0 +1,26 @@ +name: hollywood +file_type: .csv +description: + A graph of movie actors where vertices are actors, and two actors are + joined by an edge whenever they appeared in a movie together. +author: Laboratory for Web Algorithmics (LAW) +refs: + The WebGraph Framework I Compression Techniques, Paolo Boldi + and Sebastiano Vigna, Proc. of the Thirteenth International + World Wide Web Conference (WWW 2004), 2004, Manhattan, USA, + pp. 595--601, ACM Press. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: true +number_of_edges: 57515616 +number_of_nodes: 1139905 +url: https://data.rapids.ai/cugraph/datasets/hollywood.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml b/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml new file mode 100644 index 00000000000..fafc68acb9b --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml @@ -0,0 +1,22 @@ +name: soc-LiveJournal1 +file_type: .csv +description: A graph of the LiveJournal social network. +author: L. Backstrom, D. Huttenlocher, J. Kleinberg, X. Lan +refs: + L. Backstrom, D. Huttenlocher, J. Kleinberg, X. Lan. Group Formation in + Large Social Networks Membership, Growth, and Evolution. KDD, 2006. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: true +is_directed: true +is_multigraph: false +is_symmetric: false +number_of_edges: 68993773 +number_of_nodes: 4847571 +url: https://data.rapids.ai/cugraph/datasets/soc-LiveJournal1.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml b/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml new file mode 100644 index 00000000000..df5df5735af --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml @@ -0,0 +1,22 @@ +name: soc-twitter-2010 +file_type: .csv +description: A network of follower relationships from a snapshot of Twitter in 2010, where an edge from i to j indicates that j is a follower of i. +author: H. Kwak, C. Lee, H. Park, S. Moon +refs: + J. Yang, J. Leskovec. Temporal Variation in Online Media. ACM Intl. + Conf. on Web Search and Data Mining (WSDM '11), 2011. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: false +number_of_edges: 530051354 +number_of_nodes: 21297772 +url: https://data.rapids.ai/cugraph/datasets/soc-twitter-2010.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/experimental/datasets/__init__.py b/python/cugraph/cugraph/experimental/datasets/__init__.py deleted file mode 100644 index 18220243df1..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from cugraph.experimental.datasets.dataset import ( - Dataset, - load_all, - set_download_dir, - get_download_dir, - default_download_dir, -) -from cugraph.experimental.datasets import metadata -from pathlib import Path - -from cugraph.utilities.api_tools import promoted_experimental_warning_wrapper - - -Dataset = promoted_experimental_warning_wrapper(Dataset) -load_all = promoted_experimental_warning_wrapper(load_all) -set_download_dir = promoted_experimental_warning_wrapper(set_download_dir) -get_download_dir = promoted_experimental_warning_wrapper(get_download_dir) - -meta_path = Path(__file__).parent / "metadata" - - -# individual dataset objects -karate = Dataset(meta_path / "karate.yaml") -karate_data = Dataset(meta_path / "karate_data.yaml") -karate_undirected = Dataset(meta_path / "karate_undirected.yaml") -karate_asymmetric = Dataset(meta_path / "karate_asymmetric.yaml") -karate_disjoint = Dataset(meta_path / "karate-disjoint.yaml") -dolphins = Dataset(meta_path / "dolphins.yaml") -polbooks = Dataset(meta_path / "polbooks.yaml") -netscience = Dataset(meta_path / "netscience.yaml") -cyber = Dataset(meta_path / "cyber.yaml") -small_line = Dataset(meta_path / "small_line.yaml") -small_tree = Dataset(meta_path / "small_tree.yaml") -toy_graph = Dataset(meta_path / "toy_graph.yaml") -toy_graph_undirected = Dataset(meta_path / "toy_graph_undirected.yaml") -email_Eu_core = Dataset(meta_path / "email-Eu-core.yaml") -ktruss_polbooks = Dataset(meta_path / "ktruss_polbooks.yaml") - - -# batches of datasets -DATASETS_UNDIRECTED = [karate, dolphins] - -DATASETS_UNDIRECTED_WEIGHTS = [netscience] - -DATASETS_UNRENUMBERED = [karate_disjoint] - -DATASETS = [dolphins, netscience, karate_disjoint] - -DATASETS_SMALL = [karate, dolphins, polbooks] - -STRONGDATASETS = [dolphins, netscience, email_Eu_core] - -DATASETS_KTRUSS = [(polbooks, ktruss_polbooks)] - -MEDIUM_DATASETS = [polbooks] - -SMALL_DATASETS = [karate, dolphins, netscience] - -RLY_SMALL_DATASETS = [small_line, small_tree] - -ALL_DATASETS = [karate, dolphins, netscience, polbooks, small_line, small_tree] - -ALL_DATASETS_WGT = [karate, dolphins, netscience, polbooks, small_line, small_tree] - -TEST_GROUP = [dolphins, netscience] diff --git a/python/cugraph/cugraph/experimental/datasets/dataset.py b/python/cugraph/cugraph/experimental/datasets/dataset.py deleted file mode 100644 index 6b395d50fef..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/dataset.py +++ /dev/null @@ -1,312 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cudf -import yaml -import os -from pathlib import Path -from cugraph.structure.graph_classes import Graph - - -class DefaultDownloadDir: - """ - Maintains the path to the download directory used by Dataset instances. - Instances of this class are typically shared by several Dataset instances - in order to allow for the download directory to be defined and updated by - a single object. - """ - - def __init__(self): - self._path = Path( - os.environ.get("RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets") - ) - - @property - def path(self): - """ - If `path` is not set, set it to the environment variable - RAPIDS_DATASET_ROOT_DIR. If the variable is not set, default to the - user's home directory. - """ - if self._path is None: - self._path = Path( - os.environ.get( - "RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets" - ) - ) - return self._path - - @path.setter - def path(self, new): - self._path = Path(new) - - def clear(self): - self._path = None - - -default_download_dir = DefaultDownloadDir() - - -class Dataset: - """ - A Dataset Object, used to easily import edgelist data and cuGraph.Graph - instances. - - Parameters - ---------- - meta_data_file_name : yaml file - The metadata file for the specific graph dataset, which includes - information on the name, type, url link, data loading format, graph - properties - """ - - def __init__( - self, - metadata_yaml_file=None, - csv_file=None, - csv_header=None, - csv_delim=" ", - csv_col_names=None, - csv_col_dtypes=None, - ): - self._metadata_file = None - self._dl_path = default_download_dir - self._edgelist = None - self._path = None - - if metadata_yaml_file is not None and csv_file is not None: - raise ValueError("cannot specify both metadata_yaml_file and csv_file") - - elif metadata_yaml_file is not None: - with open(metadata_yaml_file, "r") as file: - self.metadata = yaml.safe_load(file) - self._metadata_file = Path(metadata_yaml_file) - - elif csv_file is not None: - if csv_col_names is None or csv_col_dtypes is None: - raise ValueError( - "csv_col_names and csv_col_dtypes must both be " - "not None when csv_file is specified." - ) - self._path = Path(csv_file) - if self._path.exists() is False: - raise FileNotFoundError(csv_file) - self.metadata = { - "name": self._path.with_suffix("").name, - "file_type": ".csv", - "url": None, - "header": csv_header, - "delim": csv_delim, - "col_names": csv_col_names, - "col_types": csv_col_dtypes, - } - - else: - raise ValueError("must specify either metadata_yaml_file or csv_file") - - def __str__(self): - """ - Use the basename of the meta_data_file the instance was constructed with, - without any extension, as the string repr. - """ - # The metadata file is likely to have a more descriptive file name, so - # use that one first if present. - # FIXME: this may need to provide a more unique or descriptive string repr - if self._metadata_file is not None: - return self._metadata_file.with_suffix("").name - else: - return self.get_path().with_suffix("").name - - def __download_csv(self, url): - """ - Downloads the .csv file from url to the current download path - (self._dl_path), updates self._path with the full path to the - downloaded file, and returns the latest value of self._path. - """ - self._dl_path.path.mkdir(parents=True, exist_ok=True) - - filename = self.metadata["name"] + self.metadata["file_type"] - if self._dl_path.path.is_dir(): - df = cudf.read_csv(url) - self._path = self._dl_path.path / filename - df.to_csv(self._path, index=False) - - else: - raise RuntimeError( - f"The directory {self._dl_path.path.absolute()}" "does not exist" - ) - return self._path - - def unload(self): - - """ - Remove all saved internal objects, forcing them to be re-created when - accessed. - - NOTE: This will cause calls to get_*() to re-read the dataset file from - disk. The caller should ensure the file on disk has not moved/been - deleted/changed. - """ - self._edgelist = None - - def get_edgelist(self, fetch=False): - """ - Return an Edgelist - - Parameters - ---------- - fetch : Boolean (default=False) - Automatically fetch for the dataset from the 'url' location within - the YAML file. - """ - if self._edgelist is None: - full_path = self.get_path() - if not full_path.is_file(): - if fetch: - full_path = self.__download_csv(self.metadata["url"]) - else: - raise RuntimeError( - f"The datafile {full_path} does not" - " exist. Try get_edgelist(fetch=True)" - " to download the datafile" - ) - header = None - if isinstance(self.metadata["header"], int): - header = self.metadata["header"] - self._edgelist = cudf.read_csv( - full_path, - delimiter=self.metadata["delim"], - names=self.metadata["col_names"], - dtype=self.metadata["col_types"], - header=header, - ) - - return self._edgelist - - def get_graph( - self, - fetch=False, - create_using=Graph, - ignore_weights=False, - store_transposed=False, - ): - """ - Return a Graph object. - - Parameters - ---------- - fetch : Boolean (default=False) - Downloads the dataset from the web. - - create_using: cugraph.Graph (instance or class), optional - (default=Graph) - Specify the type of Graph to create. Can pass in an instance to - create a Graph instance with specified 'directed' attribute. - - ignore_weights : Boolean (default=False) - Ignores weights in the dataset if True, resulting in an - unweighted Graph. If False (the default), weights from the - dataset -if present- will be applied to the Graph. If the - dataset does not contain weights, the Graph returned will - be unweighted regardless of ignore_weights. - """ - if self._edgelist is None: - self.get_edgelist(fetch) - - if create_using is None: - G = Graph() - elif isinstance(create_using, Graph): - # what about BFS if trnaposed is True - attrs = {"directed": create_using.is_directed()} - G = type(create_using)(**attrs) - elif type(create_using) is type: - G = create_using() - else: - raise TypeError( - "create_using must be a cugraph.Graph " - "(or subclass) type or instance, got: " - f"{type(create_using)}" - ) - - if len(self.metadata["col_names"]) > 2 and not (ignore_weights): - G.from_cudf_edgelist( - self._edgelist, - source="src", - destination="dst", - edge_attr="wgt", - store_transposed=store_transposed, - ) - else: - G.from_cudf_edgelist( - self._edgelist, - source="src", - destination="dst", - store_transposed=store_transposed, - ) - return G - - def get_path(self): - """ - Returns the location of the stored dataset file - """ - if self._path is None: - self._path = self._dl_path.path / ( - self.metadata["name"] + self.metadata["file_type"] - ) - - return self._path.absolute() - - -def load_all(force=False): - """ - Looks in `metadata` directory and fetches all datafiles from the the URLs - provided in each YAML file. - - Parameters - force : Boolean (default=False) - Overwrite any existing copies of datafiles. - """ - default_download_dir.path.mkdir(parents=True, exist_ok=True) - - meta_path = Path(__file__).parent.absolute() / "metadata" - for file in meta_path.iterdir(): - meta = None - if file.suffix == ".yaml": - with open(meta_path / file, "r") as metafile: - meta = yaml.safe_load(metafile) - - if "url" in meta: - filename = meta["name"] + meta["file_type"] - save_to = default_download_dir.path / filename - if not save_to.is_file() or force: - df = cudf.read_csv(meta["url"]) - df.to_csv(save_to, index=False) - - -def set_download_dir(path): - """ - Set the download directory for fetching datasets - - Parameters - ---------- - path : String - Location used to store datafiles - """ - if path is None: - default_download_dir.clear() - else: - default_download_dir.path = path - - -def get_download_dir(): - return default_download_dir.path.absolute() diff --git a/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml b/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml deleted file mode 100644 index 69a79db9cd9..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -fetch: "False" -force: "False" -# path where datasets will be downloaded to and stored -download_dir: "datasets" diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py b/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py deleted file mode 100644 index 081b2ae8260..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml deleted file mode 100644 index 93ab5345442..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: cyber -file_type: .csv -author: N/A -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/cyber.csv -refs: N/A -col_names: - - idx - - srcip - - dstip -col_types: - - int32 - - str - - str -delim: "," -header: 0 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 2546575 -number_of_nodes: 706529 -number_of_lines: 2546576 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml deleted file mode 100644 index e4951375321..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml +++ /dev/null @@ -1,25 +0,0 @@ -name: dolphins -file_type: .csv -author: D. Lusseau -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/dolphins.csv -refs: - D. Lusseau, K. Schneider, O. J. Boisseau, P. Haase, E. Slooten, and S. M. Dawson, - The bottlenose dolphin community of Doubtful Sound features a large proportion of - long-lasting associations, Behavioral Ecology and Sociobiology 54, 396-405 (2003). -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -delim: " " -header: None -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 318 -number_of_nodes: 62 -number_of_lines: 318 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml deleted file mode 100644 index 97d0dc82ee3..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: email-Eu-core -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/email-Eu-core.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 25571 -number_of_nodes: 1005 -number_of_lines: 25571 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml deleted file mode 100644 index 0c0eaf78b63..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate-disjoint -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-disjoint.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: True -is_multigraph: false -is_symmetric: true -number_of_edges: 312 -number_of_nodes: 68 -number_of_lines: 312 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml deleted file mode 100644 index 273381ed368..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: karate -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 156 -number_of_nodes: 34 -number_of_lines: 156 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml deleted file mode 100644 index 3616b8fb3a5..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: karate-asymmetric -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-asymmetric.csv -delim: " " -header: None -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: true -is_directed: false -is_multigraph: false -is_symmetric: false -number_of_edges: 78 -number_of_nodes: 34 -number_of_lines: 78 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml deleted file mode 100644 index 9a8b27f21ae..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate-data -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-data.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: "\t" -header: None -col_names: - - src - - dst -col_types: - - int32 - - int32 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 156 -number_of_nodes: 34 -number_of_lines: 156 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml deleted file mode 100644 index 1b45f86caee..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate_undirected -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate_undirected.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: "\t" -header: None -col_names: - - src - - dst -col_types: - - int32 - - int32 -has_loop: true -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 78 -number_of_nodes: 34 -number_of_lines: 78 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml deleted file mode 100644 index 1ef29b3917e..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml +++ /dev/null @@ -1,23 +0,0 @@ -name: ktruss_polbooks -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/ref/ktruss/polbooks.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 233 -number_of_nodes: 58 -number_of_lines: 233 - diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml deleted file mode 100644 index 2dca702df3d..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: netscience -file_type: .csv -author: Newman, Mark EJ -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/netscience.csv -refs: Finding community structure in networks using the eigenvectors of matrices. -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 2742 -number_of_nodes: 1461 -number_of_lines: 5484 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml deleted file mode 100644 index 5816e5672fd..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: polbooks -file_type: .csv -author: V. Krebs -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/polbooks.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -is_directed: true -has_loop: null -is_multigraph: null -is_symmetric: true -number_of_edges: 882 -number_of_nodes: 105 -number_of_lines: 882 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml deleted file mode 100644 index 5b724ac99fd..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: small_line -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/small_line.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 9 -number_of_nodes: 10 -number_of_lines: 8 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml deleted file mode 100644 index 8eeac346d2a..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: small_tree -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/small_tree.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 11 -number_of_nodes: 9 -number_of_lines: 11 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml deleted file mode 100644 index 819aad06f6a..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: toy_graph -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/toy_graph.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 16 -number_of_nodes: 6 -number_of_lines: 16 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml deleted file mode 100644 index c6e86bdf334..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: toy_graph_undirected -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/toy_graph_undirected.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 8 -number_of_nodes: 6 -number_of_lines: 8 diff --git a/python/cugraph/cugraph/structure/graph_classes.py b/python/cugraph/cugraph/structure/graph_classes.py index 6f6c7e5a26c..03efcba0307 100644 --- a/python/cugraph/cugraph/structure/graph_classes.py +++ b/python/cugraph/cugraph/structure/graph_classes.py @@ -469,8 +469,7 @@ def from_numpy_array(self, np_array, nodes=None): nodes: array-like or None, optional (default=None) A list of column names, acting as labels for nodes """ - if not isinstance(np_array, np.ndarray): - raise TypeError("np_array input is not a Numpy array") + np_array = np.asarray(np_array) if len(np_array.shape) != 2: raise ValueError("np_array is not a 2D matrix") diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleGraph.py index 2b23d3a26b7..22d82eb1796 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleGraph.py @@ -1286,9 +1286,13 @@ def nodes(self): else: return df[df.columns[0]] else: - return cudf.concat( - [df[simpleGraphImpl.srcCol], df[simpleGraphImpl.dstCol]] - ).unique() + return ( + cudf.concat( + [df[simpleGraphImpl.srcCol], df[simpleGraphImpl.dstCol]] + ) + .drop_duplicates() + .reset_index(drop=True) + ) if self.adjlist is not None: return cudf.Series(np.arange(0, self.number_of_nodes())) diff --git a/python/cugraph/cugraph/structure/property_graph.py b/python/cugraph/cugraph/structure/property_graph.py index 36ce5baa212..513798f35f9 100644 --- a/python/cugraph/cugraph/structure/property_graph.py +++ b/python/cugraph/cugraph/structure/property_graph.py @@ -800,15 +800,9 @@ def add_vertex_data( tmp_df.index = tmp_df.index.rename(self.vertex_col_name) # FIXME: handle case of a type_name column already being in tmp_df - if self.__series_type is cudf.Series: - # cudf does not yet support initialization with a scalar - tmp_df[TCN] = cudf.Series( - cudf.Series([type_name], dtype=cat_dtype).repeat(len(tmp_df)), - index=tmp_df.index, - ) - else: - # pandas is oddly slow if dtype is passed to the constructor here - tmp_df[TCN] = pd.Series(type_name, index=tmp_df.index).astype(cat_dtype) + tmp_df[TCN] = self.__series_type(type_name, index=tmp_df.index).astype( + cat_dtype + ) if property_columns: # all columns @@ -1207,15 +1201,9 @@ def add_edge_data( tmp_df[self.src_col_name] = tmp_df[vertex_col_names[0]] tmp_df[self.dst_col_name] = tmp_df[vertex_col_names[1]] - if self.__series_type is cudf.Series: - # cudf does not yet support initialization with a scalar - tmp_df[TCN] = cudf.Series( - cudf.Series([type_name], dtype=cat_dtype).repeat(len(tmp_df)), - index=tmp_df.index, - ) - else: - # pandas is oddly slow if dtype is passed to the constructor here - tmp_df[TCN] = pd.Series(type_name, index=tmp_df.index).astype(cat_dtype) + tmp_df[TCN] = self.__series_type(type_name, index=tmp_df.index).astype( + cat_dtype + ) # Add unique edge IDs to the new rows. This is just a count for each # row starting from the last edge ID value, with initial edge ID 0. diff --git a/python/cugraph/cugraph/testing/__init__.py b/python/cugraph/cugraph/testing/__init__.py index f5f0bcb06eb..2b4a4fd3ebf 100644 --- a/python/cugraph/cugraph/testing/__init__.py +++ b/python/cugraph/cugraph/testing/__init__.py @@ -19,7 +19,7 @@ Resultset, load_resultset, get_resultset, - results_dir_path, + default_resultset_download_dir, ) from cugraph.datasets import ( cyber, @@ -34,6 +34,11 @@ email_Eu_core, toy_graph, toy_graph_undirected, + soc_livejournal, + cit_patents, + europe_osm, + hollywood, + # twitter, ) # @@ -66,3 +71,4 @@ toy_graph_undirected, ] DEFAULT_DATASETS = [dolphins, netscience, karate_disjoint] +BENCHMARKING_DATASETS = [soc_livejournal, cit_patents, europe_osm, hollywood] diff --git a/python/cugraph/cugraph/testing/generate_resultsets.py b/python/cugraph/cugraph/testing/generate_resultsets.py index 9724aca32dc..2ae0f52d88b 100644 --- a/python/cugraph/cugraph/testing/generate_resultsets.py +++ b/python/cugraph/cugraph/testing/generate_resultsets.py @@ -20,8 +20,14 @@ import cudf import cugraph from cugraph.datasets import dolphins, netscience, karate_disjoint, karate -from cugraph.testing import utils, Resultset, SMALL_DATASETS, results_dir_path +# from cugraph.testing import utils, Resultset, SMALL_DATASETS, results_dir_path +from cugraph.testing import ( + utils, + Resultset, + SMALL_DATASETS, + default_resultset_download_dir, +) _resultsets = {} @@ -224,6 +230,7 @@ def add_resultset(result_data_dictionary, **kwargs): ] ) # Generating ALL results files + results_dir_path = default_resultset_download_dir.path if not results_dir_path.exists(): results_dir_path.mkdir(parents=True, exist_ok=True) diff --git a/python/cugraph/cugraph/testing/resultset.py b/python/cugraph/cugraph/testing/resultset.py index 490e3a7c4ff..9570d7f3e04 100644 --- a/python/cugraph/cugraph/testing/resultset.py +++ b/python/cugraph/cugraph/testing/resultset.py @@ -16,10 +16,12 @@ import urllib.request import cudf -from cugraph.testing import utils +from cugraph.datasets.dataset import ( + DefaultDownloadDir, + default_download_dir, +) - -results_dir_path = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" / "resultsets" +# results_dir_path = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" / "resultsets" class Resultset: @@ -48,6 +50,42 @@ def get_cudf_dataframe(self): _resultsets = {} +def get_resultset(resultset_name, **kwargs): + """ + Returns the golden results for a specific test. + + Parameters + ---------- + resultset_name : String + Name of the test's module (currently just 'traversal' is supported) + + kwargs : + All distinct test details regarding the choice of algorithm, dataset, + and graph + """ + arg_dict = dict(kwargs) + arg_dict["resultset_name"] = resultset_name + # Example: + # {'a': 1, 'z': 9, 'c': 5, 'b': 2} becomes 'a-1-b-2-c-5-z-9' + resultset_key = "-".join( + [ + str(val) + for arg_dict_pair in sorted(arg_dict.items()) + for val in arg_dict_pair + ] + ) + uuid = _resultsets.get(resultset_key) + if uuid is None: + raise KeyError(f"results for {arg_dict} not found") + + results_dir_path = default_resultset_download_dir.path + results_filename = results_dir_path / (uuid + ".csv") + return cudf.read_csv(results_filename) + + +default_resultset_download_dir = DefaultDownloadDir(subdir="tests/resultsets") + + def load_resultset(resultset_name, resultset_download_url): """ Read a mapping file (.csv) in the _results_dir and save the @@ -56,17 +94,21 @@ def load_resultset(resultset_name, resultset_download_url): _results_dir, use resultset_download_url to download a file to install/unpack/etc. to _results_dir first. """ - mapping_file_path = results_dir_path / (resultset_name + "_mappings.csv") + # curr_resultset_download_dir = get_resultset_download_dir() + curr_resultset_download_dir = default_resultset_download_dir.path + # curr_download_dir = path + curr_download_dir = default_download_dir.path + mapping_file_path = curr_resultset_download_dir / (resultset_name + "_mappings.csv") if not mapping_file_path.exists(): # Downloads a tar gz from s3 bucket, then unpacks the results files - compressed_file_dir = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" + compressed_file_dir = curr_download_dir / "tests" compressed_file_path = compressed_file_dir / "resultsets.tar.gz" - if not results_dir_path.exists(): - results_dir_path.mkdir(parents=True, exist_ok=True) + if not curr_resultset_download_dir.exists(): + curr_resultset_download_dir.mkdir(parents=True, exist_ok=True) if not compressed_file_path.exists(): urllib.request.urlretrieve(resultset_download_url, compressed_file_path) tar = tarfile.open(str(compressed_file_path), "r:gz") - tar.extractall(str(results_dir_path)) + tar.extractall(str(curr_resultset_download_dir)) tar.close() # FIXME: This assumes separator is " ", but should this be configurable? @@ -102,35 +144,3 @@ def load_resultset(resultset_name, resultset_download_url): ) _resultsets[resultset_key] = uuid - - -def get_resultset(resultset_name, **kwargs): - """ - Returns the golden results for a specific test. - - Parameters - ---------- - resultset_name : String - Name of the test's module (currently just 'traversal' is supported) - - kwargs : - All distinct test details regarding the choice of algorithm, dataset, - and graph - """ - arg_dict = dict(kwargs) - arg_dict["resultset_name"] = resultset_name - # Example: - # {'a': 1, 'z': 9, 'c': 5, 'b': 2} becomes 'a-1-b-2-c-5-z-9' - resultset_key = "-".join( - [ - str(val) - for arg_dict_pair in sorted(arg_dict.items()) - for val in arg_dict_pair - ] - ) - uuid = _resultsets.get(resultset_key) - if uuid is None: - raise KeyError(f"results for {arg_dict} not found") - - results_filename = results_dir_path / (uuid + ".csv") - return cudf.read_csv(results_filename) diff --git a/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py b/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py index 4277f94a396..478b7e655d5 100644 --- a/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py +++ b/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py @@ -16,7 +16,7 @@ import dask_cudf from pylibcugraph.testing.utils import gen_fixture_params_product -from cugraph.experimental.datasets import DATASETS_UNDIRECTED +from cugraph.datasets import karate, dolphins import cugraph import cugraph.dask as dcg @@ -41,7 +41,7 @@ def setup_function(): # email_Eu_core is too expensive to test -datasets = DATASETS_UNDIRECTED +datasets = [karate, dolphins] # ============================================================================= diff --git a/python/cugraph/cugraph/tests/link_analysis/test_hits.py b/python/cugraph/cugraph/tests/link_analysis/test_hits.py index 1c5a135e944..fcfd8cc5318 100644 --- a/python/cugraph/cugraph/tests/link_analysis/test_hits.py +++ b/python/cugraph/cugraph/tests/link_analysis/test_hits.py @@ -38,7 +38,11 @@ def setup_function(): fixture_params = gen_fixture_params_product( (datasets, "graph_file"), ([50], "max_iter"), - ([1.0e-6], "tol"), + # FIXME: Changed this from 1.0e-6 to 1.0e-5. NX defaults to + # FLOAT64 computation, cuGraph C++ defaults to whatever the edge weight + # is, cugraph python defaults that to FLOAT32. Does not converge at + # 1e-6 for larger graphs and FLOAT32. + ([1.0e-5], "tol"), ) diff --git a/python/cugraph/cugraph/tests/nx/test_compat_pr.py b/python/cugraph/cugraph/tests/nx/test_compat_pr.py index 9be3912a33f..45cab7a5674 100644 --- a/python/cugraph/cugraph/tests/nx/test_compat_pr.py +++ b/python/cugraph/cugraph/tests/nx/test_compat_pr.py @@ -24,7 +24,7 @@ import numpy as np from cugraph.testing import utils -from cugraph.experimental.datasets import karate +from cugraph.datasets import karate from pylibcugraph.testing.utils import gen_fixture_params_product diff --git a/python/cugraph/cugraph/tests/utils/test_dataset.py b/python/cugraph/cugraph/tests/utils/test_dataset.py index c2a4f7c6072..60bc6dbb45a 100644 --- a/python/cugraph/cugraph/tests/utils/test_dataset.py +++ b/python/cugraph/cugraph/tests/utils/test_dataset.py @@ -13,11 +13,10 @@ import os import gc -import sys -import warnings from pathlib import Path from tempfile import TemporaryDirectory +import pandas import pytest import cudf @@ -27,6 +26,7 @@ ALL_DATASETS, WEIGHTED_DATASETS, SMALL_DATASETS, + BENCHMARKING_DATASETS, ) from cugraph import datasets @@ -74,27 +74,14 @@ def setup(tmpdir): gc.collect() -@pytest.fixture() -def setup_deprecation_warning_tests(): - """ - Fixture used to set warning filters to 'default' and reload - experimental.datasets module if it has been previously - imported. Tests that import this fixture are expected to - import cugraph.experimental.datasets - """ - warnings.filterwarnings("default") - - if "cugraph.experimental.datasets" in sys.modules: - del sys.modules["cugraph.experimental.datasets"] - - yield - - ############################################################################### # Helpers # check if there is a row where src == dst -def has_loop(df): +def has_selfloop(dataset): + if not dataset.metadata["is_directed"]: + return False + df = dataset.get_edgelist(download=True) df.rename(columns={df.columns[0]: "src", df.columns[1]: "dst"}, inplace=True) res = df.where(df["src"] == df["dst"]) @@ -109,7 +96,13 @@ def is_symmetric(dataset): else: df = dataset.get_edgelist(download=True) df_a = df.sort_values("src") - df_b = df_a[["dst", "src", "wgt"]] + + # create df with swapped src/dst columns + df_b = None + if "wgt" in df_a.columns: + df_b = df_a[["dst", "src", "wgt"]] + else: + df_b = df_a[["dst", "src"]] df_b.rename(columns={"dst": "src", "src": "dst"}, inplace=True) # created a df by appending the two res = cudf.concat([df_a, df_b]) @@ -157,6 +150,27 @@ def test_download(dataset): assert dataset.get_path().is_file() +@pytest.mark.parametrize("dataset", SMALL_DATASETS) +def test_reader(dataset): + # defaults to using cudf.read_csv + E = dataset.get_edgelist(download=True) + + assert E is not None + assert isinstance(E, cudf.core.dataframe.DataFrame) + dataset.unload() + + # using pandas + E_pd = dataset.get_edgelist(download=True, reader="pandas") + + assert E_pd is not None + assert isinstance(E_pd, pandas.core.frame.DataFrame) + dataset.unload() + + with pytest.raises(ValueError): + dataset.get_edgelist(reader="fail") + dataset.get_edgelist(reader=None) + + @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_get_edgelist(dataset): E = dataset.get_edgelist(download=True) @@ -172,7 +186,6 @@ def test_get_graph(dataset): @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_metadata(dataset): M = dataset.metadata - assert M is not None @@ -310,10 +323,8 @@ def test_is_directed(dataset): @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_has_loop(dataset): - df = dataset.get_edgelist(download=True) - - assert has_loop(df) == dataset.metadata["has_loop"] +def test_has_selfloop(dataset): + assert has_selfloop(dataset) == dataset.metadata["has_loop"] @pytest.mark.parametrize("dataset", ALL_DATASETS) @@ -328,6 +339,25 @@ def test_is_multigraph(dataset): assert G.is_multigraph() == dataset.metadata["is_multigraph"] +# The datasets used for benchmarks are in their own test, since downloading them +# repeatedly would increase testing overhead significantly +@pytest.mark.parametrize("dataset", BENCHMARKING_DATASETS) +def test_benchmarking_datasets(dataset): + dataset_is_directed = dataset.metadata["is_directed"] + G = dataset.get_graph( + download=True, create_using=Graph(directed=dataset_is_directed) + ) + + assert G.is_directed() == dataset.metadata["is_directed"] + assert G.number_of_nodes() == dataset.metadata["number_of_nodes"] + assert G.number_of_edges() == dataset.metadata["number_of_edges"] + assert has_selfloop(dataset) == dataset.metadata["has_loop"] + assert is_symmetric(dataset) == dataset.metadata["is_symmetric"] + assert G.is_multigraph() == dataset.metadata["is_multigraph"] + + dataset.unload() + + @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_object_getters(dataset): assert dataset.is_directed() == dataset.metadata["is_directed"] @@ -336,32 +366,3 @@ def test_object_getters(dataset): assert dataset.number_of_nodes() == dataset.metadata["number_of_nodes"] assert dataset.number_of_vertices() == dataset.metadata["number_of_nodes"] assert dataset.number_of_edges() == dataset.metadata["number_of_edges"] - - -# -# Test experimental for DeprecationWarnings -# -def test_experimental_dataset_import(setup_deprecation_warning_tests): - with pytest.deprecated_call(): - from cugraph.experimental.datasets import karate - - # unload() is called to pass flake8 - karate.unload() - - -def test_experimental_method_warnings(setup_deprecation_warning_tests): - from cugraph.experimental.datasets import ( - load_all, - set_download_dir, - get_download_dir, - ) - - warnings.filterwarnings("default") - tmpd = TemporaryDirectory() - - with pytest.deprecated_call(): - set_download_dir(tmpd.name) - get_download_dir() - load_all() - - tmpd.cleanup() diff --git a/python/cugraph/cugraph/tests/utils/test_resultset.py b/python/cugraph/cugraph/tests/utils/test_resultset.py new file mode 100644 index 00000000000..5c2298bedb7 --- /dev/null +++ b/python/cugraph/cugraph/tests/utils/test_resultset.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from pathlib import Path +from tempfile import TemporaryDirectory + +import cudf +from cugraph.datasets.dataset import ( + set_download_dir, + get_download_dir, +) +from cugraph.testing.resultset import load_resultset, default_resultset_download_dir + +############################################################################### + + +def test_load_resultset(): + with TemporaryDirectory() as tmpd: + + set_download_dir(Path(tmpd)) + default_resultset_download_dir.path = Path(tmpd) / "tests" / "resultsets" + default_resultset_download_dir.path.mkdir(parents=True, exist_ok=True) + + datasets_download_dir = get_download_dir() + resultsets_download_dir = default_resultset_download_dir.path + assert "tests" in os.listdir(datasets_download_dir) + assert "resultsets.tar.gz" not in os.listdir(datasets_download_dir / "tests") + assert "traversal_mappings.csv" not in os.listdir(resultsets_download_dir) + + load_resultset( + "traversal", "https://data.rapids.ai/cugraph/results/resultsets.tar.gz" + ) + + assert "resultsets.tar.gz" in os.listdir(datasets_download_dir / "tests") + assert "traversal_mappings.csv" in os.listdir(resultsets_download_dir) + + +def test_verify_resultset_load(): + # This test is more detailed than test_load_resultset, where for each module, + # we check that every single resultset file is included along with the + # corresponding mapping file. + with TemporaryDirectory() as tmpd: + set_download_dir(Path(tmpd)) + default_resultset_download_dir.path = Path(tmpd) / "tests" / "resultsets" + default_resultset_download_dir.path.mkdir(parents=True, exist_ok=True) + + resultsets_download_dir = default_resultset_download_dir.path + + load_resultset( + "traversal", "https://data.rapids.ai/cugraph/results/resultsets.tar.gz" + ) + + resultsets = os.listdir(resultsets_download_dir) + downloaded_results = cudf.read_csv( + resultsets_download_dir / "traversal_mappings.csv", sep=" " + ) + downloaded_uuids = downloaded_results["#UUID"].values + for resultset_uuid in downloaded_uuids: + assert str(resultset_uuid) + ".csv" in resultsets diff --git a/python/cugraph/pyproject.toml b/python/cugraph/pyproject.toml index aaa301fa05f..319900b3de3 100644 --- a/python/cugraph/pyproject.toml +++ b/python/cugraph/pyproject.toml @@ -33,12 +33,11 @@ dependencies = [ "cupy-cuda11x>=12.0.0", "dask-cuda==23.12.*", "dask-cudf==23.12.*", - "dask>=2023.7.1", - "distributed>=2023.7.1", "fsspec[http]>=0.6.0", "numba>=0.57", "pylibcugraph==23.12.*", "raft-dask==23.12.*", + "rapids-dask-dependency==23.12.*", "rmm==23.12.*", "ucx-py==0.35.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/nx-cugraph/README.md b/python/nx-cugraph/README.md index ab267e5a756..273a6112d77 100644 --- a/python/nx-cugraph/README.md +++ b/python/nx-cugraph/README.md @@ -1,32 +1,135 @@ # nx-cugraph ## Description -[RAPIDS](https://rapids.ai) nx-cugraph is a [backend to NetworkX](https://networkx.org/documentation/stable/reference/classes/index.html#backends) -with minimal dependencies (`networkx`, `cupy`, and `pylibcugraph`) to run graph algorithms on the GPU. +[RAPIDS](https://rapids.ai) nx-cugraph is a [backend to NetworkX](https://networkx.org/documentation/stable/reference/utils.html#backends) +to run supported algorithms with GPU acceleration. -### Contribute +## System Requirements -Follow instructions for [contributing to cugraph](https://github.com/rapidsai/cugraph/blob/branch-23.12 -and [building from source](https://docs.rapids.ai/api/cugraph/stable/installation/source_build/), then build nx-cugraph in develop (i.e., editable) mode: -``` -$ ./build.sh nx-cugraph --pydevelop -``` +nx-cugraph requires the following: + + * NVIDIA GPU, Pascal architecture or later + * CUDA 11.2, 11.4, 11.5, 11.8, or 12.0 + * Python versions 3.9, 3.10, or 3.11 + * NetworkX >= version 3.2 -### Run tests +More details about system requirements can be found in the [RAPIDS System Requirements documentation](https://docs.rapids.ai/install#system-req). -Run nx-cugraph tests from `cugraph/python/nx-cugraph` directory: +## Installation + +nx-cugraph can be installed using either conda or pip. + +### conda ``` -$ pytest +conda install -c rapidsai-nightly -c conda-forge -c nvidia nx-cugraph ``` -Run nx-cugraph benchmarks: +### pip ``` -$ pytest --bench +python -m pip install nx-cugraph-cu11 --extra-index-url https://pypi.nvidia.com ``` -Run networkx tests (requires networkx version 3.2): +Notes: + + * Nightly wheel builds will not be available until the 23.12 release, therefore the index URL for the stable release version is being used in the pip install command above. + * Additional information relevant to installing any RAPIDS package can be found [here](https://rapids.ai/#quick-start). + +## Enabling nx-cugraph + +NetworkX will use nx-cugraph as the graph analytics backend if any of the +following are used: + +### `NETWORKX_AUTOMATIC_BACKENDS` environment variable. +The `NETWORKX_AUTOMATIC_BACKENDS` environment variable can be used to have NetworkX automatically dispatch to specified backends an API is called that the backend supports. +Set `NETWORKX_AUTOMATIC_BACKENDS=cugraph` to use nx-cugraph to GPU accelerate supported APIs with no code changes. +Example: ``` -$ ./run_nx_tests.sh +bash> NETWORKX_AUTOMATIC_BACKENDS=cugraph python my_networkx_script.py ``` -Additional arguments may be passed to pytest such as: + +### `backend=` keyword argument +To explicitly specify a particular backend for an API, use the `backend=` +keyword argument. This argument takes precedence over the +`NETWORKX_AUTOMATIC_BACKENDS` environment variable. This requires anyone +running code that uses the `backend=` keyword argument to have the specified +backend installed. + +Example: ``` -$ ./run_nx_tests.sh -x --sw -k betweenness +nx.betweenness_centrality(cit_patents_graph, k=k, backend="cugraph") ``` + +### Type-based dispatching + +NetworkX also supports automatically dispatching to backends associated with +specific graph types. Like the `backend=` keyword argument example above, this +requires the user to write code for a specific backend, and therefore requires +the backend to be installed, but has the advantage of ensuring a particular +behavior without the potential for runtime conversions. + +To use type-based dispatching with nx-cugraph, the user must import the backend +directly in their code to access the utilities provided to create a Graph +instance specifically for the nx-cugraph backend. + +Example: +``` +import networkx as nx +import nx_cugraph as nxcg + +G = nx.Graph() +... +nxcg_G = nxcg.from_networkx(G) # conversion happens once here +nx.betweenness_centrality(nxcg_G, k=1000) # nxcg Graph type causes cugraph backend + # to be used, no conversion necessary +``` + +## Supported Algorithms + +The nx-cugraph backend to NetworkX connects +[pylibcugraph](../../readme_pages/pylibcugraph.md) (cuGraph's low-level python +interface to its CUDA-based graph analytics library) and +[CuPy](https://cupy.dev/) (a GPU-accelerated array library) to NetworkX's +familiar and easy-to-use API. + +Below is the list of algorithms (many listed using pylibcugraph names), +available today in pylibcugraph or implemented using CuPy, that are or will be +supported in nx-cugraph. + +| feature/algo | release/target version | +| ----- | ----- | +| analyze_clustering_edge_cut | ? | +| analyze_clustering_modularity | ? | +| analyze_clustering_ratio_cut | ? | +| balanced_cut_clustering | ? | +| betweenness_centrality | 23.10 | +| bfs | ? | +| core_number | ? | +| degree_centrality | 23.12 | +| ecg | ? | +| edge_betweenness_centrality | 23.10 | +| ego_graph | ? | +| eigenvector_centrality | 23.12 | +| get_two_hop_neighbors | ? | +| hits | 23.12 | +| in_degree_centrality | 23.12 | +| induced_subgraph | ? | +| jaccard_coefficients | ? | +| katz_centrality | 23.12 | +| k_core | ? | +| k_truss_subgraph | 23.12 | +| leiden | ? | +| louvain | 23.10 | +| node2vec | ? | +| out_degree_centrality | 23.12 | +| overlap_coefficients | ? | +| pagerank | 23.12 | +| personalized_pagerank | ? | +| sorensen_coefficients | ? | +| spectral_modularity_maximization | ? | +| sssp | 23.12 | +| strongly_connected_components | ? | +| triangle_count | ? | +| uniform_neighbor_sample | ? | +| uniform_random_walks | ? | +| weakly_connected_components | ? | + +To request nx-cugraph backend support for a NetworkX API that is not listed +above, visit the [cuGraph GitHub repo](https://github.com/rapidsai/cugraph). diff --git a/python/nx-cugraph/_nx_cugraph/__init__.py b/python/nx-cugraph/_nx_cugraph/__init__.py index bf554a11a8b..457350b9ef0 100644 --- a/python/nx-cugraph/_nx_cugraph/__init__.py +++ b/python/nx-cugraph/_nx_cugraph/__init__.py @@ -157,6 +157,7 @@ def get_info(): return d +# FIXME: can this use the standard VERSION file and update mechanism? __version__ = "23.12.00" if __name__ == "__main__":