From 02109a662d6b6f3d0d8f78649b6e2bd94043f095 Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:35:19 -0400 Subject: [PATCH 1/7] Update cuGraph-PyG GraphSAGE Examples (#4224) Fixes failing tests of the examples. Updates the examples to use the cugraph-ops models within cugraph-pyg instead of the deprecated ones within pyg. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Don Acosta (https://github.com/acostadon) - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4224 --- python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py | 4 ++-- python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py index 9c0adaad879..4ca573504a1 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ import torch import numpy as np -from torch_geometric.nn import CuGraphSAGEConv +from cugraph_pyg.nn import SAGEConv as CuGraphSAGEConv import torch.nn as nn import torch.nn.functional as F diff --git a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py index 82f5e7ea67d..9c96a707e4d 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,7 +18,7 @@ import torch -from torch_geometric.nn import CuGraphSAGEConv +from cugraph_pyg.nn import SAGEConv as CuGraphSAGEConv import torch.nn as nn import torch.nn.functional as F From 6c4f881d063f5079bea5e9e330e12871a3697c55 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Mon, 11 Mar 2024 12:37:45 -0400 Subject: [PATCH 2/7] Add additional `kwargs` to GATConv (#4210) Support deterministic/high-precision flags in mha primitives, introduced in https://github.com/rapidsai/cugraph-ops/pull/607 Authors: - Tingyu Wang (https://github.com/tingyu66) Approvers: - Maximilian Stadler (https://github.com/stadlmax) - Alex Barghi (https://github.com/alexbarghi-nv) URL: https://github.com/rapidsai/cugraph/pull/4210 --- .../cugraph_dgl/nn/conv/gatconv.py | 24 +++++++++++++++++ .../cugraph_dgl/nn/conv/gatv2conv.py | 12 +++++++++ .../cugraph_pyg/nn/conv/gat_conv.py | 26 ++++++++++++++++++- .../cugraph_pyg/nn/conv/gatv2_conv.py | 14 +++++++++- 4 files changed, 74 insertions(+), 2 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py index cc4ce474f2d..e8813271fd8 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py @@ -186,6 +186,10 @@ def forward( nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, max_in_degree: Optional[int] = None, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, + high_precision_dgrad: bool = False, + high_precision_wgrad: bool = False, ) -> torch.Tensor: r"""Forward computation. @@ -204,6 +208,20 @@ def forward( from a neighbor sampler, the value should be set to the corresponding :attr:`fanout`. This option is used to invoke the MFG-variant of cugraph-ops kernel. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. + high_precision_dgrad: bool, default=False + Optional flag indicating whether gradients for inputs in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. + high_precision_wgrad: bool, default=False + Optional flag indicating whether gradients for weights in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. Returns ------- @@ -232,6 +250,8 @@ def forward( _graph = self.get_cugraph_ops_CSC( g, is_bipartite=bipartite, max_in_degree=max_in_degree ) + if deterministic_dgrad: + _graph.add_reverse_graph() if bipartite: nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1])) @@ -273,6 +293,10 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=efeat, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, + high_precision_dgrad=high_precision_dgrad, + high_precision_wgrad=high_precision_wgrad, )[: g.num_dst_nodes()] if self.concat: diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py index 6c78b4df0b8..4f47005f8ee 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py @@ -150,6 +150,8 @@ def forward( nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, max_in_degree: Optional[int] = None, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, ) -> torch.Tensor: r"""Forward computation. @@ -166,6 +168,12 @@ def forward( from a neighbor sampler, the value should be set to the corresponding :attr:`fanout`. This option is used to invoke the MFG-variant of cugraph-ops kernel. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. Returns ------- @@ -196,6 +204,8 @@ def forward( _graph = self.get_cugraph_ops_CSC( g, is_bipartite=graph_bipartite, max_in_degree=max_in_degree ) + if deterministic_dgrad: + _graph.add_reverse_graph() if nfeat_bipartite: nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1])) @@ -228,6 +238,8 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=efeat, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, )[: g.num_dst_nodes()] if self.concat: diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 309bee4e228..d1785f2bef8 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -162,6 +162,10 @@ def forward( csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, max_num_neighbors: Optional[int] = None, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, + high_precision_dgrad: bool = False, + high_precision_wgrad: bool = False, ) -> torch.Tensor: r"""Runs the forward pass of the module. @@ -178,11 +182,27 @@ def forward( of a destination node. When enabled, it allows models to use the message-flow-graph primitives in cugraph-ops. (default: :obj:`None`) + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. + high_precision_dgrad: bool, default=False + Optional flag indicating whether gradients for inputs in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. + high_precision_wgrad: bool, default=False + Optional flag indicating whether gradients for weights in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. """ bipartite = not isinstance(x, torch.Tensor) graph = self.get_cugraph( csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors ) + if deterministic_dgrad: + graph.add_reverse_graph() if edge_attr is not None: if self.lin_edge is None: @@ -220,6 +240,10 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=edge_attr, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, + high_precision_dgrad=high_precision_dgrad, + high_precision_wgrad=high_precision_wgrad, ) if self.bias is not None: diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py index 32956dcb400..33865898816 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -174,6 +174,8 @@ def forward( x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, ) -> torch.Tensor: r"""Runs the forward pass of the module. @@ -186,9 +188,17 @@ def forward( :meth:`to_csc` method to convert an :obj:`edge_index` representation to the desired format. edge_attr: (torch.Tensor, optional) The edge features. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. """ bipartite = not isinstance(x, torch.Tensor) or not self.share_weights graph = self.get_cugraph(csc, bipartite=bipartite) + if deterministic_dgrad: + graph.add_reverse_graph() if edge_attr is not None: if self.lin_edge is None: @@ -217,6 +227,8 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=edge_attr, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, ) if self.bias is not None: From 4f4be6ef78bdf5a874e3cad849b956465d1f6443 Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:58:37 -0400 Subject: [PATCH 3/7] cuGraph-DGL and WholeGraph Performance Testing with Feature Store Performance Improvements (#4081) Large-scale cuGraph-DGL performance testing scripts. Also changes the DGL and PyG scripts to evaluate on all ranks and reuse the test samples, and adds support for benchmarking cuGraph-DGL/cuGraph-PyG with WholeGraph. Updates `cuGraph.gnn.FeatureStore` and `cuGraph-PyG` for increased performance: * Supporting passing in a WG embedding directly to cugraph.gnn.FeatureStore * Simplifying how cuGraph-PyG handles filtering and using a cache to prevent repeatedly copying data between the device and host * Fix bug in cugraph.gnn.FeatureStore where indexing with a gpu tensor would raise an exception, especially with WG * Add a function to cugraph.gnn.FeatureStore to check where data is stored, which is used by cuGraph-PyG to prevent unnecessary d2h and h2d copies Merge after #3584 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Seunghwa Kang (https://github.com/seunghwak) - Vibhu Jawa (https://github.com/VibhuJawa) - Brad Rees (https://github.com/BradReesWork) Approvers: - Vibhu Jawa (https://github.com/VibhuJawa) - Don Acosta (https://github.com/acostadon) - Brad Rees (https://github.com/BradReesWork) - Naim (https://github.com/naimnv) - Joseph Nke (https://github.com/jnke2016) URL: https://github.com/rapidsai/cugraph/pull/4081 --- .../standalone/bulk_sampling/README.md | 2 +- .../bulk_sampling/bench_cugraph_training.py | 75 +++- .../bulk_sampling/cugraph_bulk_sampling.py | 124 ++++-- .../bulk_sampling/datasets/ogbn_papers100M.py | 152 ++++++-- .../bulk_sampling/models/dgl/__init__.py | 15 + .../bulk_sampling/models/dgl/models_dgl.py | 69 ++++ .../models/pyg/models_cugraph_pyg.py | 2 +- .../standalone/bulk_sampling/run_train_job.sh | 41 +- .../{run_sampling.sh => train.sh} | 52 ++- .../bulk_sampling/trainers/dgl/__init__.py | 16 + .../trainers/dgl/trainers_cugraph_dgl.py | 315 +++++++++++++++ .../trainers/dgl/trainers_dgl.py | 361 ++++++++++++++++++ .../trainers/pyg/trainers_cugraph_pyg.py | 172 +++++++-- .../trainers/pyg/trainers_pyg.py | 215 ++++++----- .../cugraph_dgl/dataloading/dataset.py | 6 +- .../cugraph_pyg/data/cugraph_store.py | 77 +++- .../cugraph_pyg/loader/cugraph_node_loader.py | 36 +- .../cugraph_pyg/sampler/cugraph_sampler.py | 4 + .../gnn/feature_storage/feat_storage.py | 45 ++- 19 files changed, 1514 insertions(+), 265 deletions(-) create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py rename benchmarks/cugraph/standalone/bulk_sampling/{run_sampling.sh => train.sh} (66%) create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py diff --git a/benchmarks/cugraph/standalone/bulk_sampling/README.md b/benchmarks/cugraph/standalone/bulk_sampling/README.md index 2d09466fb2f..56e9f4f5f64 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/README.md +++ b/benchmarks/cugraph/standalone/bulk_sampling/README.md @@ -152,7 +152,7 @@ Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. Y the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which can be used to create replications of the dataset for scale testing purposes. -The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` +The final two arguments are `FRAMEWORK` which can be "cugraph_dgl_csr", "cugraph_pyg" or "pyg", and `GPUS_PER_NODE` which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE` is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched GPUs per node is currently unsupported by this script but should be possible in practice. diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py index c9e347b261d..2604642b748 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -43,8 +43,9 @@ def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> Non rmm.reinitialize( devices=[rank], - pool_allocator=True, - initial_pool_size=pool_size, + pool_allocator=False, + # pool_allocator=True, + # initial_pool_size=pool_size, ) if use_rmm_torch_allocator: @@ -119,10 +120,17 @@ def parse_args(): parser.add_argument( "--framework", type=str, - help="The framework to test (PyG, cuGraphPyG)", + help="The framework to test (PyG, cugraph_pyg, cugraph_dgl_csr)", required=True, ) + parser.add_argument( + "--use_wholegraph", + action="store_true", + help="Whether to use WholeGraph feature storage", + required=False, + ) + parser.add_argument( "--model", type=str, @@ -162,6 +170,13 @@ def parse_args(): required=False, ) + parser.add_argument( + "--skip_download", + action="store_true", + help="Whether to skip downloading", + required=False, + ) + return parser.parse_args() @@ -186,21 +201,43 @@ def main(args): world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node + if args.use_wholegraph: + # TODO support WG without cuGraph + if args.framework.lower() not in ["cugraph_pyg", "cugraph_dgl_csr"]: + raise ValueError("WG feature store only supported with cuGraph backends") + from pylibwholegraph.torch.initialize import ( + get_global_communicator, + get_local_node_communicator, + init, + ) + + logger.info("initializing WG comms...") + init(global_rank, world_size, local_rank, args.gpus_per_node) + wm_comm = get_global_communicator() + get_local_node_communicator() + + wm_comm = wm_comm.wmb_comm + logger.info(f"rank {global_rank} successfully initialized WG comms") + wm_comm.barrier() + dataset = OGBNPapers100MDataset( replication_factor=args.replication_factor, dataset_dir=args.dataset_dir, train_split=args.train_split, val_split=args.val_split, - load_edge_index=(args.framework == "PyG"), + load_edge_index=(args.framework.lower() == "pyg"), + backend="wholegraph" if args.use_wholegraph else "torch", ) - if global_rank == 0: + # Note: this does not generate WG files + if global_rank == 0 and not args.skip_download: dataset.download() + dist.barrier() fanout = [int(f) for f in args.fanout.split("_")] - if args.framework == "PyG": + if args.framework.lower() == "pyg": from trainers.pyg import PyGNativeTrainer trainer = PyGNativeTrainer( @@ -215,7 +252,7 @@ def main(args): num_neighbors=fanout, batch_size=args.batch_size, ) - elif args.framework == "cuGraphPyG": + elif args.framework.lower() == "cugraph_pyg": sample_dir = os.path.join( args.sample_dir, f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", @@ -229,11 +266,35 @@ def main(args): device=local_rank, rank=global_rank, world_size=world_size, + gpus_per_node=args.gpus_per_node, num_epochs=args.num_epochs, shuffle=True, replace=False, num_neighbors=fanout, batch_size=args.batch_size, + backend="wholegraph" if args.use_wholegraph else "torch", + ) + elif args.framework.lower() == "cugraph_dgl_csr": + sample_dir = os.path.join( + args.sample_dir, + f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", + ) + from trainers.dgl import DGLCuGraphTrainer + + trainer = DGLCuGraphTrainer( + model=args.model, + dataset=dataset, + sample_dir=sample_dir, + device=local_rank, + rank=global_rank, + world_size=world_size, + gpus_per_node=args.gpus_per_node, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=[int(f) for f in args.fanout.split("_")], + batch_size=args.batch_size, + backend="wholegraph" if args.use_wholegraph else "torch", ) else: raise ValueError("unsupported framework") diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index e3a5bba3162..95e1afcb28b 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -190,6 +190,10 @@ def sample_graph( val_perc=0.5, sampling_kwargs={}, ): + logger = logging.getLogger("__main__") + logger.info("Starting sampling phase...") + + logger.info("Calculating random splits...") cupy.random.seed(seed) train_df, test_df = label_df.random_split( [train_perc, 1 - train_perc], random_state=seed, shuffle=True @@ -197,24 +201,35 @@ def sample_graph( val_df, test_df = label_df.random_split( [val_perc, 1 - val_perc], random_state=seed, shuffle=True ) + logger.info("Calculated random splits") total_time = 0.0 for epoch in range(num_epochs): - steps = [("train", train_df), ("test", test_df)] + steps = [("train", train_df)] if epoch == num_epochs - 1: steps.append(("val", val_df)) + steps.append(("test", test_df)) for step, batch_df in steps: - batch_df = batch_df.sample(frac=1.0, random_state=seed) + logger.info("Shuffling batch dataframe...") + batch_df = batch_df.sample(frac=1.0, random_state=seed).persist() + logger.info("Shuffled and persisted batch dataframe...") - if step == "val": - output_sample_path = os.path.join(output_path, "val", "samples") - else: + if step == "train": output_sample_path = os.path.join( output_path, f"epoch={epoch}", f"{step}", "samples" ) - os.makedirs(output_sample_path) + else: + output_sample_path = os.path.join(output_path, step, "samples") + + client = default_client() + + def func(): + os.makedirs(output_sample_path, exist_ok=True) + + client.run(func) + logger.info("Creating bulk sampler...") sampler = BulkSampler( batch_size=batch_size, output_path=output_sample_path, @@ -227,6 +242,7 @@ def sample_graph( log_level=logging.INFO, **sampling_kwargs, ) + logger.info("Bulk sampler created and ready for input") n_workers = len(default_client().scheduler_info()["workers"]) @@ -244,13 +260,13 @@ def sample_graph( # should always persist the batch dataframe or performance may be suboptimal batch_df = batch_df.persist() - print("created batches") + logger.info("created and persisted batches") start_time = perf_counter() sampler.add_batches(batch_df, start_col_name="node", batch_col_name="batch") sampler.flush() end_time = perf_counter() - print("flushed all batches") + logger.info("flushed all batches") total_time += end_time - start_time return total_time @@ -356,23 +372,29 @@ def load_disk_dataset( path = Path(dataset_dir) / dataset parquet_path = path / "parquet" + logger = logging.getLogger("__main__") + + logger.info("getting n workers...") n_workers = get_n_workers() + logger.info(f"there are {n_workers} workers") with open(os.path.join(path, "meta.json")) as meta_file: meta = json.load(meta_file) + logger.info("assigning offsets...") node_offsets, node_offsets_replicated, total_num_nodes = assign_offsets_pyg( meta["num_nodes"], replication_factor=replication_factor ) + logger.info("offsets assigned") edge_index_dict = {} for edge_type in meta["num_edges"].keys(): - print(f"Loading edge index for edge type {edge_type}") + logger.info(f"Loading edge index for edge type {edge_type}") can_edge_type = tuple(edge_type.split("__")) edge_index_dict[can_edge_type] = dask_cudf.read_parquet( Path(parquet_path) / edge_type / "edge_index.parquet" - ).repartition(n_workers * 2) + ).repartition(npartitions=n_workers * 2) edge_index_dict[can_edge_type]["src"] += node_offsets_replicated[ can_edge_type[0] @@ -384,6 +406,7 @@ def load_disk_dataset( edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type] if replication_factor > 1: + logger.info("processing replications") edge_index_dict[can_edge_type] = edge_index_dict[ can_edge_type ].map_partitions( @@ -400,6 +423,7 @@ def load_disk_dataset( } ), ) + logger.info("replications processed") gc.collect() @@ -407,48 +431,63 @@ def load_disk_dataset( edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type].rename( columns={"src": "dst", "dst": "src"} ) + logger.info("edge index loaded") # Assign numeric edge type ids based on lexicographic order edge_offsets = {} edge_count = 0 - for num_edge_type, can_edge_type in enumerate(sorted(edge_index_dict.keys())): - if add_edge_types: - edge_index_dict[can_edge_type]["etp"] = cupy.int32(num_edge_type) - edge_offsets[can_edge_type] = edge_count - edge_count += len(edge_index_dict[can_edge_type]) + # for num_edge_type, can_edge_type in enumerate(sorted(edge_index_dict.keys())): + # if add_edge_types: + # edge_index_dict[can_edge_type]["etp"] = cupy.int32(num_edge_type) + # edge_offsets[can_edge_type] = edge_count + # edge_count += len(edge_index_dict[can_edge_type]) + + if len(edge_index_dict) != 1: + raise ValueError("should only be 1 edge index") + + logger.info("setting edge type") + + all_edges_df = list(edge_index_dict.values())[0] + if add_edge_types: + all_edges_df["etp"] = cupy.int32(0) - all_edges_df = dask_cudf.concat(list(edge_index_dict.values())) + # all_edges_df = dask_cudf.concat(list(edge_index_dict.values())) del edge_index_dict gc.collect() node_labels = {} for node_type, offset in node_offsets_replicated.items(): - print(f"Loading node labels for node type {node_type} (offset={offset})") + logger.info(f"Loading node labels for node type {node_type} (offset={offset})") node_label_path = os.path.join( os.path.join(parquet_path, node_type), "node_label.parquet" ) if os.path.exists(node_label_path): node_labels[node_type] = ( dask_cudf.read_parquet(node_label_path) - .repartition(n_workers) + .repartition(npartitions=n_workers) .drop("label", axis=1) .persist() ) + logger.info(f"Loaded and persisted initial labels") node_labels[node_type]["node"] += offset node_labels[node_type] = node_labels[node_type].persist() + logger.info(f"Set and persisted node offsets") if replication_factor > 1: + logger.info(f"Replicating labels...") node_labels[node_type] = node_labels[node_type].map_partitions( _replicate_df, replication_factor, {"node": meta["num_nodes"][node_type]}, meta=cudf.DataFrame({"node": cudf.Series(dtype="int64")}), ) + logger.info(f"Replicated labels (will likely evaluate later)") gc.collect() node_labels_df = dask_cudf.concat(list(node_labels.values())).reset_index(drop=True) + logger.info("Dataset successfully loaded") del node_labels gc.collect() @@ -459,6 +498,7 @@ def load_disk_dataset( node_offsets_replicated, edge_offsets, total_num_nodes, + sum(meta["num_edges"].values()) * replication_factor, ) @@ -540,6 +580,7 @@ def benchmark_cugraph_bulk_sampling( node_offsets, edge_offsets, total_num_nodes, + num_input_edges, ) = load_disk_dataset( dataset, dataset_dir=dataset_dir, @@ -548,7 +589,6 @@ def benchmark_cugraph_bulk_sampling( add_edge_types=add_edge_types, ) - num_input_edges = len(dask_edgelist_df) logger.info(f"Number of input edges = {num_input_edges:,}") G = construct_graph(dask_edgelist_df) @@ -562,7 +602,13 @@ def benchmark_cugraph_bulk_sampling( output_path, f"{dataset}[{replication_factor}]_b{batch_size}_f{fanout}", ) - os.makedirs(output_subdir) + + client = default_client() + + def func(): + os.makedirs(output_subdir, exist_ok=True) + + client.run(func) if sampling_target_framework == "cugraph_dgl_csr": sampling_kwargs = { @@ -574,8 +620,8 @@ def benchmark_cugraph_bulk_sampling( "use_legacy_names": False, "include_hop_column": False, } - else: - # FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.02) + elif sampling_target_framework == "cugraph_pyg": + # FIXME: Update these arguments when CSC mode is fixed in cuGraph-PyG (release 24.04) sampling_kwargs = { "deduplicate_sources": True, "prior_sources_behavior": "exclude", @@ -585,8 +631,10 @@ def benchmark_cugraph_bulk_sampling( "use_legacy_names": False, "include_hop_column": True, } + else: + raise ValueError("Only cugraph_dgl_csr or cugraph_pyg are valid frameworks") - batches_per_partition = 600_000 // batch_size + batches_per_partition = 256 execution_time, allocation_counts = sample_graph( G=G, label_df=dask_label_df, @@ -761,9 +809,9 @@ def get_args(): logger.setLevel(logging.INFO) args = get_args() - if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: + if args.sampling_target_framework not in ["cugraph_dgl_csr", "cugraph_pyg"]: raise ValueError( - "sampling_target_framework must be one of cugraph_dgl_csr or None", + "sampling_target_framework must be one of cugraph_dgl_csr or cugraph_pyg", "Other frameworks are not supported at this time.", ) @@ -775,12 +823,30 @@ def get_args(): seeds_per_call_opts = [int(s) for s in args.seeds_per_call_opts.split(",")] dask_worker_devices = [int(d) for d in args.dask_worker_devices.split(",")] - logger.info("starting dask client") - client, cluster = start_dask_client() + import time + + time_dask_start = time.localtime() + + logger.info(f"{time.asctime(time_dask_start)}: starting dask client") + from dask_cuda.initialize import initialize + from dask.distributed import Client + from cugraph.dask.comms import comms as Comms + import os, time + + client = Client(scheduler_file=os.environ["SCHEDULER_FILE"], timeout=360) + time.sleep(30) + cluster = Comms.initialize(p2p=True) + # client, cluster = start_dask_client() + time_dask_end = time.localtime() + logger.info(f"{time.asctime(time_dask_end)}: dask client started") + + logger.info("enabling spilling") enable_spilling() - stats_ls = [] client.run(enable_spilling) - logger.info("dask client started") + logger.info("enabled spilling") + + stats_ls = [] + for dataset in datasets: m = re.match(r"(\w+)\[([0-9]+)\]", dataset) if m: diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py index a50e40f6d55..e3151e37a25 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py @@ -34,6 +34,7 @@ def __init__( train_split=0.8, val_split=0.5, load_edge_index=True, + backend="torch", ): self.__replication_factor = replication_factor self.__disk_x = None @@ -43,6 +44,7 @@ def __init__( self.__train_split = train_split self.__val_split = val_split self.__load_edge_index = load_edge_index + self.__backend = backend def download(self): import logging @@ -152,6 +154,27 @@ def download(self): ) ldf.to_parquet(node_label_file_path) + # WholeGraph + wg_bin_file_path = os.path.join(dataset_path, "wgb", "paper") + if self.__replication_factor == 1: + wg_bin_rep_path = os.path.join(wg_bin_file_path, "node_feat.d") + else: + wg_bin_rep_path = os.path.join( + wg_bin_file_path, f"node_feat_{self.__replication_factor}x.d" + ) + + if not os.path.exists(wg_bin_rep_path): + os.makedirs(wg_bin_rep_path) + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + node_feat = dataset[0][0]["node_feat"] + for k in range(self.__replication_factor): + node_feat.tofile(os.path.join(wg_bin_rep_path, f"{k:04d}.bin")) + @property def edge_index_dict( self, @@ -224,45 +247,87 @@ def edge_index_dict( @property def x_dict(self) -> Dict[str, torch.Tensor]: + if self.__disk_x is None: + if self.__backend == "wholegraph": + self.__load_x_wg() + else: + self.__load_x_torch() + + return self.__disk_x + + def __load_x_torch(self) -> None: node_type_path = os.path.join( self.__dataset_dir, "ogbn_papers100M", "npy", "paper" ) + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.npy") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.npy" + ) - if self.__disk_x is None: - if self.__replication_factor == 1: - full_path = os.path.join(node_type_path, "node_feat.npy") - else: - full_path = os.path.join( - node_type_path, f"node_feat_{self.__replication_factor}x.npy" - ) + self.__disk_x = {"paper": torch.as_tensor(np.load(full_path, mmap_mode="r"))} - self.__disk_x = {"paper": np.load(full_path, mmap_mode="r")} + def __load_x_wg(self) -> None: + import logging - return self.__disk_x + logger = logging.getLogger("OGBNPapers100MDataset") + logger.info("Loading x into WG embedding...") + + import pylibwholegraph.torch as wgth + + node_type_path = os.path.join( + self.__dataset_dir, "ogbn_papers100M", "wgb", "paper" + ) + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.d") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.d" + ) + + file_list = [os.path.join(full_path, f) for f in os.listdir(full_path)] + + x = wgth.create_embedding_from_filelist( + wgth.get_global_communicator(), + "distributed", # TODO support other options + "cpu", # TODO support GPU + file_list, + torch.float32, + 128, + ) + from pylibwholegraph.torch.initialize import get_global_communicator + + wm_comm = get_global_communicator() + wm_comm.barrier() + + logger.info("created x wg embedding") + + self.__disk_x = {"paper": x} @property def y_dict(self) -> Dict[str, torch.Tensor]: if self.__y is None: - self.__get_labels() + self.__get_y() return self.__y @property def train_dict(self) -> Dict[str, torch.Tensor]: if self.__train is None: - self.__get_labels() + self.__get_split() return self.__train @property def test_dict(self) -> Dict[str, torch.Tensor]: if self.__test is None: - self.__get_labels() + self.__get_split() return self.__test @property def val_dict(self) -> Dict[str, torch.Tensor]: if self.__val is None: - self.__get_labels() + self.__get_split() return self.__val @property @@ -271,7 +336,7 @@ def num_input_features(self) -> int: @property def num_labels(self) -> int: - return int(self.y_dict["paper"].max()) + 1 + return 172 def num_nodes(self, node_type: str) -> int: if node_type != "paper": @@ -285,46 +350,49 @@ def num_edges(self, edge_type: Tuple[str, str, str]) -> int: return 1_615_685_872 * self.__replication_factor - def __get_labels(self): + def __get_y(self): label_path = os.path.join( self.__dataset_dir, "ogbn_papers100M", - "parquet", + "wgb", "paper", - "node_label.parquet", + "node_label.d", + "0.bin", ) - node_label = pandas.read_parquet(label_path) - - if self.__replication_factor > 1: - orig_num_nodes = self.num_nodes("paper") // self.__replication_factor - dfr = pandas.DataFrame( - { - "node": pandas.concat( - [ - node_label.node + (r * orig_num_nodes) - for r in range(1, self.__replication_factor) - ] - ), - "label": pandas.concat( - [node_label.label for r in range(1, self.__replication_factor)] - ), - } + if self.__backend == "wholegraph": + import pylibwholegraph.torch as wgth + + node_label = wgth.create_embedding_from_filelist( + wgth.get_global_communicator(), + "distributed", # TODO support other options + "cpu", # TODO support GPU + [label_path] * self.__replication_factor, + torch.int16, + 1, + ) + + else: + node_label_1x = torch.as_tensor( + np.fromfile(label_path, dtype="int16"), device="cpu" ) - node_label = pandas.concat([node_label, dfr]).reset_index(drop=True) + if self.__replication_factor > 1: + node_label = torch.concatenate( + [node_label_1x] * self.__replication_factor + ) + else: + node_label = node_label_1x + + self.__y = {"paper": node_label} + + def __get_split(self): num_nodes = self.num_nodes("paper") - node_label_tensor = torch.full( - (num_nodes,), -1, dtype=torch.float32, device="cpu" - ) - node_label_tensor[ - torch.as_tensor(node_label.node.values, device="cpu") - ] = torch.as_tensor(node_label.label.values, device="cpu") - self.__y = {"paper": node_label_tensor.contiguous()} + node = self.y_dict["paper"][self.y_dict["paper"] > 0] train_ix, test_val_ix = train_test_split( - torch.as_tensor(node_label.node.values), + node, train_size=self.__train_split, random_state=num_nodes, ) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py new file mode 100644 index 00000000000..610a7648801 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .models_dgl import GraphSAGE diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py new file mode 100644 index 00000000000..2cfdda2d2e7 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/dgl/models_dgl.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +class GraphSAGE(torch.nn.Module): + """ + GraphSAGE model implementation for DGL + supporting both native DGL and cuGraph-ops + backends. + """ + + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + num_layers, + model_backend="dgl", + ): + if model_backend == "dgl": + from dgl.nn import SAGEConv + else: + from cugraph_dgl.nn import SAGEConv + + super(GraphSAGE, self).__init__() + self.convs = torch.nn.ModuleList() + for _ in range(num_layers - 1): + self.convs.append( + SAGEConv(in_channels, hidden_channels, aggregator_type="mean") + ) + in_channels = hidden_channels + self.convs.append( + SAGEConv(hidden_channels, out_channels, aggregator_type="mean") + ) + + def forward(self, blocks, x): + """ + Runs the model forward pass given a list of blocks + and feature tensor. + """ + + for i, conv in enumerate(self.convs): + x = conv(blocks[i], x) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=0.5) + return x + + +def create_model(feat_size, num_classes, num_layers, model_backend="dgl"): + model = GraphSAGE( + feat_size, 64, num_classes, num_layers, model_backend=model_backend + ) + model = model.to("cuda") + model.train() + return model diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py index 1de791bf588..7ee400b004f 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py @@ -57,7 +57,7 @@ def forward(self, x, edge, num_sampled_nodes, num_sampled_edges): for i, conv in enumerate(self.convs): if i > 0: - new_num_edges = edge[1][-2] + new_num_edges = int(edge[1][-2]) edge[0] = edge[0].narrow( dim=0, start=0, diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh index 27ae0dc7788..8136018c877 100755 --- a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -#SBATCH -A datascience_rapids_cugraphgnn -#SBATCH -p luna -#SBATCH -J datascience_rapids_cugraphgnn-papers:bulkSamplingPyG -#SBATCH -N 1 -#SBATCH -t 00:25:00 - CONTAINER_IMAGE=${CONTAINER_IMAGE:="please_specify_container"} SCRIPTS_DIR=$(pwd) LOGS_DIR=${LOGS_DIR:=$(pwd)"/logs"} @@ -31,10 +25,11 @@ mkdir -p $DATASETS_DIR BATCH_SIZE=512 FANOUT="10_10_10" NUM_EPOCHS=1 -REPLICATION_FACTOR=1 +REPLICATION_FACTOR=2 +JOB_ID=$RANDOM -# options: PyG or cuGraphPyG -FRAMEWORK="cuGraphPyG" +# options: PyG, cuGraphPyG, or cuGraphDGL +FRAMEWORK="cuGraphDGL" GPUS_PER_NODE=8 nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) @@ -52,6 +47,7 @@ echo Num GPUs Per Node: $gpus_per_node set -e + # First run without cuGraph to get data if [[ "$FRAMEWORK" == "cuGraphPyG" ]]; then @@ -59,25 +55,10 @@ if [[ "$FRAMEWORK" == "cuGraphPyG" ]]; then srun \ --container-image $CONTAINER_IMAGE \ --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ - bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS + bash /scripts/train.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS "cugraph_pyg" $nnodes $head_node_ip $JOB_ID +elif [[ "$FRAMEWORK" == "cuGraphDGL" ]]; then + srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + bash /scripts/train.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS "cugraph_dgl_csr" $nnodes $head_node_ip $JOB_ID fi - -# Train -srun \ - --container-image $CONTAINER_IMAGE \ - --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ - torchrun \ - --nnodes $nnodes \ - --nproc-per-node $gpus_per_node \ - --rdzv-id $RANDOM \ - --rdzv-backend c10d \ - --rdzv-endpoint $head_node_ip:29500 \ - /scripts/bench_cugraph_training.py \ - --output_file "/logs/output.txt" \ - --framework $FRAMEWORK \ - --dataset_dir "/datasets" \ - --sample_dir "/samples" \ - --batch_size $BATCH_SIZE \ - --fanout $FANOUT \ - --replication_factor $REPLICATION_FACTOR \ - --num_epochs $NUM_EPOCHS diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/train.sh similarity index 66% rename from benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh rename to benchmarks/cugraph/standalone/bulk_sampling/train.sh index 1b3085dcc9a..a3b85e281f1 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh +++ b/benchmarks/cugraph/standalone/bulk_sampling/train.sh @@ -21,6 +21,10 @@ FANOUT=$2 REPLICATION_FACTOR=$3 SCRIPTS_DIR=$4 NUM_EPOCHS=$5 +SAMPLING_FRAMEWORK=$6 +N_NODES=$7 +HEAD_NODE_IP=$8 +JOB_ID=$9 SAMPLES_DIR=/samples DATASET_DIR=/datasets @@ -29,12 +33,19 @@ LOGS_DIR=/logs MG_UTILS_DIR=${SCRIPTS_DIR}/mg_utils SCHEDULER_FILE=${MG_UTILS_DIR}/dask_scheduler.json -export WORKER_RMM_POOL_SIZE=28G -export UCX_MAX_RNDV_RAILS=1 +echo $SAMPLES_DIR +ls $SAMPLES_DIR + +export WORKER_RMM_POOL_SIZE=75G +#export UCX_MAX_RNDV_RAILS=1 export RAPIDS_NO_INITIALIZE=1 export CUDF_SPILL=1 -export LIBCUDF_CUFILE_POLICY="OFF" +export LIBCUDF_CUFILE_POLICY="KVIKIO" +export KVIKIO_NTHREADS=64 export GPUS_PER_NODE=8 +#export NCCL_CUMEM_ENABLE=0 +#export NCCL_DEBUG="TRACE" +export NCCL_DEBUG_FILE=/logs/nccl_debug.%h.%p export SCHEDULER_FILE=$SCHEDULER_FILE export LOGS_DIR=$LOGS_DIR @@ -59,8 +70,9 @@ else fi echo "properly waiting for workers to connect" -NUM_GPUS=$(python -c "import os; print(int(os.environ['SLURM_JOB_NUM_NODES'])*int(os.environ['GPUS_PER_NODE']))") -handleTimeout 120 python ${MG_UTILS_DIR}/wait_for_workers.py \ +export NUM_GPUS=$(python -c "import os; print(int(os.environ['SLURM_JOB_NUM_NODES'])*int(os.environ['GPUS_PER_NODE']))") +SEEDS_PER_CALL=$(python -c "import os; print(int(os.environ['NUM_GPUS'])*65536)") +handleTimeout 630 python ${MG_UTILS_DIR}/wait_for_workers.py \ --num-expected-workers ${NUM_GPUS} \ --scheduler-file-path ${SCHEDULER_FILE} @@ -76,14 +88,15 @@ if [[ $SLURM_NODEID == 0 ]]; then --datasets "ogbn_papers100M["$REPLICATION_FACTOR"]" \ --fanouts $FANOUT \ --batch_sizes $BATCH_SIZE \ - --seeds_per_call_opts "524288" \ + --seeds_per_call_opts $SEEDS_PER_CALL \ --num_epochs $NUM_EPOCHS \ - --random_seed 42 + --random_seed 42 \ + --sampling_target_framework $SAMPLING_FRAMEWORK - echo "DONE" > ${SAMPLES_DIR}/status.txt + echo "DONE" > ${LOGS_DIR}/status.txt fi -while [ ! -f "${SAMPLES_DIR}"/status.txt ] +while [ ! -f "${LOGS_DIR}"/status.txt ] do sleep 1 done @@ -106,6 +119,25 @@ if [[ ${#python_processes[@]} -gt 1 || $dask_processes ]]; then fi sleep 2 +torchrun \ + --nnodes $N_NODES \ + --nproc-per-node $GPUS_PER_NODE \ + --rdzv-id $JOB_ID \ + --rdzv-backend c10d \ + --rdzv-endpoint $HEAD_NODE_IP:29500 \ + /scripts/bench_cugraph_training.py \ + --output_file "/logs/output.txt" \ + --framework $SAMPLING_FRAMEWORK \ + --dataset_dir "/datasets" \ + --sample_dir "/samples" \ + --batch_size $BATCH_SIZE \ + --fanout $FANOUT \ + --replication_factor $REPLICATION_FACTOR \ + --num_epochs $NUM_EPOCHS \ + --use_wholegraph \ + --skip_download + + if [[ $SLURM_NODEID == 0 ]]; then - rm ${SAMPLES_DIR}/status.txt + rm ${LOGS_DIR}/status.txt fi diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py new file mode 100644 index 00000000000..03d2a51e538 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .trainers_dgl import DGLTrainer +from .trainers_cugraph_dgl import DGLCuGraphTrainer diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py new file mode 100644 index 00000000000..37745e645fd --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_cugraph_dgl.py @@ -0,0 +1,315 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import re + +from .trainers_dgl import DGLTrainer +from models.dgl import GraphSAGE +from datasets import Dataset + +import torch +import numpy as np +import warnings + +from torch.nn.parallel import DistributedDataParallel as ddp +from cugraph_dgl.dataloading import HomogenousBulkSamplerDataset +from cugraph.gnn import FeatureStore + +from typing import List + + +def get_dataloader( + input_file_paths: List[str], + total_num_nodes: int, + sparse_format: str, + return_type: str, +) -> torch.utils.data.DataLoader: + """ + Returns a dataloader that reads bulk samples from the given input paths. + + Parameters + ---------- + input_file_paths: List[str] + List of input parquet files containing samples. + total_num_nodes: int + Total number of nodes in the graph. + sparse_format: str + The sparse format to read (i.e. coo) + return_type: str + The type of object to be returned by the dataloader (i.e. dgl.Block) + + Returns + ------- + torch.utils.data.DataLoader + """ + + print("Creating dataloader", flush=True) + st = time.time() + if len(input_file_paths) > 0: + dataset = HomogenousBulkSamplerDataset( + total_num_nodes, + edge_dir="in", + sparse_format=sparse_format, + return_type=return_type, + ) + dataset.set_input_files(input_file_paths=input_file_paths) + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=lambda x: x, + shuffle=False, + num_workers=0, + batch_size=None, + ) + et = time.time() + print(f"Time to create dataloader = {et - st:.2f} seconds", flush=True) + return dataloader + else: + return [] + + +class DGLCuGraphTrainer(DGLTrainer): + """ + Trainer implementation for cuGraph-DGL that supports + WholeGraph as a feature store. + """ + + def __init__( + self, + dataset: Dataset, + model: str = "GraphSAGE", + device: int = 0, + rank: int = 0, + world_size: int = 1, + gpus_per_node: int = 1, + num_epochs: int = 1, + sample_dir: str = ".", + backend: str = "torch", + **kwargs, + ): + """ + Parameters + ---------- + dataset: Dataset + The dataset to train on. + model: str + The model to use for training. + Currently only "GraphSAGE" is supported. + device: int, default=0 + The CUDA device to use. + rank: int, default=0 + The global rank of the worker this trainer is assigned to. + world_size: int, default=1 + The number of workers in the world. + num_epochs: int, default=1 + The number of training epochs to run. + sample_dir: str, default="." + The directory where samples generated by the bulk sampler + are stored. + backend: str, default="torch" + The feature store backend to be used by the cuGraph Feature Store. + Defaults to "torch". Options are "torch" and "wholegraph" + kwargs + Keyword arguments to pass to the loader + """ + self.__data = None + self.__device = device + self.__rank = rank + self.__world_size = world_size + self.__gpus_per_node = gpus_per_node + self.__num_epochs = num_epochs + self.__dataset = dataset + self.__sample_dir = sample_dir + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + self.__backend = backend + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train") -> int: + # TODO support online sampling + if stage == "train": + path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + elif stage in ["test", "val"]: + path = os.path.join(self.__sample_dir, stage, "samples") + else: + raise ValueError(f"Invalid stage {stage}") + + input_file_paths, num_batches = self.get_input_files( + path, epoch=epoch, stage=stage + ) + + dataloader = get_dataloader( + input_file_paths=input_file_paths.tolist(), + total_num_nodes=None, + sparse_format="csc", + return_type="cugraph_dgl.nn.SparseGraph", + ) + return dataloader, num_batches + + @property + def data(self): + import logging + + logger = logging.getLogger("DGLCuGraphTrainer") + logger.info("getting data") + + if self.__data is None: + logger.info("using wholegraph backend") + if self.__backend == "wholegraph": + fs = FeatureStore( + backend="wholegraph", + wg_type="chunked", + wg_location="cpu", + ) + else: + fs = FeatureStore(backend=self.__backend) + num_nodes_dict = {} + + if self.__backend == "wholegraph": + from pylibwholegraph.torch.initialize import get_global_communicator + + wm_comm = get_global_communicator() + wm_comm.barrier() + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + fs.add_data(x, node_type, "x") + num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + if self.__backend == "wholegraph": + wm_comm.barrier() + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + if self.__backend == "wholegraph": + logger.info("using wholegraph backend") + fs.add_data(y, node_type, "y") + wm_comm.barrier() + else: + y = y.cuda() + y = y.reshape((y.shape[0], 1)) + fs.add_data(y, node_type, "y") + + """ + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + train = train.reshape((train.shape[0], 1)) + if self.__backend != "wholegraph": + train = train.cuda() + fs.add_data(train, node_type, "train") + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + test = test.reshape((test.shape[0], 1)) + if self.__backend != "wholegraph": + test = test.cuda() + fs.add_data(test, node_type, "test") + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + val = val.reshape((val.shape[0], 1)) + if self.__backend != "wholegraph": + val = val.cuda() + fs.add_data(val, node_type, "val") + """ + + # # TODO support online sampling if the edge index is provided + # num_edges_dict = self.__dataset.edge_index_dict + # if not isinstance(list(num_edges_dict.values())[0], int): + # num_edges_dict = {k: len(v) for k, v in num_edges_dict} + + if self.__backend == "wholegraph": + wm_comm.barrier() + + self.__data = fs + return self.__data + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + GraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + model_backend="cugraph_dgl", + ) + .to(torch.float32) + .to(self.__device) + ) + # TODO: Fix for distributed models + if torch.distributed.is_initialized(): + model = ddp(model, device_ids=[self.__device]) + else: + warnings.warn("Distributed training is not available") + print("done creating model") + + return model + + def get_input_files(self, path, epoch=0, stage="train"): + file_list = np.array([f.path for f in os.scandir(path)]) + file_list.sort() + np.random.seed(epoch) + np.random.shuffle(file_list) + + splits = np.array_split(file_list, self.__gpus_per_node) + + ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet") + num_batches = min( + [ + sum( + [ + int(ex.match(fname.split("/")[-1])[2]) + - int(ex.match(fname.split("/")[-1])[1]) + for fname in s + ] + ) + for s in splits + ] + ) + if num_batches == 0: + raise ValueError( + f"Too few batches for training with world size {self.__world_size}" + ) + + return splits[self.__device], num_batches diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py new file mode 100644 index 00000000000..fad986257b2 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/dgl/trainers_dgl.py @@ -0,0 +1,361 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import torch +import torch.distributed as td +import torch.nn.functional as F +from torchmetrics import Accuracy +from trainers import Trainer +import time + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from cugraph.gnn import FeatureStore + + +def get_features(input_nodes, output_nodes, feature_store, key="paper"): + if isinstance(input_nodes, dict): + input_nodes = input_nodes[key] + if isinstance(output_nodes, dict): + output_nodes = output_nodes[key] + + # TODO: Fix below + # Adding based on assumption that cpu features + # and gpu index is not supported yet + + if feature_store.backend == "torch": + input_nodes = input_nodes.to("cpu") + output_nodes = output_nodes.to("cpu") + + x = feature_store.get_data(indices=input_nodes, type_name=key, feat_name="x") + y = feature_store.get_data(indices=output_nodes, type_name=key, feat_name="y") + y = y.reshape((y.shape[0],)) + return x, y + + +def log_batch( + logger: logging.Logger, + iter_i: int, + num_batches: int, + time_forward: int, + time_backward: int, + time_start: int, + loader_time_iter: int, + epoch: int, + rank: int, +): + """ + Logs the current performance of the trainer. + + Parameters + ---------- + logger: logging.Logger + The logger to use for logging the performance details. + iter_i: int + The current training iteration. + num_batches: int + The number of batches processed so far + time_forward: int + The total amount of time for the model forward pass so far + time_backward: int + The total amount of the for the model backwards pass so far + time_start: int + The time at which training was started + loader_time_iter: int + The time taken by the loader in the current iteraiton + epoch: int + The current training epoch + rank: int + The global rank of this worker + + Returns + ------- + None + """ + + time_forward_iter = time_forward / num_batches + time_backward_iter = time_backward / num_batches + total_time_iter = (time.perf_counter() - time_start) / num_batches + logger.info(f"epoch {epoch}, iteration {iter_i}, rank {rank}") + logger.info(f"time forward: {time_forward_iter}") + logger.info(f"time backward: {time_backward_iter}") + logger.info(f"loader time: {loader_time_iter}") + logger.info(f"total time: {total_time_iter}") + + +def train_epoch( + model, + optimizer, + loader, + feature_store, + epoch, + num_classes, + time_d, + logger, + rank, + max_num_batches, +): + """ + Train the model for one epoch. + model: The model to train. + optimizer: The optimizer to use. + loader: The loader to use. + data: cuGraph.gnn.FeatueStore + epoch: The epoch number. + num_classes: The number of classes. + time_d: A dictionary of times. + logger: The logger to use. + rank: Global rank + max_num_batches: Number of batches after which to quit (to avoid hang due to asymmetry) + """ + model = model.train() + time_feature_indexing = time_d["time_feature_indexing"] + time_feature_transfer = time_d["time_feature_transfer"] + time_forward = time_d["time_forward"] + time_backward = time_d["time_backward"] + time_loader = time_d["time_loader"] + + time_start = time.perf_counter() + end_time_backward = time.perf_counter() + + num_batches = 0 + + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + loader_time_iter = time.perf_counter() - end_time_backward + time_loader += loader_time_iter + feature_indexing_time_start = time.perf_counter() + x, y_true = get_features(input_nodes, output_nodes, feature_store=feature_store) + additional_feature_time_end = time.perf_counter() + time_feature_indexing += ( + additional_feature_time_end - feature_indexing_time_start + ) + feature_trasfer_time_start = time.perf_counter() + x = x.to("cuda") + y_true = y_true.to("cuda") + time_feature_transfer += time.perf_counter() - feature_trasfer_time_start + num_batches += 1 + + start_time_forward = time.perf_counter() + y_pred = model( + blocks, + x, + ) + end_time_forward = time.perf_counter() + time_forward += end_time_forward - start_time_forward + + if y_pred.shape[0] > len(y_true): + raise ValueError(f"illegal shape: {y_pred.shape}; {y_true.shape}") + + y_true = y_true[: y_pred.shape[0]] + y_true = F.one_hot( + y_true.to(torch.int64), + num_classes=num_classes, + ).to(torch.float32) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true shape was {y_true.shape} " + f"but y_pred shape was {y_pred.shape} " + f"in iteration {iter_i} " + f"on rank {y_pred.device.index}" + ) + + start_time_backward = time.perf_counter() + loss = F.cross_entropy(y_pred, y_true) + optimizer.zero_grad() + loss.backward() + optimizer.step() + end_time_backward = time.perf_counter() + time_backward += end_time_backward - start_time_backward + + if iter_i % 50 == 0: + log_batch( + logger=logger, + iter_i=iter_i, + num_batches=num_batches, + time_forward=time_forward, + time_backward=time_backward, + time_start=time_start, + loader_time_iter=loader_time_iter, + epoch=epoch, + rank=rank, + ) + + if max_num_batches is not None and iter_i >= max_num_batches: + break + + time_d["time_loader"] += time_loader + time_d["time_feature_indexing"] += time_feature_indexing + time_d["time_feature_transfer"] += time_feature_transfer + time_d["time_forward"] += time_forward + time_d["time_backward"] += time_backward + + return num_batches + + +def get_accuracy( + model: torch.nn.Module, + loader: torch.utils.DataLoader, + feature_store: FeatureStore, + num_classes: int, + max_num_batches: int, +) -> float: + """ + Computes the accuracy given a loader that ouputs evaluation data, the model being evaluated, + the feature store where node features are stored, and the number of output classes. + + Parameters + ---------- + model: torch.nn.Module + The model being evaluated + loader: torch.utils.DataLoader + The loader over evaluation samples + feature_store: cugraph.gnn.FeatureStore + The feature store containing node features + num_classes: int + The number of output classes of the model + max_num_batches: int + The number of batches to iterate for, will quit after reaching this number. + Used to avoid hang due to asymmetric input. + + Returns + ------- + float + The calcuated accuracy, as a percentage. + + """ + + print("Computing accuracy...", flush=True) + acc = Accuracy(task="multiclass", num_classes=num_classes).cuda() + acc_sum = 0.0 + num_batches = 0 + with torch.no_grad(): + for iter_i, (input_nodes, output_nodes, blocks) in enumerate(loader): + x, y_true = get_features( + input_nodes, output_nodes, feature_store=feature_store + ) + x = x.to("cuda") + y_true = y_true.to("cuda") + + out = model(blocks, x) + batch_size = out.shape[0] + acc_sum += acc(out[:batch_size].softmax(dim=-1), y_true[:batch_size]) + num_batches += 1 + + if max_num_batches is not None and iter_i >= max_num_batches: + break + + num_batches = num_batches + + acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32, device="cuda") + td.all_reduce(acc_sum, op=td.ReduceOp.SUM) + nb = torch.tensor(float(num_batches), dtype=torch.float32, device=acc_sum.device) + td.all_reduce(nb, op=td.ReduceOp.SUM) + + acc = acc_sum / nb + + print( + f"Accuracy: {acc * 100.0:.4f}%", + ) + return acc * 100.0 + + +class DGLTrainer(Trainer): + """ + Trainer implementation for node classification in DGL. + """ + + def train(self): + logger = logging.getLogger("DGLTrainer") + time_d = { + "time_loader": 0.0, + "time_feature_indexing": 0.0, + "time_feature_transfer": 0.0, + "time_forward": 0.0, + "time_backward": 0.0, + } + total_batches = 0 + for epoch in range(self.num_epochs): + start_time = time.perf_counter() + self.model.train() + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + loader, max_num_batches = self.get_loader(epoch=epoch, stage="train") + num_batches = train_epoch( + model=self.model, + optimizer=self.optimizer, + loader=loader, + feature_store=self.data, + num_classes=self.dataset.num_labels, + epoch=epoch, + time_d=time_d, + logger=logger, + rank=self.rank, + max_num_batches=max_num_batches, + ) + total_batches = total_batches + num_batches + end_time = time.perf_counter() + epoch_time_taken = end_time - start_time + print( + f"RANK: {self.rank} Total time taken for training epoch {epoch} = {epoch_time_taken}", + flush=True, + ) + print("---" * 30) + td.barrier() + self.model.eval() + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + # test + loader, max_num_batches = self.get_loader(epoch=epoch, stage="test") + test_acc = get_accuracy( + model=self.model.module, + loader=loader, + feature_store=self.data, + num_classes=self.dataset.num_labels, + max_num_batches=max_num_batches, + ) + print(f"Accuracy: {test_acc:.4f}%") + + # val: + self.model.eval() + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + loader, max_num_batches = self.get_loader(epoch=epoch, stage="val") + val_acc = get_accuracy( + model=self.model.module, + loader=loader, + feature_store=self.data, + num_classes=self.dataset.num_labels, + max_num_batches=max_num_batches, + ) + print(f"Validation Accuracy: {val_acc:.4f}%") + + val_acc = float(val_acc) + stats = { + "Accuracy": val_acc, + "# Batches": total_batches, + "Loader Time": time_d["time_loader"], + "Feature Time": time_d["time_feature_indexing"] + + time_d["time_feature_transfer"], + "Forward Time": time_d["time_forward"], + "Backward Time": time_d["time_backward"], + } + return stats + + +# For native DGL training, see benchmarks/cugraph-dgl/scale-benchmarks diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py index 71151e9ba59..833322deffe 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py @@ -13,41 +13,84 @@ from .trainers_pyg import PyGTrainer from models.pyg import CuGraphSAGE +from datasets import Dataset import torch import numpy as np from torch.nn.parallel import DistributedDataParallel as ddp +from torch.distributed.optim import ZeroRedundancyOptimizer from cugraph.gnn import FeatureStore from cugraph_pyg.data import CuGraphStore from cugraph_pyg.loader import BulkSampleLoader import os +import re class PyGCuGraphTrainer(PyGTrainer): + """ + Trainer implementation for cuGraph-PyG that supports + WholeGraph as a feature store. + """ + def __init__( self, - dataset, - model="GraphSAGE", - device=0, - rank=0, - world_size=1, - num_epochs=1, - sample_dir=".", + dataset: Dataset, + model: str = "GraphSAGE", + device: int = 0, + rank: int = 0, + world_size: int = 1, + gpus_per_node: int = 1, + num_epochs: int = 1, + sample_dir: str = ".", + backend: str = "torch", **kwargs, ): + """ + Parameters + ---------- + dataset: Dataset + The dataset to train on. + model: str + The model to use for training. + Currently only "GraphSAGE" is supported. + device: int, default=0 + The CUDA device to use. + rank: int, default=0 + The global rank of the worker this trainer is assigned to. + world_size: int, default=1 + The number of workers in the world. + num_epochs: int, default=1 + The number of training epochs to run. + sample_dir: str, default="." + The directory where samples generated by the bulk sampler + are stored. + backend: str, default="torch" + The feature store backend to be used by the cuGraph Feature Store. + Defaults to "torch". Options are "torch" and "wholegraph" + kwargs + Keyword arguments to pass to the loader. + """ + + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + logger.info("creating trainer") self.__data = None self.__device = device self.__rank = rank self.__world_size = world_size + self.__gpus_per_node = gpus_per_node self.__num_epochs = num_epochs self.__dataset = dataset self.__sample_dir = sample_dir self.__loader_kwargs = kwargs self.__model = self.get_model(model) + self.__backend = backend self.__optimizer = None + logger.info("created trainer") @property def rank(self): @@ -64,8 +107,11 @@ def dataset(self): @property def optimizer(self): if self.__optimizer is None: - self.__optimizer = torch.optim.Adam( - self.model.parameters(), lr=0.01, weight_decay=0.0005 + self.__optimizer = ZeroRedundancyOptimizer( + self.model.parameters(), + lr=0.01, + weight_decay=0.0005, + optimizer_class=torch.optim.Adam, ) return self.__optimizer @@ -73,7 +119,7 @@ def optimizer(self): def num_epochs(self) -> int: return self.__num_epochs - def get_loader(self, epoch: int = 0, stage="train") -> int: + def get_loader(self, epoch: int = 0, stage="train"): import logging logger = logging.getLogger("PyGCuGraphTrainer") @@ -81,22 +127,25 @@ def get_loader(self, epoch: int = 0, stage="train") -> int: logger.info(f"getting loader for epoch {epoch}, {stage} stage") # TODO support online sampling - if stage == "val": - path = os.path.join(self.__sample_dir, "val", "samples") - else: + if stage == "train": path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + elif stage in ["test", "val"]: + path = os.path.join(self.__sample_dir, stage, "samples") + else: + raise ValueError(f"invalid stage {stage}") + input_files, num_batches = self.get_input_files(path, epoch=epoch, stage=stage) loader = BulkSampleLoader( self.data, self.data, None, # FIXME get input nodes properly directory=path, - input_files=self.get_input_files(path, epoch=epoch, stage=stage), + input_files=input_files, **self.__loader_kwargs, ) logger.info(f"got loader successfully on rank {self.rank}") - return loader + return loader, num_batches @property def data(self): @@ -106,36 +155,73 @@ def data(self): logger.info("getting data") if self.__data is None: - # FIXME wholegraph - fs = FeatureStore(backend="torch") + if self.__backend == "wholegraph": + logger.info("using wholegraph backend") + fs = FeatureStore( + backend="wholegraph", + wg_type="chunked", + wg_location="cpu", + ) + else: + fs = FeatureStore(backend=self.__backend) num_nodes_dict = {} + if self.__backend == "wholegraph": + from pylibwholegraph.torch.initialize import get_global_communicator + + wm_comm = get_global_communicator() + wm_comm.barrier() + for node_type, x in self.__dataset.x_dict.items(): logger.debug(f"getting x for {node_type}") fs.add_data(x, node_type, "x") num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + if self.__backend == "wholegraph": + wm_comm.barrier() for node_type, y in self.__dataset.y_dict.items(): logger.debug(f"getting y for {node_type}") - fs.add_data(y, node_type, "y") + if self.__backend == "wholegraph": + logger.info("using wholegraph backend") + fs.add_data(y, node_type, "y") + wm_comm.barrier() + else: + y = y.cuda() + y = y.reshape((y.shape[0], 1)) + fs.add_data(y, node_type, "y") + + """ for node_type, train in self.__dataset.train_dict.items(): logger.debug(f"getting train for {node_type}") + train = train.reshape((train.shape[0], 1)) + if self.__backend != "wholegraph": + train = train.cuda() fs.add_data(train, node_type, "train") for node_type, test in self.__dataset.test_dict.items(): logger.debug(f"getting test for {node_type}") + test = test.reshape((test.shape[0], 1)) + if self.__backend != "wholegraph": + test = test.cuda() fs.add_data(test, node_type, "test") for node_type, val in self.__dataset.val_dict.items(): logger.debug(f"getting val for {node_type}") + val = val.reshape((val.shape[0], 1)) + if self.__backend != "wholegraph": + val = val.cuda() fs.add_data(val, node_type, "val") + """ # TODO support online sampling if the edge index is provided num_edges_dict = self.__dataset.edge_index_dict if not isinstance(list(num_edges_dict.values())[0], int): num_edges_dict = {k: len(v) for k, v in num_edges_dict} + if self.__backend == "wholegraph": + wm_comm.barrier() + self.__data = CuGraphStore( fs, num_edges_dict, @@ -147,14 +233,28 @@ def data(self): return self.__data def get_model(self, name="GraphSAGE"): + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + + logger.info("Creating model...") + if name != "GraphSAGE": raise ValueError("only GraphSAGE is currently supported") + logger.info("getting input features...") num_input_features = self.__dataset.num_input_features + + logger.info("getting output features...") num_output_features = self.__dataset.num_labels + + logger.info("getting num neighbors...") num_layers = len(self.__loader_kwargs["num_neighbors"]) + logger.info("Got input features, output features, num neighbors") + with torch.cuda.device(self.__device): + logger.info("Constructing CuGraphSAGE model...") model = ( CuGraphSAGE( in_channels=num_input_features, @@ -166,8 +266,10 @@ def get_model(self, name="GraphSAGE"): .to(self.__device) ) + logger.info("Parallelizing model with ddp...") model = ddp(model, device_ids=[self.__device]) - print("done creating model") + + logger.info("done creating model") return model @@ -175,10 +277,28 @@ def get_input_files(self, path, epoch=0, stage="train"): file_list = np.array(os.listdir(path)) file_list.sort() - if stage == "train": - splits = np.array_split(file_list, self.__world_size) - np.random.seed(epoch) - np.random.shuffle(splits) - return splits[self.rank] - else: - return file_list + np.random.seed(epoch) + np.random.shuffle(file_list) + + splits = np.array_split(file_list, self.__gpus_per_node) + + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + + split = splits[self.__device] + logger.info(f"rank {self.__rank} input files: {str(split)}") + + ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet") + num_batches = min( + [ + sum([int(ex.match(fname)[2]) - int(ex.match(fname)[1]) for fname in s]) + for s in splits + ] + ) + if num_batches == 0: + raise ValueError( + f"Too few batches for training with world size {self.__world_size}" + ) + + return split, num_batches diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py index bddd6ae2644..d6205901b68 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py @@ -33,7 +33,12 @@ import time -def pyg_num_workers(world_size): +def pyg_num_workers(world_size: int) -> int: + """ + Calculates the number of workers for the + loader in PyG by calling sched_getaffinity. + """ + num_workers = None if hasattr(os, "sched_getaffinity"): try: @@ -45,14 +50,80 @@ def pyg_num_workers(world_size): return int(num_workers) +def calc_accuracy( + loader: NeighborLoader, + max_num_batches: int, + model: torch.nn.Module, + num_classes: int, +) -> float: + """ + Evaluates the accuracy of a model given a loader over evaluation samples. + + Parameters + ---------- + loader: NeighborLoader + The loader over evaluation samples. + model: torch.nn.Module + The model being evaluated. + num_classes: int + The number of output classes of the model. + + Returns + ------- + The calculated accuracy as a fraction. + """ + + from torchmetrics import Accuracy + + acc = Accuracy(task="multiclass", num_classes=num_classes).cuda() + + acc_sum = 0.0 + num_batches = 0 + with torch.no_grad(): + for i, batch in enumerate(loader): + num_sampled_nodes = sum( + [torch.as_tensor(n) for n in batch.num_sampled_nodes_dict.values()] + ) + num_sampled_edges = sum( + [torch.as_tensor(e) for e in batch.num_sampled_edges_dict.values()] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long).reshape((batch.y.shape[0],)) + + out = model( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc(out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + num_batches += 1 + + if max_num_batches is not None and i >= max_num_batches: + break + + acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32, device="cuda") + td.all_reduce(acc_sum, op=td.ReduceOp.SUM) + nb = torch.tensor(float(num_batches), dtype=torch.float32, device=acc_sum.device) + td.all_reduce(nb, op=td.ReduceOp.SUM) + + return acc_sum / nb + + class PyGTrainer(Trainer): + """ + Trainer implementation for node classification in PyG. + """ + def train(self): import logging logger = logging.getLogger("PyGTrainer") logger.info("Entered train loop") - total_loss = 0.0 num_batches = 0 time_forward = 0.0 @@ -62,19 +133,32 @@ def train(self): start_time = time.perf_counter() end_time_backward = start_time + num_layers = len(self.model.module.convs) + for epoch in range(self.num_epochs): with td.algorithms.join.Join( - [self.model], divide_by_initial_world_size=False + [self.model, self.optimizer], divide_by_initial_world_size=False ): self.model.train() - for iter_i, data in enumerate( - self.get_loader(epoch=epoch, stage="train") - ): + loader, max_num_batches = self.get_loader(epoch=epoch, stage="train") + + max_num_batches = torch.tensor([max_num_batches], device="cuda") + torch.distributed.all_reduce( + max_num_batches, op=torch.distributed.ReduceOp.MIN + ) + max_num_batches = int(max_num_batches[0]) + + for iter_i, data in enumerate(loader): loader_time_iter = time.perf_counter() - end_time_backward time_loader += loader_time_iter time_feature_transfer_start = time.perf_counter() + if len(data.edge_index_dict[("paper", "cites", "paper")][0]) < 3: + logger.error(f"Invalid edge index in iteration {iter_i}") + data = old_data + + old_data = data num_sampled_nodes = sum( [ torch.as_tensor(n) @@ -89,7 +173,6 @@ def train(self): ) # FIXME find a way to get around this and not have to call extend_tensor - num_layers = len(self.model.module.convs) num_sampled_nodes = extend_tensor(num_sampled_nodes, num_layers + 1) num_sampled_edges = extend_tensor(num_sampled_edges, num_layers) @@ -118,7 +201,12 @@ def train(self): ) logger.info(f"total time: {total_time_iter}") + # from pynvml.smi import nvidia_smi + # mem_info = nvidia_smi.getInstance().DeviceQuery('memory.free, memory.total')['gpu'][self.rank % 8]['fb_memory_usage'] + # logger.info(f"rank {self.rank} memory: {mem_info}") + y_true = data.y + y_true = y_true.reshape((y_true.shape[0],)) x = data.x.to(torch.float32) start_time_forward = time.perf_counter() @@ -160,101 +248,48 @@ def train(self): self.optimizer.zero_grad() loss.backward() self.optimizer.step() - total_loss += loss.item() end_time_backward = time.perf_counter() time_backward += end_time_backward - start_time_backward - end_time = time.perf_counter() - - # test - from torchmetrics import Accuracy + if max_num_batches is not None and iter_i >= max_num_batches: + break - acc = Accuracy( - task="multiclass", num_classes=self.dataset.num_labels - ).cuda() + end_time = time.perf_counter() + """ + logger.info("Entering test stage...") with td.algorithms.join.Join( [self.model], divide_by_initial_world_size=False ): self.model.eval() - if self.rank == 0: - acc_sum = 0.0 - with torch.no_grad(): - for i, batch in enumerate( - self.get_loader(epoch=epoch, stage="test") - ): - num_sampled_nodes = sum( - [ - torch.as_tensor(n) - for n in batch.num_sampled_nodes_dict.values() - ] - ) - num_sampled_edges = sum( - [ - torch.as_tensor(e) - for e in batch.num_sampled_edges_dict.values() - ] - ) - batch_size = num_sampled_nodes[0] - - batch = batch.to_homogeneous().cuda() - - batch.y = batch.y.to(torch.long) - out = self.model.module( - batch.x, - batch.edge_index, - num_sampled_nodes, - num_sampled_edges, - ) - acc_sum += acc( - out[:batch_size].softmax(dim=-1), batch.y[:batch_size] - ) - print( - f"Accuracy: {acc_sum/(i) * 100.0:.4f}%", - ) + loader, max_num_batches = self.get_loader(epoch=epoch, stage="test") + num_classes = self.dataset.num_labels - td.barrier() + acc = calc_accuracy( + loader, max_num_batches, self.model.module, num_classes + ) - with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): - self.model.eval() if self.rank == 0: - acc_sum = 0.0 - with torch.no_grad(): - for i, batch in enumerate( - self.get_loader(epoch=epoch, stage="val") - ): - num_sampled_nodes = sum( - [ - torch.as_tensor(n) - for n in batch.num_sampled_nodes_dict.values() - ] - ) - num_sampled_edges = sum( - [ - torch.as_tensor(e) - for e in batch.num_sampled_edges_dict.values() - ] - ) - batch_size = num_sampled_nodes[0] - - batch = batch.to_homogeneous().cuda() - - batch.y = batch.y.to(torch.long) - out = self.model.module( - batch.x, - batch.edge_index, - num_sampled_nodes, - num_sampled_edges, - ) - acc_sum += acc( - out[:batch_size].softmax(dim=-1), batch.y[:batch_size] - ) print( - f"Validation Accuracy: {acc_sum/(i) * 100.0:.4f}%", + f"Accuracy: {acc * 100.0:.4f}%", ) + """ + + """ + logger.info("Entering validation stage") + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + self.model.eval() + loader, max_num_batches = self.get_loader(epoch=epoch, stage="val") + num_classes = self.dataset.num_labels + acc = calc_accuracy(loader, max_num_batches, self.model.module, num_classes) + + if self.rank == 0: + print( + f"Validation Accuracy: {acc * 100.0:.4f}%", + ) + """ stats = { - "Accuracy": float(acc_sum / (i) * 100.0) if self.rank == 0 else 0.0, "# Batches": num_batches, "Loader Time": time_loader, "Feature Transfer Time": time_feature_transfer, @@ -265,6 +300,12 @@ def train(self): class PyGNativeTrainer(PyGTrainer): + """ + Trainer implementation for native PyG + training using HeteroData as the graph and feature + store and NeighborLoader as the loader. + """ + def __init__( self, dataset, @@ -403,7 +444,7 @@ def get_loader(self, epoch: int = 0, stage="train"): ) logger.info("done creating loader") - return loader + return loader, None def get_model(self, name="GraphSAGE"): if name != "GraphSAGE": diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py index 815fd30d8eb..f6fe38fe9f8 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -63,6 +63,10 @@ def __getitem__(self, idx: int): fn, batch_offset = self._batch_to_fn_d[idx] if fn != self._current_batch_fn: + # Remove current batches to free up memory + # before loading new batches + if hasattr(self, "_current_batches"): + del self._current_batches if self.sparse_format == "csc": df = _load_sampled_file(dataset_obj=self, fn=fn, skip_rename=True) self._current_batches = ( diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index 05d540b7c45..df16fc9fd6c 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -1083,13 +1083,12 @@ def _get_tensor(self, attr: CuGraphTensorAttr) -> TensorType: idx = attr.index if idx is not None: - if feature_backend == "torch": + if feature_backend in ["torch", "wholegraph"]: if not isinstance(idx, torch.Tensor): raise TypeError( f"Type {type(idx)} invalid" f" for feature store backend {feature_backend}" ) - idx = idx.cpu() elif feature_backend == "numpy": # allow feature indexing through cupy arrays if isinstance(idx, cupy.ndarray): @@ -1244,5 +1243,77 @@ def _infer_unspecified_attr(self, attr: CuGraphTensorAttr) -> CuGraphTensorAttr: return attr + def filter( + self, + format: str, + node_dict: Dict[str, torch.Tensor], + row_dict: Dict[str, torch.Tensor], + col_dict: Dict[str, torch.Tensor], + edge_dict: Dict[str, Tuple[torch.Tensor]], + ) -> torch_geometric.data.HeteroData: + """ + Parameters + ---------- + format: str + COO or CSC + node_dict: Dict[str, torch.Tensor] + IDs of nodes in original store being outputted + row_dict: Dict[str, torch.Tensor] + Renumbered output edge index row + col_dict: Dict[str, torch.Tensor] + Renumbered output edge index column + edge_dict: Dict[str, Tuple[torch.Tensor]] + Currently unused original edge mapping + """ + data = torch_geometric.data.HeteroData() + + # TODO use torch_geometric.EdgeIndex in release 24.04 (Issue #4051) + for attr in self.get_all_edge_attrs(): + key = attr.edge_type + if key in row_dict and key in col_dict: + if format == "CSC": + data.put_edge_index( + (row_dict[key], col_dict[key]), + edge_type=key, + layout="csc", + is_sorted=True, + ) + else: + data[key].edge_index = torch.stack( + [ + row_dict[key], + col_dict[key], + ], + dim=0, + ) + + required_attrs = [] + # To prevent copying multiple times, we use a cache; + # the original node_dict serves as the gpu cache if needed + node_dict_cpu = {} + for attr in self.get_all_tensor_attrs(): + if attr.group_name in node_dict: + device = self.__features.get_storage(attr.group_name, attr.attr_name) + attr.index = node_dict[attr.group_name] + if not isinstance(attr.index, torch.Tensor): + raise ValueError("Node index must be a tensor!") + if attr.index.is_cuda and device == "cpu": + if attr.group_name not in node_dict_cpu: + node_dict_cpu[attr.group_name] = attr.index.cpu() + attr.index = node_dict_cpu[attr.group_name] + elif attr.index.is_cpu and device == "cuda": + node_dict_cpu[attr.group_name] = attr.index + node_dict[attr.group_name] = attr.index.cuda() + attr.index = node_dict[attr.group_name] + + required_attrs.append(attr) + data[attr.group_name].num_nodes = attr.index.size(0) + + tensors = self.multi_get_tensor(required_attrs) + for i, attr in enumerate(required_attrs): + data[attr.group_name][attr.attr_name] = tensors[i] + + return data + def __len__(self): return len(self.get_all_tensor_attrs()) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index bcfaf579820..55c9e9b3329 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -28,7 +28,6 @@ _sampler_output_from_sampling_results_heterogeneous, _sampler_output_from_sampling_results_homogeneous_csr, _sampler_output_from_sampling_results_homogeneous_coo, - filter_cugraph_store_csc, ) from typing import Union, Tuple, Sequence, List, Dict @@ -454,31 +453,20 @@ def __next__(self): start_time_feature = perf_counter() # Create a PyG HeteroData object, loading the required features - if self.__coo: - pyg_filter_fn = ( - torch_geometric.loader.utils.filter_custom_hetero_store - if hasattr(torch_geometric.loader.utils, "filter_custom_hetero_store") - else torch_geometric.loader.utils.filter_custom_store - ) - out = pyg_filter_fn( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) - else: - out = filter_cugraph_store_csc( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) + if self.__graph_store != self.__feature_store: + # TODO Possibly support this if there is an actual use case + raise ValueError("Separate graph and feature stores currently unsupported") + + out = self.__graph_store.filter( + "COO" if self.__coo else "CSC", + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) # Account for CSR format in cuGraph vs. CSC format in PyG + # TODO deprecate and remove this functionality if self.__coo and self.__graph_store.order == "CSC": for edge_type in out.edge_index_dict: out[edge_type].edge_index = out[edge_type].edge_index.flip(dims=[0]) diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 65cb63d25e0..ffab54efe08 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -411,6 +411,10 @@ def filter_cugraph_store_csc( col_dict: Dict[str, torch.Tensor], edge_dict: Dict[str, Tuple[torch.Tensor]], ) -> torch_geometric.data.HeteroData: + """ + Deprecated + """ + data = torch_geometric.data.HeteroData() for attr in graph_store.get_all_edge_attrs(): diff --git a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py index 77a53882fc4..f0186220114 100644 --- a/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py +++ b/python/cugraph/cugraph/gnn/feature_storage/feat_storage.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -168,19 +168,54 @@ def get_data( feat, wgth.WholeMemoryEmbedding ): indices_tensor = ( - indices + indices.cuda() if isinstance(indices, torch.Tensor) else torch.as_tensor(indices, device="cuda") ) return feat.gather(indices_tensor) - else: - return feat[indices] + elif not isinstance(torch, MissingModule) and isinstance(feat, torch.Tensor): + if indices is not None: + if not isinstance(indices, torch.Tensor): + indices = torch.as_tensor(indices) + + if feat.is_cpu and indices.is_cuda: + # TODO maybe add a warning here + indices = indices.cpu() + return feat[indices] def get_feature_list(self) -> list[str]: return {feat_name: feats.keys() for feat_name, feats in self.fd.items()} + def get_storage(self, type_name: str, feat_name: str) -> str: + """ + Returns where the data is stored (cuda, cpu). + Note: will return "cuda" for data managed by CUDA, even if + it is in host memory. + + Parameters + ---------- + type_name : str + The node-type/edge-type to store data + feat_name: + The feature name to retrieve data for + + Returns + ------- + "cuda" for data managed by CUDA, otherwise "CPU". + """ + feat = self.fd[feat_name][type_name] + if not isinstance(wgth, MissingModule) and isinstance( + feat, wgth.WholeMemoryEmbedding + ): + return "cuda" + elif isinstance(feat, torch.Tensor): + return "cpu" if feat.is_cpu else "cuda" + else: + return "cpu" + @staticmethod def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs): + # TODO (Issue #4078) support casting WG tensors to numpy and torch if backend == "numpy": if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)): return _cast_to_numpy_ar(feat_obj.values, **kwargs) @@ -192,6 +227,8 @@ def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs): else: return _cast_to_torch_tensor(feat_obj, **kwargs) elif backend == "wholegraph": + if isinstance(feat_obj, wgth.WholeMemoryEmbedding): + return feat_obj return _get_wg_embedding(feat_obj, **kwargs) From c4a531d7b8daa26cec8bde715a8fd1347b764336 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Mon, 11 Mar 2024 13:52:34 -0400 Subject: [PATCH 4/7] Mark kernels as internal (#4098) Downstream consumers of static built versions of RAPIDS C++ projects have encountered runtime issues due to multiple instances of the same kernel existing in different DSOs. To resolve this issue we need to ensure that all CUDA kernels in all RAPIDS libraries need to be have internal linkage ( static for projects using whole compilation, __attribute__((visibility("hidden"))) for header libraries / separable compilation ). This updates all cugraph kernels to have internal linkage, and adds a CI job to verify that no new kernels are added with external linkage. Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Brad Rees (https://github.com/BradReesWork) - Naim (https://github.com/naimnv) - Seunghwa Kang (https://github.com/seunghwak) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/cugraph/pull/4098 --- .github/workflows/pr.yaml | 9 ++ .github/workflows/test.yaml | 10 ++ .../detail/decompress_edge_partition.cuh | 4 +- .../include/hash/helper_functions.cuh | 4 +- cpp/libcugraph_etl/src/renumbering.cu | 20 +-- cpp/src/community/legacy/ecg.cu | 2 +- cpp/src/components/legacy/weak_cc.cuh | 30 ++-- cpp/src/layout/legacy/bh_kernels.cuh | 136 +++++++++--------- cpp/src/layout/legacy/exact_repulsion.cuh | 16 +-- cpp/src/layout/legacy/fa2_kernels.cuh | 98 ++++++------- .../detail/extract_transform_v_frontier_e.cuh | 6 +- ...r_v_random_select_transform_outgoing_e.cuh | 4 +- ...v_transform_reduce_incoming_outgoing_e.cuh | 8 +- cpp/src/prims/transform_e.cuh | 2 +- cpp/src/prims/transform_reduce_e.cuh | 8 +- .../transform_reduce_e_by_src_dst_key.cuh | 8 +- cpp/src/structure/graph_view_impl.cuh | 4 +- .../traversal/od_shortest_distances_impl.cuh | 2 +- cpp/src/utilities/eidecl_graph_utils.hpp | 11 +- cpp/src/utilities/eidir_graph_utils.hpp | 17 +-- cpp/src/utilities/graph_utils.cuh | 49 ++++--- cpp/src/utilities/path_retrieval.cu | 14 +- 22 files changed, 243 insertions(+), 219 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 7f0b95e3573..7c8c9973462 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -15,6 +15,7 @@ jobs: - checks - conda-cpp-build - conda-cpp-tests + - conda-cpp-checks - conda-notebook-tests - conda-python-build - conda-python-tests @@ -52,6 +53,14 @@ jobs: uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 with: build_type: pull-request + conda-cpp-checks: + needs: conda-cpp-build + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.04 + with: + build_type: pull-request + enable_check_symbols: true + symbol_exclusions: (cugraph::ops|hornet|void writeEdgeCountsKernel|void markUniqueOffsetsKernel) conda-python-build: needs: conda-cpp-build secrets: inherit diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 32fb2d62b29..0bd095bfa94 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -14,6 +14,16 @@ on: type: string jobs: + conda-cpp-checks: + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.04 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + enable_check_symbols: true + symbol_exclusions: (cugraph::ops|hornet|void writeEdgeCountsKernel|void markUniqueOffsetsKernel) conda-cpp-tests: secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 diff --git a/cpp/include/cugraph/detail/decompress_edge_partition.cuh b/cpp/include/cugraph/detail/decompress_edge_partition.cuh index dad5ce77e45..6b974a326dd 100644 --- a/cpp/include/cugraph/detail/decompress_edge_partition.cuh +++ b/cpp/include/cugraph/detail/decompress_edge_partition.cuh @@ -44,7 +44,7 @@ namespace detail { int32_t constexpr decompress_edge_partition_block_size = 1024; template -__global__ void decompress_to_edgelist_mid_degree( +__global__ static void decompress_to_edgelist_mid_degree( edge_partition_device_view_t edge_partition, vertex_t major_range_first, vertex_t major_range_last, @@ -74,7 +74,7 @@ __global__ void decompress_to_edgelist_mid_degree( } template -__global__ void decompress_to_edgelist_high_degree( +__global__ static void decompress_to_edgelist_high_degree( edge_partition_device_view_t edge_partition, vertex_t major_range_first, vertex_t major_range_last, diff --git a/cpp/libcugraph_etl/include/hash/helper_functions.cuh b/cpp/libcugraph_etl/include/hash/helper_functions.cuh index db377f938d2..8a11867f7e2 100644 --- a/cpp/libcugraph_etl/include/hash/helper_functions.cuh +++ b/cpp/libcugraph_etl/include/hash/helper_functions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022, NVIDIA CORPORATION. + * Copyright (c) 2017-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -131,7 +131,7 @@ __forceinline__ __device__ void store_pair_vectorized(pair_type* __restrict__ co } template -__global__ void init_hashtbl(value_type* __restrict__ const hashtbl_values, +__global__ static void init_hashtbl(value_type* __restrict__ const hashtbl_values, const size_type n, const key_type key_val, const elem_type elem_val) diff --git a/cpp/libcugraph_etl/src/renumbering.cu b/cpp/libcugraph_etl/src/renumbering.cu index 08759702ab4..1cbeeeeea05 100644 --- a/cpp/libcugraph_etl/src/renumbering.cu +++ b/cpp/libcugraph_etl/src/renumbering.cu @@ -270,7 +270,7 @@ __device__ __inline__ int32_t validate_ht_col_insert(volatile int32_t* ptr_col) return col; } -__global__ void concat_and_create_histogram(int8_t* col_1, +__global__ static void concat_and_create_histogram(int8_t* col_1, int32_t* offset_1, int8_t* col_2, int32_t* offset_2, @@ -349,7 +349,7 @@ __global__ void concat_and_create_histogram(int8_t* col_1, } } -__global__ void concat_and_create_histogram_2(int8_t* col_1, +__global__ static void concat_and_create_histogram_2(int8_t* col_1, int32_t* offset_1, int8_t* col_2, int32_t* offset_2, @@ -452,7 +452,7 @@ __global__ void concat_and_create_histogram_2(int8_t* col_1, } template -__global__ void set_src_vertex_idx(int8_t* col_1, +__global__ static void set_src_vertex_idx(int8_t* col_1, int32_t* offset_1, int8_t* col_2, int32_t* offset_2, @@ -509,7 +509,7 @@ __global__ void set_src_vertex_idx(int8_t* col_1, } template -__global__ void set_dst_vertex_idx(int8_t* col_1, +__global__ static void set_dst_vertex_idx(int8_t* col_1, int32_t* offset_1, int8_t* col_2, int32_t* offset_2, @@ -585,7 +585,7 @@ __global__ void set_dst_vertex_idx(int8_t* col_1, } } -__global__ void create_mapping_histogram(uint32_t* hash_value, +__global__ static void create_mapping_histogram(uint32_t* hash_value, str_hash_value* payload, cudf_map_type hash_map, accum_type count) @@ -595,7 +595,7 @@ __global__ void create_mapping_histogram(uint32_t* hash_value, if (idx < count) { auto it = hash_map.insert(thrust::make_pair(hash_value[idx], payload[idx])); } } -__global__ void assign_histogram_idx(cudf_map_type cuda_map_obj, +__global__ static void assign_histogram_idx(cudf_map_type cuda_map_obj, size_t slot_count, str_hash_value* key, uint32_t* value, @@ -621,7 +621,7 @@ __global__ void assign_histogram_idx(cudf_map_type cuda_map_obj, } } -__global__ void set_vertex_indices(str_hash_value* ht_value_payload, accum_type count) +__global__ static void set_vertex_indices(str_hash_value* ht_value_payload, accum_type count) { accum_type tid = threadIdx.x + blockIdx.x * blockDim.x; // change count_ to renumber_idx @@ -630,7 +630,7 @@ __global__ void set_vertex_indices(str_hash_value* ht_value_payload, accum_type } } -__global__ void set_output_col_offsets(str_hash_value* row_col_pair, +__global__ static void set_output_col_offsets(str_hash_value* row_col_pair, int32_t* out_col1_offset, int32_t* out_col2_offset, int dst_pair_match, @@ -653,7 +653,7 @@ __global__ void set_output_col_offsets(str_hash_value* row_col_pair, } } -__global__ void offset_buffer_size_comp(int32_t* out_col1_length, +__global__ static void offset_buffer_size_comp(int32_t* out_col1_length, int32_t* out_col2_length, int32_t* out_col1_offsets, int32_t* out_col2_offsets, @@ -673,7 +673,7 @@ __global__ void offset_buffer_size_comp(int32_t* out_col1_length, } } -__global__ void select_unrenumber_string(str_hash_value* idx_to_col_row, +__global__ static void select_unrenumber_string(str_hash_value* idx_to_col_row, int32_t total_elements, int8_t* src_col1, int8_t* src_col2, diff --git a/cpp/src/community/legacy/ecg.cu b/cpp/src/community/legacy/ecg.cu index d93a4446faa..b2ad79204ed 100644 --- a/cpp/src/community/legacy/ecg.cu +++ b/cpp/src/community/legacy/ecg.cu @@ -52,7 +52,7 @@ binsearch_maxle(const IndexType* vec, const IndexType val, IndexType low, IndexT // FIXME: This shouldn't need to be a custom kernel, this // seems like it should just be a thrust::transform template -__global__ void match_check_kernel( +__global__ static void match_check_kernel( IdxT size, IdxT num_verts, IdxT* offsets, IdxT* indices, IdxT* parts, ValT* weights) { IdxT tid = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/cpp/src/components/legacy/weak_cc.cuh b/cpp/src/components/legacy/weak_cc.cuh index ad9aa773590..f4254e2d55d 100644 --- a/cpp/src/components/legacy/weak_cc.cuh +++ b/cpp/src/components/legacy/weak_cc.cuh @@ -59,15 +59,15 @@ class WeakCCState { }; template -__global__ void weak_cc_label_device(vertex_t* labels, - edge_t const* offsets, - vertex_t const* indices, - edge_t nnz, - bool* fa, - bool* xa, - bool* m, - vertex_t startVertexId, - vertex_t batchSize) +__global__ static void weak_cc_label_device(vertex_t* labels, + edge_t const* offsets, + vertex_t const* indices, + edge_t nnz, + bool* fa, + bool* xa, + bool* m, + vertex_t startVertexId, + vertex_t batchSize) { vertex_t tid = threadIdx.x + blockIdx.x * TPB_X; if (tid < batchSize) { @@ -118,11 +118,11 @@ __global__ void weak_cc_label_device(vertex_t* labels, } template -__global__ void weak_cc_init_label_kernel(vertex_t* labels, - vertex_t startVertexId, - vertex_t batchSize, - vertex_t MAX_LABEL, - Lambda filter_op) +__global__ static void weak_cc_init_label_kernel(vertex_t* labels, + vertex_t startVertexId, + vertex_t batchSize, + vertex_t MAX_LABEL, + Lambda filter_op) { /** F1 and F2 in the paper correspond to fa and xa */ /** Cd in paper corresponds to db_cluster */ @@ -134,7 +134,7 @@ __global__ void weak_cc_init_label_kernel(vertex_t* labels, } template -__global__ void weak_cc_init_all_kernel( +__global__ static void weak_cc_init_all_kernel( vertex_t* labels, bool* fa, bool* xa, vertex_t N, vertex_t MAX_LABEL) { vertex_t tid = threadIdx.x + blockIdx.x * TPB_X; diff --git a/cpp/src/layout/legacy/bh_kernels.cuh b/cpp/src/layout/legacy/bh_kernels.cuh index 5b101363314..f6e163ab306 100644 --- a/cpp/src/layout/legacy/bh_kernels.cuh +++ b/cpp/src/layout/legacy/bh_kernels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,9 @@ namespace detail { /** * Intializes the states of objects. This speeds the overall kernel up. */ -__global__ void InitializationKernel(unsigned* restrict limiter, - int* restrict maxdepthd, - float* restrict radiusd) +__global__ static void InitializationKernel(unsigned* restrict limiter, + int* restrict maxdepthd, + float* restrict radiusd) { maxdepthd[0] = 1; limiter[0] = 0; @@ -54,10 +54,10 @@ __global__ void InitializationKernel(unsigned* restrict limiter, /** * Reset root. */ -__global__ void ResetKernel(float* restrict radiusd_squared, - int* restrict bottomd, - const int NNODES, - const float* restrict radiusd) +__global__ static void ResetKernel(float* restrict radiusd_squared, + int* restrict bottomd, + const int NNODES, + const float* restrict radiusd) { radiusd_squared[0] = radiusd[0] * radiusd[0]; // create root node @@ -67,20 +67,21 @@ __global__ void ResetKernel(float* restrict radiusd_squared, /** * Figures the bounding boxes for every point in the embedding. */ -__global__ __launch_bounds__(THREADS1, FACTOR1) void BoundingBoxKernel(int* restrict startd, - int* restrict childd, - int* restrict massd, - float* restrict posxd, - float* restrict posyd, - float* restrict maxxd, - float* restrict maxyd, - float* restrict minxd, - float* restrict minyd, - const int FOUR_NNODES, - const int NNODES, - const int N, - unsigned* restrict limiter, - float* restrict radiusd) +__global__ static __launch_bounds__(THREADS1, + FACTOR1) void BoundingBoxKernel(int* restrict startd, + int* restrict childd, + int* restrict massd, + float* restrict posxd, + float* restrict posyd, + float* restrict maxxd, + float* restrict maxyd, + float* restrict minxd, + float* restrict minyd, + const int FOUR_NNODES, + const int NNODES, + const int N, + unsigned* restrict limiter, + float* restrict radiusd) { float val, minx, maxx, miny, maxy; __shared__ float sminx[THREADS1], smaxx[THREADS1], sminy[THREADS1], smaxy[THREADS1]; @@ -158,9 +159,9 @@ __global__ __launch_bounds__(THREADS1, FACTOR1) void BoundingBoxKernel(int* rest /** * Clear some of the state vectors up. */ -__global__ __launch_bounds__(1024, 1) void ClearKernel1(int* restrict childd, - const int FOUR_NNODES, - const int FOUR_N) +__global__ static __launch_bounds__(1024, 1) void ClearKernel1(int* restrict childd, + const int FOUR_NNODES, + const int FOUR_N) { const int inc = blockDim.x * gridDim.x; int k = (FOUR_N & -32) + threadIdx.x + blockIdx.x * blockDim.x; @@ -175,15 +176,15 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int* restrict childd, /** * Build the actual KD Tree. */ -__global__ __launch_bounds__(THREADS2, - FACTOR2) void TreeBuildingKernel(int* restrict childd, - const float* restrict posxd, - const float* restrict posyd, - const int NNODES, - const int N, - int* restrict maxdepthd, - int* restrict bottomd, - const float* restrict radiusd) +__global__ static __launch_bounds__(THREADS2, + FACTOR2) void TreeBuildingKernel(int* restrict childd, + const float* restrict posxd, + const float* restrict posyd, + const int NNODES, + const int N, + int* restrict maxdepthd, + int* restrict bottomd, + const float* restrict radiusd) { int j, depth; float x, y, r; @@ -296,10 +297,10 @@ __global__ __launch_bounds__(THREADS2, /** * Clean more state vectors. */ -__global__ __launch_bounds__(1024, 1) void ClearKernel2(int* restrict startd, - int* restrict massd, - const int NNODES, - const int* restrict bottomd) +__global__ static __launch_bounds__(1024, 1) void ClearKernel2(int* restrict startd, + int* restrict massd, + const int NNODES, + const int* restrict bottomd) { const int bottom = bottomd[0]; const int inc = blockDim.x * gridDim.x; @@ -317,15 +318,15 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel2(int* restrict startd, /** * Summarize the KD Tree via cell gathering */ -__global__ __launch_bounds__(THREADS3, - FACTOR3) void SummarizationKernel(int* restrict countd, - const int* restrict childd, - volatile int* restrict massd, - float* restrict posxd, - float* restrict posyd, - const int NNODES, - const int N, - const int* restrict bottomd) +__global__ static __launch_bounds__(THREADS3, + FACTOR3) void SummarizationKernel(int* restrict countd, + const int* restrict childd, + volatile int* restrict massd, + float* restrict posxd, + float* restrict posyd, + const int NNODES, + const int N, + const int* restrict bottomd) { bool flag = 0; float cm, px, py; @@ -453,13 +454,14 @@ __global__ __launch_bounds__(THREADS3, /** * Sort the cells */ -__global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel(int* restrict sortd, - const int* restrict countd, - volatile int* restrict startd, - int* restrict childd, - const int NNODES, - const int N, - const int* restrict bottomd) +__global__ static __launch_bounds__(THREADS4, + FACTOR4) void SortKernel(int* restrict sortd, + const int* restrict countd, + volatile int* restrict startd, + int* restrict childd, + const int NNODES, + const int N, + const int* restrict bottomd) { const int bottom = bottomd[0]; const int dec = blockDim.x * gridDim.x; @@ -502,7 +504,7 @@ __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel(int* restrict so /** * Calculate the repulsive forces using the KD Tree */ -__global__ __launch_bounds__( +__global__ static __launch_bounds__( THREADS5, FACTOR5) void RepulsionKernel(/* int *restrict errd, */ const float scaling_ratio, const float theta, @@ -612,18 +614,18 @@ __global__ __launch_bounds__( } } -__global__ __launch_bounds__(THREADS6, - FACTOR6) void apply_forces_bh(float* restrict Y_x, - float* restrict Y_y, - const float* restrict attract_x, - const float* restrict attract_y, - const float* restrict repel_x, - const float* restrict repel_y, - float* restrict old_dx, - float* restrict old_dy, - const float* restrict swinging, - const float speed, - const int n) +__global__ static __launch_bounds__(THREADS6, + FACTOR6) void apply_forces_bh(float* restrict Y_x, + float* restrict Y_y, + const float* restrict attract_x, + const float* restrict attract_y, + const float* restrict repel_x, + const float* restrict repel_y, + float* restrict old_dx, + float* restrict old_dy, + const float* restrict swinging, + const float speed, + const int n) { // For evrery vertex for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { diff --git a/cpp/src/layout/legacy/exact_repulsion.cuh b/cpp/src/layout/legacy/exact_repulsion.cuh index fe895bae6a0..8530202afd5 100644 --- a/cpp/src/layout/legacy/exact_repulsion.cuh +++ b/cpp/src/layout/legacy/exact_repulsion.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,13 +22,13 @@ namespace cugraph { namespace detail { template -__global__ void repulsion_kernel(const float* restrict x_pos, - const float* restrict y_pos, - float* restrict repel_x, - float* restrict repel_y, - const int* restrict mass, - const float scaling_ratio, - const vertex_t n) +__global__ static void repulsion_kernel(const float* restrict x_pos, + const float* restrict y_pos, + float* restrict repel_x, + float* restrict repel_y, + const int* restrict mass, + const float scaling_ratio, + const vertex_t n) { int j = (blockIdx.x * blockDim.x) + threadIdx.x; // for every item in row int i = (blockIdx.y * blockDim.y) + threadIdx.y; // for every row diff --git a/cpp/src/layout/legacy/fa2_kernels.cuh b/cpp/src/layout/legacy/fa2_kernels.cuh index 4f1ce520387..33e7841a380 100644 --- a/cpp/src/layout/legacy/fa2_kernels.cuh +++ b/cpp/src/layout/legacy/fa2_kernels.cuh @@ -23,19 +23,19 @@ namespace cugraph { namespace detail { template -__global__ void attraction_kernel(const vertex_t* restrict row, - const vertex_t* restrict col, - const weight_t* restrict v, - const edge_t e, - const float* restrict x_pos, - const float* restrict y_pos, - float* restrict attract_x, - float* restrict attract_y, - const int* restrict mass, - bool outbound_attraction_distribution, - bool lin_log_mode, - const float edge_weight_influence, - const float coef) +__global__ static void attraction_kernel(const vertex_t* restrict row, + const vertex_t* restrict col, + const weight_t* restrict v, + const edge_t e, + const float* restrict x_pos, + const float* restrict y_pos, + float* restrict attract_x, + float* restrict attract_y, + const int* restrict mass, + bool outbound_attraction_distribution, + bool lin_log_mode, + const float edge_weight_influence, + const float coef) { vertex_t i, src, dst; weight_t weight = 1; @@ -116,13 +116,13 @@ void apply_attraction(const vertex_t* restrict row, } template -__global__ void linear_gravity_kernel(const float* restrict x_pos, - const float* restrict y_pos, - float* restrict attract_x, - float* restrict attract_y, - const int* restrict mass, - const float gravity, - const vertex_t n) +__global__ static void linear_gravity_kernel(const float* restrict x_pos, + const float* restrict y_pos, + float* restrict attract_x, + float* restrict attract_y, + const int* restrict mass, + const float gravity, + const vertex_t n) { // For every node. for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { @@ -136,14 +136,14 @@ __global__ void linear_gravity_kernel(const float* restrict x_pos, } template -__global__ void strong_gravity_kernel(const float* restrict x_pos, - const float* restrict y_pos, - float* restrict attract_x, - float* restrict attract_y, - const int* restrict mass, - const float gravity, - const float scaling_ratio, - const vertex_t n) +__global__ static void strong_gravity_kernel(const float* restrict x_pos, + const float* restrict y_pos, + float* restrict attract_x, + float* restrict attract_y, + const int* restrict mass, + const float gravity, + const float scaling_ratio, + const vertex_t n) { // For every node. for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { @@ -187,16 +187,16 @@ void apply_gravity(const float* restrict x_pos, } template -__global__ void local_speed_kernel(const float* restrict repel_x, - const float* restrict repel_y, - const float* restrict attract_x, - const float* restrict attract_y, - const float* restrict old_dx, - const float* restrict old_dy, - const int* restrict mass, - float* restrict swinging, - float* restrict traction, - const vertex_t n) +__global__ static void local_speed_kernel(const float* restrict repel_x, + const float* restrict repel_y, + const float* restrict attract_x, + const float* restrict attract_y, + const float* restrict old_dx, + const float* restrict old_dy, + const int* restrict mass, + float* restrict swinging, + float* restrict traction, + const vertex_t n) { // For every node. for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { @@ -272,17 +272,17 @@ void adapt_speed(const float jitter_tolerance, } template -__global__ void update_positions_kernel(float* restrict x_pos, - float* restrict y_pos, - const float* restrict repel_x, - const float* restrict repel_y, - const float* restrict attract_x, - const float* restrict attract_y, - float* restrict old_dx, - float* restrict old_dy, - const float* restrict swinging, - const float speed, - const vertex_t n) +__global__ static void update_positions_kernel(float* restrict x_pos, + float* restrict y_pos, + const float* restrict repel_x, + const float* restrict repel_y, + const float* restrict attract_x, + const float* restrict attract_y, + float* restrict old_dx, + float* restrict old_dy, + const float* restrict swinging, + const float speed, + const vertex_t n) { // For every node. for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { diff --git a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh index fc3da3cac07..0b6447f50d9 100644 --- a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh +++ b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh @@ -127,7 +127,7 @@ template -__global__ void extract_transform_v_frontier_e_hypersparse_or_low_degree( +__global__ static void extract_transform_v_frontier_e_hypersparse_or_low_degree( edge_partition_device_view_t edge_partition, @@ -295,7 +295,7 @@ template -__global__ void extract_transform_v_frontier_e_mid_degree( +__global__ static void extract_transform_v_frontier_e_mid_degree( edge_partition_device_view_t edge_partition, @@ -396,7 +396,7 @@ template -__global__ void extract_transform_v_frontier_e_high_degree( +__global__ static void extract_transform_v_frontier_e_high_degree( edge_partition_device_view_t edge_partition, diff --git a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh index 9cb3365116e..5240c49cb80 100644 --- a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh +++ b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh @@ -328,7 +328,7 @@ struct return_value_compute_offset_t { }; template -__global__ void compute_valid_local_nbr_inclusive_sums_mid_local_degree( +__global__ static void compute_valid_local_nbr_inclusive_sums_mid_local_degree( edge_partition_device_view_t edge_partition, edge_partition_edge_property_device_view_t edge_partition_e_mask, raft::device_span edge_partition_frontier_majors, @@ -382,7 +382,7 @@ __global__ void compute_valid_local_nbr_inclusive_sums_mid_local_degree( } template -__global__ void compute_valid_local_nbr_inclusive_sums_high_local_degree( +__global__ static void compute_valid_local_nbr_inclusive_sums_high_local_degree( edge_partition_device_view_t edge_partition, edge_partition_edge_property_device_view_t edge_partition_e_mask, raft::device_span edge_partition_frontier_majors, diff --git a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh index 083487fa5b4..509ab56d3fe 100644 --- a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh +++ b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh @@ -149,7 +149,7 @@ template -__global__ void per_v_transform_reduce_e_hypersparse( +__global__ static void per_v_transform_reduce_e_hypersparse( edge_partition_device_view_t edge_partition, @@ -251,7 +251,7 @@ template -__global__ void per_v_transform_reduce_e_low_degree( +__global__ static void per_v_transform_reduce_e_low_degree( edge_partition_device_view_t edge_partition, @@ -350,7 +350,7 @@ template -__global__ void per_v_transform_reduce_e_mid_degree( +__global__ static void per_v_transform_reduce_e_mid_degree( edge_partition_device_view_t edge_partition, @@ -466,7 +466,7 @@ template -__global__ void per_v_transform_reduce_e_high_degree( +__global__ static void per_v_transform_reduce_e_high_degree( edge_partition_device_view_t edge_partition, diff --git a/cpp/src/prims/transform_e.cuh b/cpp/src/prims/transform_e.cuh index 2cb1a5358b0..9c7670f68d2 100644 --- a/cpp/src/prims/transform_e.cuh +++ b/cpp/src/prims/transform_e.cuh @@ -51,7 +51,7 @@ template -__global__ void transform_e_packed_bool( +__global__ static void transform_e_packed_bool( edge_partition_device_view_t edge_partition, diff --git a/cpp/src/prims/transform_reduce_e.cuh b/cpp/src/prims/transform_reduce_e.cuh index e5855b105ee..43722550c58 100644 --- a/cpp/src/prims/transform_reduce_e.cuh +++ b/cpp/src/prims/transform_reduce_e.cuh @@ -61,7 +61,7 @@ template -__global__ void transform_reduce_e_hypersparse( +__global__ static void transform_reduce_e_hypersparse( edge_partition_device_view_t edge_partition, @@ -153,7 +153,7 @@ template -__global__ void transform_reduce_e_low_degree( +__global__ static void transform_reduce_e_low_degree( edge_partition_device_view_t edge_partition, @@ -242,7 +242,7 @@ template -__global__ void transform_reduce_e_mid_degree( +__global__ static void transform_reduce_e_mid_degree( edge_partition_device_view_t edge_partition, @@ -320,7 +320,7 @@ template -__global__ void transform_reduce_e_high_degree( +__global__ static void transform_reduce_e_high_degree( edge_partition_device_view_t edge_partition, diff --git a/cpp/src/prims/transform_reduce_e_by_src_dst_key.cuh b/cpp/src/prims/transform_reduce_e_by_src_dst_key.cuh index 42203085077..eee0ed03d1c 100644 --- a/cpp/src/prims/transform_reduce_e_by_src_dst_key.cuh +++ b/cpp/src/prims/transform_reduce_e_by_src_dst_key.cuh @@ -97,7 +97,7 @@ template -__global__ void transform_reduce_by_src_dst_key_hypersparse( +__global__ static void transform_reduce_by_src_dst_key_hypersparse( edge_partition_device_view_t edge_partition, @@ -156,7 +156,7 @@ template -__global__ void transform_reduce_by_src_dst_key_low_degree( +__global__ static void transform_reduce_by_src_dst_key_low_degree( edge_partition_device_view_t edge_partition, @@ -214,7 +214,7 @@ template -__global__ void transform_reduce_by_src_dst_key_mid_degree( +__global__ static void transform_reduce_by_src_dst_key_mid_degree( edge_partition_device_view_t edge_partition, @@ -274,7 +274,7 @@ template -__global__ void transform_reduce_by_src_dst_key_high_degree( +__global__ static void transform_reduce_by_src_dst_key_high_degree( edge_partition_device_view_t edge_partition, diff --git a/cpp/src/structure/graph_view_impl.cuh b/cpp/src/structure/graph_view_impl.cuh index 4ee5ad5ca02..29dca6ef409 100644 --- a/cpp/src/structure/graph_view_impl.cuh +++ b/cpp/src/structure/graph_view_impl.cuh @@ -241,7 +241,7 @@ rmm::device_uvector compute_minor_degrees( int32_t constexpr count_edge_partition_multi_edges_block_size = 1024; template -__global__ void for_all_major_for_all_nbr_mid_degree( +__global__ static void for_all_major_for_all_nbr_mid_degree( edge_partition_device_view_t edge_partition, vertex_t major_range_first, vertex_t major_range_last, @@ -275,7 +275,7 @@ __global__ void for_all_major_for_all_nbr_mid_degree( } template -__global__ void for_all_major_for_all_nbr_high_degree( +__global__ static void for_all_major_for_all_nbr_high_degree( edge_partition_device_view_t edge_partition, vertex_t major_range_first, vertex_t major_range_last, diff --git a/cpp/src/traversal/od_shortest_distances_impl.cuh b/cpp/src/traversal/od_shortest_distances_impl.cuh index c2a3f1160ca..612eb0c48f2 100644 --- a/cpp/src/traversal/od_shortest_distances_impl.cuh +++ b/cpp/src/traversal/od_shortest_distances_impl.cuh @@ -215,7 +215,7 @@ template -__global__ void multi_partition_copy( +__global__ static void multi_partition_copy( InputIterator input_first, InputIterator input_last, raft::device_span output_buffer_ptrs, diff --git a/cpp/src/utilities/eidecl_graph_utils.hpp b/cpp/src/utilities/eidecl_graph_utils.hpp index 84240ba2845..abf026cbbfe 100644 --- a/cpp/src/utilities/eidecl_graph_utils.hpp +++ b/cpp/src/utilities/eidecl_graph_utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,12 @@ extern template void offsets_to_indices(int const*, int, int*); extern template void offsets_to_indices(long const*, int, int*); extern template void offsets_to_indices(long const*, long, long*); -extern template __global__ void offsets_to_indices_kernel(int const*, int, int*); -extern template __global__ void offsets_to_indices_kernel(long const*, int, int*); -extern template __global__ void offsets_to_indices_kernel(long const*, long, long*); +extern template __attribute__((visibility("hidden"))) __global__ void +offsets_to_indices_kernel(int const*, int, int*); +extern template __attribute__((visibility("hidden"))) __global__ void +offsets_to_indices_kernel(long const*, int, int*); +extern template __attribute__((visibility("hidden"))) __global__ void +offsets_to_indices_kernel(long const*, long, long*); } // namespace detail } // namespace cugraph diff --git a/cpp/src/utilities/eidir_graph_utils.hpp b/cpp/src/utilities/eidir_graph_utils.hpp index 033bb197ce8..ba06c6f56ea 100644 --- a/cpp/src/utilities/eidir_graph_utils.hpp +++ b/cpp/src/utilities/eidir_graph_utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,15 +29,12 @@ template void offsets_to_indices(int32_t const*, int32_t, int3 template void offsets_to_indices(int64_t const*, int32_t, int32_t*); template void offsets_to_indices(int64_t const*, int64_t, int64_t*); -template __global__ void offsets_to_indices_kernel(int32_t const*, - int32_t, - int32_t*); -template __global__ void offsets_to_indices_kernel(int64_t const*, - int32_t, - int32_t*); -template __global__ void offsets_to_indices_kernel(int64_t const*, - int64_t, - int64_t*); +template __global__ __attribute__((visibility("hidden"))) void +offsets_to_indices_kernel(int32_t const*, int32_t, int32_t*); +template __global__ __attribute__((visibility("hidden"))) void +offsets_to_indices_kernel(int64_t const*, int32_t, int32_t*); +template __global__ __attribute__((visibility("hidden"))) void +offsets_to_indices_kernel(int64_t const*, int64_t, int64_t*); } // namespace detail } // namespace cugraph diff --git a/cpp/src/utilities/graph_utils.cuh b/cpp/src/utilities/graph_utils.cuh index 2d542956531..0b257e7abde 100644 --- a/cpp/src/utilities/graph_utils.cuh +++ b/cpp/src/utilities/graph_utils.cuh @@ -247,34 +247,36 @@ void update_dangling_nodes(size_t n, T* dangling_nodes, T damping_factor) // google matrix kernels template -__global__ void degree_coo(const IndexType n, - const IndexType e, - const IndexType* ind, - ValueType* degree) +__global__ static void degree_coo(const IndexType n, + const IndexType e, + const IndexType* ind, + ValueType* degree) { for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < e; i += gridDim.x * blockDim.x) atomicAdd(°ree[ind[i]], (ValueType)1.0); } template -__global__ void flag_leafs_kernel(const size_t n, const IndexType* degree, ValueType* bookmark) +__global__ static void flag_leafs_kernel(const size_t n, + const IndexType* degree, + ValueType* bookmark) { for (auto i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) if (degree[i] == 0) bookmark[i] = 1.0; } template -__global__ void degree_offsets(const IndexType n, - const IndexType e, - const IndexType* ind, - ValueType* degree) +__global__ static void degree_offsets(const IndexType n, + const IndexType e, + const IndexType* ind, + ValueType* degree) { for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) degree[i] += ind[i + 1] - ind[i]; } template -__global__ void type_convert(FromType* array, int n) +__global__ static void type_convert(FromType* array, int n) { for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) { ToType val = array[i]; @@ -284,12 +286,12 @@ __global__ void type_convert(FromType* array, int n) } template -__global__ void equi_prob3(const IndexType n, - const IndexType e, - const IndexType* csrPtr, - const IndexType* csrInd, - ValueType* val, - IndexType* degree) +__global__ static void equi_prob3(const IndexType n, + const IndexType e, + const IndexType* csrPtr, + const IndexType* csrInd, + ValueType* val, + IndexType* degree) { int j, row, col; for (row = threadIdx.z + blockIdx.z * blockDim.z; row < n; row += gridDim.z * blockDim.z) { @@ -303,12 +305,12 @@ __global__ void equi_prob3(const IndexType n, } template -__global__ void equi_prob2(const IndexType n, - const IndexType e, - const IndexType* csrPtr, - const IndexType* csrInd, - ValueType* val, - IndexType* degree) +__global__ static void equi_prob2(const IndexType n, + const IndexType e, + const IndexType* csrPtr, + const IndexType* csrInd, + ValueType* val, + IndexType* degree) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row < n) { @@ -372,7 +374,8 @@ void HT_matrix_csc_coo(const IndexType n, } template -__global__ void offsets_to_indices_kernel(const offsets_t* offsets, index_t v, index_t* indices) +__attribute__((visibility("hidden"))) __global__ void offsets_to_indices_kernel( + const offsets_t* offsets, index_t v, index_t* indices) { auto tid{threadIdx.x}; auto ctaStart{blockIdx.x}; diff --git a/cpp/src/utilities/path_retrieval.cu b/cpp/src/utilities/path_retrieval.cu index e37ce3a3ced..eda60941c23 100644 --- a/cpp/src/utilities/path_retrieval.cu +++ b/cpp/src/utilities/path_retrieval.cu @@ -29,13 +29,13 @@ namespace cugraph { namespace detail { template -__global__ void get_traversed_cost_kernel(vertex_t const* vertices, - vertex_t const* preds, - vertex_t const* vtx_map, - weight_t const* info_weights, - weight_t* out, - vertex_t stop_vertex, - vertex_t num_vertices) +__global__ static void get_traversed_cost_kernel(vertex_t const* vertices, + vertex_t const* preds, + vertex_t const* vtx_map, + weight_t const* info_weights, + weight_t* out, + vertex_t stop_vertex, + vertex_t num_vertices) { for (vertex_t i = threadIdx.x + blockIdx.x * blockDim.x; i < num_vertices; i += gridDim.x * blockDim.x) { From 69aef7fb4d1974cde83d3b0e9e7c88cc44b44b0c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 11 Mar 2024 17:00:45 -0400 Subject: [PATCH 5/7] Replace local copyright check with pre-commit-hooks verify-copyright (#4130) The local `copyright.py` script is bug-prone. Replace it with a more robust centralized script from `pre-commit-hooks`. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Brad Rees (https://github.com/BradReesWork) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Brad Rees (https://github.com/BradReesWork) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/cugraph/pull/4130 --- .pre-commit-config.yaml | 20 +-- ci/checks/copyright.py | 271 ---------------------------------------- 2 files changed, 12 insertions(+), 279 deletions(-) delete mode 100644 ci/checks/copyright.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3d893e0e562..ddb84d8a0f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,16 +46,20 @@ repos: ) types_or: [c, c++, cuda] args: ["-fallback-style=none", "-style=file", "-i"] - - repo: local - hooks: - - id: copyright-check - name: copyright-check - entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year - language: python - pass_filenames: false - additional_dependencies: [gitpython] - repo: https://github.com/rapidsai/dependency-file-generator rev: v1.8.0 hooks: - id: rapids-dependency-file-generator args: ["--clean"] + - repo: https://github.com/rapidsai/pre-commit-hooks + rev: v0.0.1 + hooks: + - id: verify-copyright + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + [.]flake8[.]cython$| + meta[.]yaml$| + setup[.]cfg$ diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py deleted file mode 100644 index ba8b73898e2..00000000000 --- a/ci/checks/copyright.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import datetime -import os -import re -import sys - -import git - -FilesToCheck = [ - re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), - re.compile(r"CMakeLists[.]txt$"), - re.compile(r"setup[.]cfg$"), - re.compile(r"[.]flake8[.]cython$"), - re.compile(r"meta[.]yaml$"), -] - -# this will break starting at year 10000, which is probably OK :) -CheckSimple = re.compile( - r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" -) -CheckDouble = re.compile( - r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 -) - - -def checkThisFile(f): - if isinstance(f, git.Diff): - if f.deleted_file or f.b_blob.size == 0: - return False - f = f.b_path - elif not os.path.exists(f) or os.stat(f).st_size == 0: - # This check covers things like symlinks which point to files that DNE - return False - for checker in FilesToCheck: - if checker.search(f): - return True - return False - - -def modifiedFiles(): - """Get a set of all modified files, as Diff objects. - - The files returned have been modified in git since the merge base of HEAD - and the upstream of the target branch. We return the Diff objects so that - we can read only the staged changes. - """ - repo = git.Repo() - # Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible - target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH")) - if target_branch is None: - # Fall back to the closest branch if not on CI - target_branch = repo.git.describe( - all=True, tags=True, match="branch-*", abbrev=0 - ).lstrip("heads/") - - upstream_target_branch = None - if target_branch in repo.heads: - # Use the tracking branch of the local reference if it exists. This - # returns None if no tracking branch is set. - upstream_target_branch = repo.heads[target_branch].tracking_branch() - if upstream_target_branch is None: - # Fall back to the remote with the newest target_branch. This code - # path is used on CI because the only local branch reference is - # current-pr-branch, and thus target_branch is not in repo.heads. - # This also happens if no tracking branch is defined for the local - # target_branch. We use the remote with the latest commit if - # multiple remotes are defined. - candidate_branches = [ - remote.refs[target_branch] for remote in repo.remotes - if target_branch in remote.refs - ] - if len(candidate_branches) > 0: - upstream_target_branch = sorted( - candidate_branches, - key=lambda branch: branch.commit.committed_datetime, - )[-1] - else: - # If no remotes are defined, try to use the local version of the - # target_branch. If this fails, the repo configuration must be very - # strange and we can fix this script on a case-by-case basis. - upstream_target_branch = repo.heads[target_branch] - merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0] - diff = merge_base.diff() - changed_files = {f for f in diff if f.b_path is not None} - return changed_files - - -def getCopyrightYears(line): - res = CheckSimple.search(line) - if res: - return int(res.group(1)), int(res.group(1)) - res = CheckDouble.search(line) - if res: - return int(res.group(1)), int(res.group(2)) - return None, None - - -def replaceCurrentYear(line, start, end): - # first turn a simple regex into double (if applicable). then update years - res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) - res = CheckDouble.sub( - rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", - res, - ) - return res - - -def checkCopyright(f, update_current_year): - """Checks for copyright headers and their years.""" - errs = [] - thisYear = datetime.datetime.now().year - lineNum = 0 - crFound = False - yearMatched = False - - if isinstance(f, git.Diff): - path = f.b_path - lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) - else: - path = f - with open(f, encoding="utf-8") as fp: - lines = fp.readlines() - - for line in lines: - lineNum += 1 - start, end = getCopyrightYears(line) - if start is None: - continue - crFound = True - if start > end: - e = [ - path, - lineNum, - "First year after second year in the copyright " - "header (manual fix required)", - None, - ] - errs.append(e) - elif thisYear < start or thisYear > end: - e = [ - path, - lineNum, - "Current year not included in the copyright header", - None, - ] - if thisYear < start: - e[-1] = replaceCurrentYear(line, thisYear, end) - if thisYear > end: - e[-1] = replaceCurrentYear(line, start, thisYear) - errs.append(e) - else: - yearMatched = True - # copyright header itself not found - if not crFound: - e = [ - path, - 0, - "Copyright header missing or formatted incorrectly " - "(manual fix required)", - None, - ] - errs.append(e) - # even if the year matches a copyright header, make the check pass - if yearMatched: - errs = [] - - if update_current_year: - errs_update = [x for x in errs if x[-1] is not None] - if len(errs_update) > 0: - lines_changed = ", ".join(str(x[1]) for x in errs_update) - print(f"File: {path}. Changing line(s) {lines_changed}") - for _, lineNum, __, replacement in errs_update: - lines[lineNum - 1] = replacement - with open(path, "w", encoding="utf-8") as out_file: - out_file.writelines(lines) - - return errs - - -def getAllFilesUnderDir(root, pathFilter=None): - retList = [] - for dirpath, dirnames, filenames in os.walk(root): - for fn in filenames: - filePath = os.path.join(dirpath, fn) - if pathFilter(filePath): - retList.append(filePath) - return retList - - -def checkCopyright_main(): - """ - Checks for copyright headers in all the modified files. In case of local - repo, this script will just look for uncommitted files and in case of CI - it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" - """ - retVal = 0 - - argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files" - ) - argparser.add_argument( - "--update-current-year", - dest="update_current_year", - action="store_true", - required=False, - help="If set, " - "update the current year if a header is already " - "present and well formatted.", - ) - argparser.add_argument( - "--git-modified-only", - dest="git_modified_only", - action="store_true", - required=False, - help="If set, " - "only files seen as modified by git will be " - "processed.", - ) - - args, dirs = argparser.parse_known_args() - - if args.git_modified_only: - files = [f for f in modifiedFiles() if checkThisFile(f)] - else: - files = [] - for d in [os.path.abspath(d) for d in dirs]: - if not os.path.isdir(d): - raise ValueError(f"{d} is not a directory.") - files += getAllFilesUnderDir(d, pathFilter=checkThisFile) - - errors = [] - for f in files: - errors += checkCopyright(f, args.update_current_year) - - if len(errors) > 0: - if any(e[-1] is None for e in errors): - print("Copyright headers incomplete in some of the files!") - for e in errors: - print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) - print("") - n_fixable = sum(1 for e in errors if e[-1] is not None) - path_parts = os.path.abspath(__file__).split(os.sep) - file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :]) - if n_fixable > 0 and not args.update_current_year: - print( - f"You can run `python {file_from_repo} --git-modified-only " - "--update-current-year` and stage the results in git to " - f"fix {n_fixable} of these errors.\n" - ) - retVal = 1 - - return retVal - - -if __name__ == "__main__": - sys.exit(checkCopyright_main()) From af84356bea6d256b5a8844537697a050984deb9c Mon Sep 17 00:00:00 2001 From: Don Acosta <97529984+acostadon@users.noreply.github.com> Date: Tue, 12 Mar 2024 09:18:49 -0400 Subject: [PATCH 6/7] Starting work on blog links and nx_cugraph docs (#4160) This PR starts to address updating cuGraph docs to reflect recent blogs and work on nx_cugraph as a backend for NetworkX Authors: - Don Acosta (https://github.com/acostadon) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/4160 --- docs/cugraph/source/api_docs/index.rst | 2 +- docs/cugraph/source/basics/nx_transition.rst | 42 ++--- .../source/graph_support/property_graph.md | 2 +- docs/cugraph/source/images/ancestors.png | Bin 0 -> 14156 bytes docs/cugraph/source/images/bfs_tree.png | Bin 0 -> 13763 bytes docs/cugraph/source/images/conn_component.png | Bin 0 -> 13053 bytes docs/cugraph/source/images/descendents.png | Bin 0 -> 13639 bytes docs/cugraph/source/images/k_truss.png | Bin 0 -> 13389 bytes docs/cugraph/source/images/katz.png | Bin 0 -> 13558 bytes docs/cugraph/source/images/pagerank.png | Bin 0 -> 13243 bytes docs/cugraph/source/images/sssp.png | Bin 0 -> 14542 bytes docs/cugraph/source/images/wcc.png | Bin 0 -> 14519 bytes docs/cugraph/source/index.rst | 1 + docs/cugraph/source/nx_cugraph/index.rst | 9 + docs/cugraph/source/nx_cugraph/nx_cugraph.md | 165 ++++++++++++++++++ .../source/tutorials/community_resources.md | 2 + .../source/tutorials/cugraph_blogs.rst | 15 ++ 17 files changed, 208 insertions(+), 30 deletions(-) create mode 100644 docs/cugraph/source/images/ancestors.png create mode 100644 docs/cugraph/source/images/bfs_tree.png create mode 100644 docs/cugraph/source/images/conn_component.png create mode 100644 docs/cugraph/source/images/descendents.png create mode 100644 docs/cugraph/source/images/k_truss.png create mode 100644 docs/cugraph/source/images/katz.png create mode 100644 docs/cugraph/source/images/pagerank.png create mode 100644 docs/cugraph/source/images/sssp.png create mode 100644 docs/cugraph/source/images/wcc.png create mode 100644 docs/cugraph/source/nx_cugraph/index.rst create mode 100644 docs/cugraph/source/nx_cugraph/nx_cugraph.md diff --git a/docs/cugraph/source/api_docs/index.rst b/docs/cugraph/source/api_docs/index.rst index 1b907165a39..d76902772fb 100644 --- a/docs/cugraph/source/api_docs/index.rst +++ b/docs/cugraph/source/api_docs/index.rst @@ -15,7 +15,7 @@ Core Graph API Documentation cugraph_c/index.rst cugraph_cpp/index.rst -Graph Nerual Networks API Documentation +Graph Neural Networks API Documentation --------------------------------------- .. toctree:: diff --git a/docs/cugraph/source/basics/nx_transition.rst b/docs/cugraph/source/basics/nx_transition.rst index 9849865814d..07c2ad26ffa 100644 --- a/docs/cugraph/source/basics/nx_transition.rst +++ b/docs/cugraph/source/basics/nx_transition.rst @@ -1,30 +1,20 @@ ************************************** -NetworkX Compatibility and Transition +NetworkX by calling cuGraph Algorithms ************************************** + *Note: this is a work in progress and will be updatred and changed as we better flesh out compatibility issues* -One of the goals of RAPIDS cuGraph is to mimic the NetworkX API to simplify -the transition to accelerated GPU data science. However, graph analysis, -also called network science, like most other data science workflow, is more -than just running an algorithm. Graph data requires cleaning and prep (ETL) -and then the construction of a graph object; that is all before the execution -of a graph algorithm. RAPIDS and cuGraph allow a portion or the complete -analytic workflow to be accelerated. To achieve the maximum amount of -acceleration, we encourage fully replacing existing code with cuGraph. -But sometimes it is easier to replace just a portion. - -Last Update -########### +Latest Update +############# -Last Update: Oct 14th, 2020 -Release: 0.16 +Last Update: March 7th, 2024 +Release: 24.04 -Information on `NetworkX `_ - -This transition guide in an expansion of the Medium Blog on `NetworkX Compatibility -`_ +**CuGraph is now a registered backend for networkX. This is described in the following blog: +`Accelerating NetworkX on NVIDIA GPUs for High Performance Graph Analytics +`_ Easy Path – Use NetworkX Graph Objects, Accelerated Algorithms @@ -33,12 +23,11 @@ Easy Path – Use NetworkX Graph Objects, Accelerated Algorithms Rather than updating all of your existing code, simply update the calls to graph algorithms by replacing the module name. This allows all the complicated ETL code to be unchanged while still seeing significate performance +improvements. Again this will be deprecated since networkX dispatching to nx_cugraph +has many advantages. + improvements. -In the following example, the cuGraph module is being imported as “cnx”. -While module can be assigned any name can be used, we picked cnx to reduce -the amount of text to be changed. The text highlighted in yellow indicates -changes. .. image:: ../images/Nx_Cg_1.png :width: 600 @@ -49,9 +38,6 @@ input and match the NetworkX API list of arguments. Currently, cuGraph accepts both NetworkX Graph and DiGraph objects. We will be adding support for Bipartite graph and Multigraph over the next few releases. -| - - Differences in Algorithms ########################## @@ -169,8 +155,8 @@ Graph Building ############## The biggest difference between NetworkX and cuGraph is with how Graph objects -are built. NetworkX, for the most part, stores graph data in a dictionary. -That structure allows easy insertion of new records. Consider the following +are built. NetworkX, for the most part, stores graph data in a dictionary. +That structure allows easy insertion of new records. Consider the following code for building a NetworkX Graph:: # Read the node data diff --git a/docs/cugraph/source/graph_support/property_graph.md b/docs/cugraph/source/graph_support/property_graph.md index ef07be79ba0..94d170c18df 100644 --- a/docs/cugraph/source/graph_support/property_graph.md +++ b/docs/cugraph/source/graph_support/property_graph.md @@ -21,7 +21,7 @@ import cugraph from cugraph.experimental import PropertyGraph # Import a built-in dataset -from cugraph.experimental.datasets import karate +from cugraph.datasets import karate # Read edgelist data into a DataFrame, load into PropertyGraph as edge data. # Create a graph using the imported Dataset object diff --git a/docs/cugraph/source/images/ancestors.png b/docs/cugraph/source/images/ancestors.png new file mode 100644 index 0000000000000000000000000000000000000000..37b8e7933a8832f4adfd51314611b5149a139b3a GIT binary patch literal 14156 zcmeHudpOi(H-S57B@3r?od;hlUn#=XfJm2SVKhJZ2KKJLj zzerna3(=i2I|T#;L@h6xUlkD8q9Gt4c>1U9z&9eZI0EobFz~8{sX%#`+ze3I>UqxU zoPa<@qR8saZ9sX)y^BtP0s_06H$Q@{{)HX_0y<-s=I5@3xUn*|5)W7eW0zEyRgdrj zKJIc*lvI8qedLhzky~#bBPF(M-)F1R-?V+}t}}byAw5(Vckc>O6dK+Z7hV0}^7&iA zgxOb<9kcJi2+z9w*~_Cvbiv(z_k#wCYOEcD4sbK~uQTp-<8Zi;`=?bCM%mqQ6Zcn7 zo$BCOfq(4(SIUteO|-50Ph0hOy{8di%M%J4A-u#5p4e~SVO*taEjw0PqIVtw{=#xn zQE9lRPoK7FXu-tA#h+KbZ#i#nZr-(BQVJ;Q%U%WprEk6`{u_RDSeWWWHRoC+?zYJ` zGHgJsa9w9c*nAGNxj9Bm4DE)8J67(3_pYrjFWgtATM?VrzBb0q%};j7Ruj~E@TvMo z4%NgUV@vPYAUi1LN|w7(G{o2Zr9*7$xSPgEtTu8|X6}B%j?|~ayChTZmyQT{J`C}# zD9Ze`+F#_$_b=u{mu*BVs4K+u7SC97zjkubt%%~s7tOg=y0BUtPVa__MkT%Fwnx7! zYC0}JW5K(lGjg+1);jiLo%y%$gSwVk64dYUAe|IpF?q1_3c4IjKiAr<%7M$?)_J51 zX;OL4yuKm;<#^cTxIts8*Ih=-hz2PntIV@E`nmaH&Xft%iX>qXC0$MFi;NeRsVFW> zs2v!tY7v8{zgu*gaHK{vhe!AEzYmlI7U8W&%ewL!ixo=aeNNkHPE4%U5Y|}2tao*= znuT|6LYQ#wkY{QI$dEsMF0+^r?_N@?%tBf~$;`yM+Js3XEW+*fE6E-rj+JWkj>P;3 zJF0Uiz<6d-1*N2M_G0$g(D`-6Pt%G=h(xulrhpcxv{`O*vuyt-xa0?Gr5)WgsVJa^)a|UsS>H zzNDb7(%Tn%g{8V)Javh3P1K~m?s>maP~Q}LqIGt~dA>Q6)8C&~@fDpQ4A&Cw@;NKf z5&WHvT2D1mt?MB#eB|UwZy{SIy;~e0?xBQuga-}VgPy9C+h28Ks_OKj(j;TBhKr}4 zwyn`UkRUPni=44kVY?OE@pB}j9*VH{?HPlJg*`Tq&?DI6hBs2f={fr&*$`dP*KQ*} zV;}GqbvSW31vz)O5uMz3fjG{Q`0E^VTqca&Cg;rzbkl~>>Iv#)T(^eA=UgbHEIY1h(7UxS_V96U#>av()WV?W^?CI} zODGo`Fhl5s7@^InP=CG)U}B z7qhv%X2&%gjVoRv_Tj8tVA;~w(%bJy_gx~rk3aL7sc}>l8@s;p4&Bq3pk{7YqPvu1 z^fhea622Jih(2&;@ib@k9A`fy==%Vr(ac5)v(89E=fW&c^>{)fUp|q^c5sYqoId4I zzR34S4jK?&jMeK}9K&;;)wX$Z9Lo+YC86CuPY|kQG^eY}?qwRyxR_6k0c$C5%@G7E zJbkzL%unGW>q&3GfKa(Gw78LWM+k(l2xDl*=SB0iw;w(puj`KM%72uQ2M^kXhXV_^ zIUr^gxmBxgqe7?1PMgD5R0%5Q+}g%Y)vBStq=EGu2$iJBM29fthy+dzDN6@Q4dTcx z1lfO1ovs!|())tMYY7?0VMStuCLtdsn0L2J`O3jkjI3ptd`ftr^|HOYcdZ}Zy$2f) zy-qno@hEcCbT6jBIdEXdGnS(U=@sn7i@SU_4Chc`HKcr2m}Q`Ba~`}W_P*vF!{Vuu zCi2zcK1kRsU4Dt&GMRA@RV(cp0$f!ZQDT;U0RFNZx31V`f z2OMqex)m@8)!+ezdo*mc8{*U4t97bDZyyn12~r2`D%tmmf_qGwYTz(}y9rm|Q!3l_| z$j3>Gz49llaie0c$osOUNMT3v{5WJI%-zWAD@`r_cpBGbrEr~TW;DomlM%txrN|S> z(MHp1V#_HMg}RLu!p2Hv1S{1b=+qZ4}dMh7^L9E#Dq&f+&Lq6OM@6` zJ!Bs<11seN^>=1UQk_mh)W_I%pbIuy%Bj zb{g3RPQ$f1eIQ_K)|(?LBSSjwd^mPeS49KB$^f0;k#gw9_$G9EMj>MX;LH&HF zalOCIUC(|zU_=kPUt~3dPkx@d$ep^h&C07JvbH8VG?hSd*?4-O@hEe-yV^aFuaa>L z);`+M9+FE?w<(n*&=i*Ht$oqWZi9Ilx+5)<r#-ZK)}%f=~cm3Q?=N787*JeByWK5R1MFsiY6<;R>=6;wp@@?q<$1e0%#`UIh~U8e zS%}s3)p6C^(m{QVt|nzzFnaA~=$CVxY@^iIe)2!*UkF2_&A01MLJ7V7*dm#`XL;Wg zTW}A!z1FhjSF*J>coT`K*=KoSYgaTUgI4jv;<9@c{(jYK zt}H}Z5EVXT@WW)<%sT?WDjA>pwdEthbGF5^b^ExnU8==dUh*Bu1+Y7w!T#k3Q*@O zI%|Crg5n}uic*9>rA(PG-ucXtiIBX18Ww-j1mZukFtjibRfjj=~MFAmUGnzpSUJo%4gqOSj9t9ct4<|1!=Qei|_W7f$31@4=oz^m7IB=PO% zmTv}r9ckN)(d|ujChQ~xnV?X2!mv>tZy85s(PyZa1BHT)9)s>k@ z>j^Iy>s4^Ecn(fWk)$mI!DP*8pGHptO#*y~b$ zxor(?qWI`FbKdmVxtQi#@0oT%Gp!5vy8Hhbz}e9=SAfKNZolZ#)-hx69NeL5woU7XHM}tU>i&XI{l(ey~TnUMcsyNE4thT0*7^ z*DD+XF;|}`P&7Yc>4|XTjn##XU(b#A%Ufu!^u5jT7ab z6&_nWZ;^c-Vx8K+ko2DFMCJo@-Szpoo$Q)HO-=<@4?!XigT!_eroTCCo5H&_9Me)t zEaZwoBdekDh%m@p>O{}oV!8gGoz>Zn+T;oOUi%4gMo4i)*C~G&(H`zAEc9~mMsUMR zOo?;xh8KN>8obL^d(db?7tWAY1GyK(`54|ceinRIjc2e;Jv<-_b}FBNW-vd-)G3*8 zSm&qXqbq}Rnb0{F4wn^l4Sn7pKgz-NrC30IxDJ@;Pc$Bb6_K4K&Xf{M^K=Jwx@=ec zf?gY^>qGhmxqFKV&v*i`jHnZl)WS@Zdz9L{^SvgVtFIc% zx|Il0(t1;Lc;NJs7s?61)LCiMY7xay$l}AI@Bt{i%L|W%*)s%Oq`+=xgKWIPG!z`PEw|%nQ31eo^bAPeKX&qup^Ln^8D8B8R#Fl!R)M}tyCa)Vd@Z5+Rv%iti ztTmA^R==cqwX0&-$7^Nc>_i!?kYDjxq^+0Mo?-Vg-bUQ|X_2f(uj&#s z@-girVz16>rc?}8^k9Do*;|z7=ES>9fewfng&BD41~^dL%nfmJ%jB?^0{0L($4cXE zHZN0%b*j`V-8|vuc)*=+kOTAvb~dxw;m5wDlvN+s^-Bo5KtGH>c(pYSEcGu1*!mx8 zSMlONuwrh0XY4IfME35wfl>Hx7g#-L!m`8v`p-sLn{73STFSJ~}i)^PF(SF_agRz|2GsVK0BmdZe{#im2+4TQo|L3R|O&ME6raCq61YGj9NFEN&f-LaLv>OB`%lqob|5iiuEvtD>VQ$OSYLxFD^ zXjS|=Qw?0QRC0Hp_;VOw(RW7|br4><@xE?nNz~5j5ey~%7)Ey6`j^#t&2e4pB&-7_Z+bW;P6y&PeMGw??!7l z`<44wy+-{X-evd762wL=<0uhA%mh<2&oR2X$*NS*$?LfUl;=}6$9-oZNv9OD(Auyc*nIBm=pg&x?n#8P}meYdQ# zhVL#;nTd{sJik>(^b=E?F<2u50=A8jmk7t%C9#|4cAla62c%eD z3Zo!rHPEEVesZCm_Qa{B*HB3Jelx3G#RLN*6Ca`*Q*iB0oNepg!enj}hqhMMs9z4t z^P{<8STm<&7(77;I9G#QJXmtQZ~xulhmDkwQ+}&}il(Ou0zZiPWsMvtPifVOKl_zQ zz%lPsU>_p1oc>iPYY!( z4$10r#uF!AoZ|?4#~@Z%HL390!IGNOKwR0|ATsZ{dzAWvjKhITsMl+RlU=?h9KveCW{1zasJh+QHn?iTTzZ ziZLbVy9^=4d80{s&Zl@aKAzwD5#F;el0$Z$5Zk~T4|N6Y8ePuyJ(_4#y4HN(+i|R6 zZNlRvR;&Qy@dV#Cd&%3DN(`-@-6C>i(MEih*3IgMKCYnl zK53*tWDa&s1Rtjevos_o!SZY<$JfY04}G+exzKP;j!-yu$KKmB0B2kpR|n!b$xWBT zzAeOj|0JzbhH=4Jx3_7^0-@6!e5$O63ow5L-){pDW<4L8yput+cn#&)qbFj020TGj zVN8WUg3owgmKpJ?yVMc?YDbC=JE&9vlUpUVWHPcyl`R2wgB3tD0y0HoMJxRVeS?R- zS6y*PP!@XWn-Sz*zD^OX_nRyG!7ILYfE5%DTl>PRy|BU$+UY>>6HM^g@;)?Psz(=b zHD>*;F9Vli;={KIPvG>|51Nh`l($?mH}?>+PQtyb7_h4GyED{Q*nB5*z7FP@TUu%b zK8U){OmttZHee$%Ta+6aLE!^dc-AQl!i2id8HzWUi845h6bYJJ9=Sj2N{edwcHJ1- zm%7rbWn04qg_a|qJ{^7m(CDT+1AZCW%hS+jIV!?Aj)S>=8AD#nwdoWcgD>xVmV+cx z4a(t}-M-@$T=kv9O0XjOv-MM56#2}!K%`C3=cg3bjI33lzOCt}mb&4d!#rIEo z0Qlvw(p+mPU49Ad066JvB7*Y0Tqd7HU3*+`V-MLg){#ssoo@{yK>WI(X?`vd^ijX?M0 z5`V?Yq6$SUDkxgClDwPJiRtspD5CmhWq9W9hF*P`U7S;v-B>u<$QCHCUUPeiQ`%9q zrtqVEDS zd83J6(phEfkgQZ~7;ia?3k3v%w%R_##Qcx`-(>aR=)czt{!>BWzwyW6G{Wkn=4=AD zJ&RKW-Blf5;R(I*otS3j@>lIU-~iXihFfX_buWa*8Ttsh>k@q_bFa=SPsvI-eWcOb zePiKjB%LP|qE@i`=ln7{Fe8VGrO=bqH?*DA@i|yn+=z6QH**;%q*Fs^lY$X4GYyO1II@f>whY(frKeENPg_f4v zzzXGHwYa6Cf3oQb8^3g1_(G>{OrjBgy039{zTBSAMIT34ywQcaoK^*Sga`xS(%8J`dcnpJInA) zL&O^2wcQ|1Ro4=H@tD!tVt+?xT1!-X?FcyaQ}XB4DP2PC<;y7&QWEjBCss}pd3lfD zs)*P#_m#T3+jsmly}>8}+C1R-?|Sqe@0$beiaOxFb4$^wnh-bVgWEabqMwt2y9?-t z9?-#g;@i};rV1uuV{KLweWf6;@!>Q__DI2xK3Ne7WJ9}R1uuJJIlH5-1%rqZ0)s;N zrA7?dyutfcZy0I0*$a7%-uFjZJh3rj4u9LAkwGvRtf{8 z5WKDF))0cuO>MCaeu-|VZW=GpZRB}%FnK3+ra#so%uJ~*NIg5vpoeg9}aw8u(LMU=Vc{O}7dHpaoF{;wPEy59QpG&fH)mJu%^>!ahWW1`b4 zSLGCtd9q8*_gSsm%_=LW@HwaCZ1`ma10(A(<9gE zKdP~#J12vxY{vUkP75NG&RYKDmHK$ASCp|4J9Fz%ZvE)mk9oLyu4aTvOhtY-gMNLy zlBAo4%h*>P@BW*!>ZUVVlu_CqRae;$GB~$CpJ)zy%S>#V&rF_b#D|!j69__7Ek4I zUpjlXJJz!`^2Ki+B+|0rtMIaBKHGqY6zN|K>5n(zb zw#;{0RpR8%nvoVYJo?w9$H%Xr$&E~y7{B0v|E1|uMm?YO$xN85JBg)y{)J~LxZ6eH z_{4`JHk+2b7TK!zQOFZ;;>4`#hZVB!I#9!G;6x1_XID&>ys68Xf29;1S+t}xuqdp} zZkn&&Ik*B>ZN+n67JL6~sBA!kmFu@CgwHH(Ug)9w?`sae7!4k;+@~`fem%ip=d2H5yK8dxF)-xl9z)z%j!Y)9YRSpIdJg7j@( zo2=CFb>3y;CVeuyBi~vlE=r+=?Uxemjx*d5c2j7x>_2V8;+9Dr zujEI0TiuZ;-%qwH^+jmB>Ac-&Q|ZId%B7ZQQB>zy|0~v=Zjz~7R<9Pfay{%%s0uf_ zpNbQZ^`R-<@X;AkPl|7t!6hO=DCzR(&8Kgp3;zzHyXhK|QbU@bB_&g}egohyRXQjB zhQaU8YQgrK{T*!o4z~a7f&YD=y*lvm<3~Yi;0X`oe`Fs2$6@|=FtFd@nJc$zNbDbI zq)xbs4RV?I&k5=8oXff&v*B3+uhcQ`2kvPm9=_Juw0p~UG>RVtNPIx-5 z?O2*BUKp3(>GgNjJCrulR_1^#iDo7lJ^C{4{n;7ACT-{i6jX9Z;;%$k1&FWvu{9-G zDC5<+=CO%HW25lJUW*tTfT+y@gbiOm@7TCJh3a6EnC!_^h#U4}>fUFsmBbI8Yh$By{WTJgh5=k?GBnztRzB$>1__e* zxyvV$XJQ7TT#YeGdO?bya+MqnA&!4ByvxFHDe|JkPwOh@t(1DIMfG-=p)Ov-xg{r2j23mg!p8zb`BCvqQi+FVM!9@1<+4OS9LUU|7}ZiJ?l^uu18mf!bJ*B9{wI`oMGkC zU4ZP^S)PoW8|ni|Bq&R0D_TnQ4Wf5kg8GcV@viYCBmC|}J*iu+tk6)_0Mp~QBF&G> z1{R{=ZTX1rp5VETz#Pu2mOQH`2iwm{?DUd+99LD4863I3kS*2WOwDl)+;N1x2hb~(5|oAVeZ!&K>2=E{ zUZ=?&f4_YgbZ<~a!%%(bYoKT(KD%HCslz3OBA;{_Gntb@{!-0h)*?G%evctxx>-YR5vR*tlzY%W#n9u-w4@U~K7J^8g*$A%DjvNhEpt|#A@M&mR zJ)^dL$~iJ2H#+{MXi7X++P95kWl@L<0w$Z;vJJroN$i1eaKbEH0wE9m>l|v{w!`N; zMenO>-JFzqxJnjpuH+2@wiDn$DjH{Ut&UtYha-C=e3HHft1^l{yEk(gP!~~nS2%TpoMOZ?<(s;qNYT~^p=Don|iOeyGBWQ4UAGw#UkOF zTPtcb?gJn^{AF`I3fA=xp_edl*K%G{6I#e{$gTsQ!-BvntzTebq za^jUd7CVVnoPoDxA0aV~ver4jhueSI^wj3SeyKnmN= zDa%>Ga_>_VUsPMl9}%4Q_@R=&Y|Kt9(iiy$`TI|*e_1*E3VOvqnnJXt>iyvs*MHqm z6Z$`t{ZmZ-k8I;8j0i^fU8W$8oM;7LLx$W-;!;18v|m=A{u(8Ac;f4d~HaA~k$(3xcJrA*NsiUegx;YxVXW?G_SgG{^HCwYNJr zFL9vM8g;;6W=b$YTHP}YsYTiU)_m>0KTKgyLf;1>W>blN&}FX;CS-#jF#FZh74Gk~ zX;6VdTJ{&d|R;~ISTB1FLdDlatU5(+ohzMQG7Atfax{JG_PIgp-}?@d|v z71f5lpRZE8zLx%UtBU326{3o9p$njC0ZQ$rvB zw6LzXeqTeX5a2dJ00vq0b3MbFTFwWz1M29%*HruOZ~`p>j%P2Jyf*nyS$j}r9$H9dXT!e27e#pKWGorNz>Z%UH)9e-qP3X+Up zH~IC*%lZixwI-&(CR;wsIC8!G8~UDn=5(V7oPx_9CVNd7aQsktJW|*&aS?mN?YvpM z&A?}wn`Of~AZ8VfjQ1{?{Ng&|_%y|Uo-#EKO%6-reTS6|)>i*QumHBAVr8$g;tlc9 zNwfTIhWD3mc;6h00Fu5peL(mg)$!DrU9rpU$*QwbE6|F!qK01eGiFmNuUVS{WaQ=W z9=k+pZHP;cCY#~?dKJsTtTr33Hm`&^l7%P%VEQPu@@@PI9|&b!$o-}p(@pRS(-qg! zfhNN^LHG~g_Sv6S;}aEfVd==GTUdE8)FGrwiW$qg!5%zM)mWr`AgDV-!$PH>)Hz|6 z?u#vkHXMPYH(fKOy-jJ@j5r!~KHL#y0Yb^$11mh?oC~F5D91uiIF!}wSSnZ94$VmL zCWt1C=>9U2r(+^3LkXe~IT}x~`%bg>PCAqQc-kbq&?8G6UAD9naI{aNpU>u^;US|c$A|fd zS$WLB^^XiwzYf1-W$0NjrIPx@sAT(bF?{iCzSswGO6U_U@;Z><7QUyOHWeUC;=MkF z@*QyP@CT{~sqbm7r-k=A(Fa}813T9|{RTp4FR@(xtMs+Ou=9Vtcmj@UrD~g1Qb9}j z#OATBf4f4Q90B|rkX-v67}J0KHI#2cJh2|0Sl^Fv*X-STKD<68AI%fVAhCfeiD!$>ObJ<+w{yutPA# zhvLsyVu3q#K>t30S^}b>;N@4}6ggk7pV5i@^q?^SOtr093f|JmyL8BO z4I_prg0zzi!m&uR1E4O;gy|^v#lebe5dCsa3dT$p0^~vi-Z0~-&W;Dc?zi2`QNan! zsn2YCBq(8~q)}^3(s(m*FfXwo@2P#Z$(qt$eKc}UPd+EgpFgeqGYQRNhP**rJ%YDe zo|)nYtL&jtfPors4zz_Pk?MA}x37N}&*n`OCn*X$+oK}sqNK6M_&TMAz4!S_%%;_* zl?8(&J?;DmD($d(SPd%J5Zwx$x0MuP+-$vm$>^Y^AQ{NU-MdiXTdtM<1V@To z!;?xXc0%LA)?(Y%-k@+`Fj(#cusAS^)n2_Wv5bItg<_WSLFl-tKD2U-T*pcC0U<>$ zWG}-Dmh45YMJfpvWj89qZ($x97#+7XFN3U<@gKkqA=(Zg1V&5p`O1#Xe3mqSu(?XU ziWuD-wH4X6_Ltq*_IU0>F2e4Y-E&hP(Ma+YKy|>@@`gMQ(_3QqFb4d0+Icu-I z_gcU8TWeo)cXd$uTH|X41qCI?6URLj6uyKjD12eKWfSntizIjlaQg!7>2OS;imf#P zeE2HtsLN3Wh3Xu|$g} zd_nXso5B~q5^$DH)IVREt7Jj+nXMf}8N%%IB`rX*_S06_KIN~k94mI0E4uwta~dmg z<-_wG__dXhwG#Kw9qC~=Xg?2|!jWGA!&aDelX|O}0RLRjI|yh@K|%4N_)FmCyCq8% zxIH;`hQ9vo_h0hYwe-g~pZi{UlK2Ji(da9tB5*si>)(Ae(1R}(bmBWn4mlPk2nM|> zpNZ17O8nyh4c^oM-QXlymPtCFr|+<&A(|OgVD0D;#*h_O9*2LuD}0FHG?VajDfRnE zo~t$X-@Fqr=Ee72wJLWFi-9D|cGrr%jP9KGM9uff zqa~xw%_We+oejS<4!2mI9^nS0b!-~dOkjG8nZ>>knk*@6(28rkgtNF^FcNU~yB=pg zcr%N)#C6TG@Zn7Bt}U|W{R^MxTKL5hJX{Ug`{TPNU%u477dCoFYVOlaIXnZCaAQhR z&y(mpH#=|{?$9D}O-95+Vc4PJy}~oo-a@U)puNJN^PSDHSvsp(LwYuA5MsdIpxF$A z+&s}o5$8uC$rskIpiprop(rt+dAREr`rOqioPPAPD$SvXO%asNnlR3ApVq=VeA2YA z^M2GO(z3EG3caB+o|KCIK%Y!JCg(0j5m63kSOD*rJd3`B+?$qCh-;_V-FiQ%zJUn8 z65CbKrS=TXMYNsAQ!u?^5WG%8(6vcTi46FfST36_h!j-BC@1tT+jU}<_X`m*Yg`iZ zn<_N4C_N)6N4`}`A>KT^r!>wx%B{!>LtZPhR(PTTz9yzxxRnj8Icj98S_~>3TE#)$g{V$kkBls@nOMGJLUG zEyo`>P{K?{nrql*&3PELJTy=r|FKkbC5S8jnA-Nx8O0UdL|4qZXusanDmlR|6-438 zEQPr!GLx27wN+}(edsdbA<-J=*Pm7_o`+RGC_KDJaDo;*f>F3od+sE&-ICX>w|mgl za;o85t&FHe63(?+`E<~pwx@jF==|_~2b34Gs0dNl_vl@gfC>c(J(ju$2aQ?QUifac zz}%eL>R-UeIc;9>ANe8B^s&Ps_Lp%NB{~t|$~DY$Yp1vt>Qe3MNN9EdXD@7JMOK#j z(Ab+gALS$Tw26!f@9-GeBl$|f*Ss7v>Mg!%9#1MdHa+Qr2%qX^LcEF2os-6IR8DJf z2koaF(B~Sti|#18@wVFI4|Yh^$8UvCorSSf;sd5ng@N2^Jx)shscFdMa43wp7S2v*4w8F!MFN&aiSQzkhd(!@yuIGfZD@JHH2xAzv$F5e zuHS+zx`O&Z91nJn*+5P|I7*OGW&r~m^msv6&+h21G0QmZF?eXStuh-XnHQGc3*w*e zDP_3XCuQ0Sz`TSP;ZVoIv=DAHd#x*eiazP9K}vN6UMp4awfaJ#r^7KT&tl8SZw`Rg zJ5O0kQ<|em`GLq9O_WBUs5I>aG@~S~7*y-0^DC$vSc;;wyekzSDq>}Iptu=BWHV{zeSQEthvDlvt>kaJ@v{Y+BKu{aZ&o{_?yO;va%_B zn%zpR@dKB_Mm1hU6xe3mx?!ICD0MXu;Vh(`-_3PfDQ7AuM2_uYI$53_3XYdl;V<54 z_%@8#jaze%+mNs=r?{cJ^cT?^kmWn2XDCsq`Q}KA+KDlLu4^-I@AQL}@Tlz5!W-h! zhxu)e)02mLf*(326SN{<>|3M8R~1vFBrSVlE5pB3obL&?0gHie3)!;O<>)oqHqRhgrqA7?&1n@3ue{%c(V{d)D8TN zVL28fjf+Z_zPd^MwUZG#3QzhUt1|aDi}DQz>AjiF)KLm6qECxY zzunX@S&2Npwg}i@-CDX^hsM@`&8M zmhkJ+wj~9Bjj(8X9D`5(A^7~ERf;3OBt;;&ulaB_`5t=lMz_3Pw89{6JfY2fB)ZF( zT{;a1(O|{kICYwBe1p@0hN|vh!;$V&j=48u2c#S5rxI5C^8Lvfl?@&#L`gwYWJ6Zo zq3hol(Ihl!Prn>jHDdeF`ZBuS_>0@4+Wh$V_}bQA#UeXQ+z{REijB`mNAs$DvT!Xd zWGI;Xh+hNF)1(V9#G=Ib08*H5ww-ePzRIL$DiH>yc* z=F1kYKQ+5*!T(Tc=TcR@>jc_J7#r}V(UlZ$c}0&tyy5T)*|15j_f~Nvt$&VqH;x@| z$S{y&R}GAUa#p$2(Z?Blh)QkDTGEEHdRa;~c@$s!a#08Ub*)sXlhtNzxuL25F*l~r zZTcqozD^R)#Y^&?Xm*h2AmqZ)s}p2m9j*VdJ)crIiC&f37sztT>cd>^g}IsVcy_^7 zZJ`$7W`A| zabA|Kwq&YWt~_uP2;&Rm4Ofk{PJ_pG9_AehoyKDiZLZwPu?!NuF};7IU+#Y|HADP- zhs|KPXN)Ehz51-4S~gsk1GSmRmD~U*cjKC%Es7NVBXeo%Z3>#Bm;B|<$zK? z&BquC%jO|H{wwKZu6sin9Ocj-8U;7`QFzZMXRqf5Rblb@s+t+~_Q;KI2Qk8Hz0ygg zga~@d&Yg@WDNpw{l%@PKvHU)CR~%dXAg5cc{wZ%BI&3#@%GjD{)$DT`=NIkX_>zt# zh%ckvn_Mn*sUgX|T0YYfd-6ihIiem(Ggu%z?AYlfCN#Pe=+!>w$AguhqdR3|Z)%Cs zL-~2xHI>;z76T#6)by@ukOZ^fV)h7vxWV5e$|GJIPM*oH@i;P>%UuMsj!Uy?brvol zA;C?sCl^+fyu2vx^ksvC6q}M(M2`W_NP}w+Ec&-dVFhZn;;Ti~*6_kcK2pa!t+;^# zDrSJMX7k|Hw@7?S6Yes5JWMEizr&scGQSd-TomDDRfks2&trX{Ul}eRUNcDJ^$08T z<(T0lX?6d2SYc53@)Td*;fShGe{#yV1F+QTFEbr!7}$r6I!-wXq_GeqX6%Jb3p z7^_yD@6Edh*JjsNXA8l<6_5f;b=U_a&ZUB?hAh&k>^&jJj_u}_`qAq+4OGN5gOJsz zoK>de#DOd}Po8BXAN;QNwJe$;3q5jG{xz$t-X$Qcv}gZSxnLP5QWW5?iO(*gM7Yud9MA4hA%N`Dg{&|?@2@oV?3iT zpoKrN!ZV$oBrOD2{2&^Muu*N1RKFBEOtrOg{jFFX^#Vi;oRks^&2Jt~ z9_rCLuKgNYuBYD2wv1`+*vmhXu+ntlMrnRf_QYehzVME7XFxo}@<-OUGvwrSMmEei ziI>bdjc!ZKA3330d+TC5*&+0hDsNn`(#&jMJRD*YP)X!~N0#Z5!xSl-LktkyP3lZ0 zirdgCk>uYCyzSAIkW8&cLGf~7&9K|}NtT&9$5(gP&@x80gG_>|NRMzYd`T#1$;#{6 zD9H3s{moc8;2cIQ`tfy4ousoEegSFT(6@N(WFsZ*9VkBN&6VgfcDJR^__Wf|AQ=<8 zCEG9O${hOJ_`y9aZyiR&*4js|eQdE-b2iNaGJ9Aw$@x(VYqD`FPSRDB(llLv2VcuV z+1I~s%Ha24^ot?|?@g&y)P?H)rK@oCH{Rx)08?@gRaC+MO^eKS3 zt}90!P=mu7U*jyPy^&R4G?f-p`}JfpYGM^wRoaRDLGl4J?jA2m%5T$rOYq|+9@TBw zS|UEB8DI0|;mgSEGiUvKe`-qie!ck|;z{DrhGp^_kY-+O)c5C31ve}|{A$$>YH(Up zdQCM=?b6bg|41zR_b`~Y+pm4rMJ(9{#J@t;DAVcclV*0s=*U}2bL3S3eWu}rb?jp>5cYDnC((k_7xBt zK|YE1caqfyTS;MQuq8ipX{*}O%fsw?=6R38)u2f^M%!Emh~tQMKoDHAJgfKcVjCeG z#ubE-0&R<3B7!%itomo=xf>6^RdEo{1JUH3Xiva4}$l)1%(BCOP0^jzWvmN(8;>~hY$T`%;B&y8X%!h zN_2LHP+vDbt6^7YRHH?h^k$4$tqPw};-jmmeWWhipX$z8((hI&qlXaan9?? z`tmj0Fp()7SjbEwz>c^Av9lZowzj(Ura*_BU7T77%z|<+29QHoNxPofmTagC%v#z@pArrPPt$N2 zs(4WrbsV&%eplGGLK-tNH}g(EL4|`)VxtXDYQxGZEk98FZw5fhT!mcw`orm3j|A7n zamKaniH+>>>X+4a>~=4l@^0MkN~Fw=wj7()#QmFbfBWs1DPj|O+8khG}GrA?$7ZU zqi6L$ZtjwG3-Q(Qga&FQQHQsx@ftj~d?j##&|m7Be()BVB0hAIcUp2%>u)Rh7~UD$ z^r4j0J9u|zw7cxZZp%iBlKbpytFcmqiwHybJC}p$(JdT|cU`NnN&5lzo+#_pWoZ>NGV*K`-zH|*$W0J&aJ%Xi(!HvL zR6}X{=^>NuB{{#qQQLN3=*r5z@zZ>+?zl6KS_Gnf3vSkm>^*#Jnma6nN#D&v|`Y%8}2ulG%v>{`*&WoI2-*6lB&y>E{gf5QQDlVKcYPgaqc>bQkW^k^^zwi z+11dG$TdBVEdFsf3Go}P`8o~ktNQ4*zm-W21qxe>ztLTPs8}@l2FBmgbkepAl2hrU zoAVKP+hgT*lk>Rgi=3{jQ8l|>ZMnboU_YG)$74VnuE`mR;?djRwkCqkGb>9^ptPTt zMKjwZ86{}-8c)_Y-ipQiLvU%+2Y=5aBzZxy=hESxs+ioyIjxB(V3%bJ0F9qBFbbt_ z?>krfl-$=9imDmKLHj~Nzl&auFfxOKv?AP*!`Zijb)@vbiRCKUBtb6hjWLwP7k0&H z{LtjqbeY+62T2lp;t9RVZ9M*IVgNsj?aj3ihExLcXN&`&jD_HkQoVZe?vtANO1)(9 z2YaGcB>i%y1CAWjQjRg4-^FaNo}&kr7*K)=AXm$9u?7Vq(%@6F-I@Tj51*H$=-#G3+UmdJjL8BMy|>~kAin7+-S zM>mij3&+4|7Q2k42DcM+`L|}TuokRpn&-Gy!*T4_^%Hp+`EOO1dy>){m$;@~NeEDY6_Y+ON#Ak3nh`_sa7oRobU3ueAcqx)Zgors?7cc2e)>MMI+!zhsUfaa-T#~^>nubtC4w66CbpFMt+d7?22tCYB-{?nYs8|E zr>DqN-2RTZwk(u(LwsS3@PWlpDr)Jc0l2Rs z_AW|yNY=BcW^!kv>Pg*y^TZ=2EAgq(<3Hp-4M4l_;XO_>7$i_|Gbvx#OK>(r+aM!hXrDrR% zJiRvH5}-;M&>ees8iicWJOt+uZAR!xQjvV9ENK&|SgXPQ2y7d9!^BSMp2{QJFV3$yU`Kx(YOCCXB^ENnaQdmAWGB zoh(R@%@0Id>S8~33zlY^o2%+^BAc~k32K{4%My^hd1NOZvNuIG!JFlx^6BcX7zRRP zoFppF48Pyn4rEy# zt#C&+vFuOGs+^`7>e$+?a1qnqUJH56%OL~9C_-jGRyLkzIc;^@eJE4aBx+7A{~&KV zbDV5%m~CD_4C2ZOM4_I&yhqgv{7)xSyn7N;zu{T z1(!)`eTK-nMwyS;Bl69cc&+Uj zaBiV8Ly)N?`Ki2WGoN?08r{*h_7&z&R}Sm3I>w>c_x zX<%gdMUB|gaYI;eFxAiL(oI>bR^MPr>wvdS#v}%p<%29v0XYb9Vh8~=Zus)iCunKA z+s{*`z5}l>eO-?{{_#aSLdj6@KJoIX0fRls(=%F%FzVHB6sb7v%Wg(`8P%@MG|tws zk~G5jB}ogE81H{Ib7AVVy=A4LZ9lk(p{2*h5@$6h{cm8}_0O>G#>ZBOALPHJ4Zrez zW7T>h7>|U<`4bgF{55m4B>0hKK`9nw2xP^a=R4^Sph3adR4kafkm61D>1?S!bXyd{ z^oqvDRg9v+2`wKI$!Ai9ThCY)9zWfP=yi3=OYWG24;{zJ|QkJG4SxSl6`947} zBCQFE#!s!K+bq9pm>)@5TUo$!090}^jR)Wq2TQ(zeRd^3W#nnq5`M0GS@;2pO#pC+ zxM0rilY;}cRph31*mQgI4xO{@on<2Af z0L!etr$wq}=thJN0W7)6fW(lxw9eP#6AjsRG!r=e&3p`^<)Iiz-jvU3?JSKq>;Gs# zHd(srWrtxbajh##=_$v8c%ateMB&vw%jQy-XE9 zc=~GTEMKjyc$9%8nVzM_4)!Ai=aeILPwA6fh_Uk6ZDtTNE3@s_aIbbrrA4MlKH1OM zo@HhQD|L`z0TA@4=t=9_@FCGezR0|2bkT9pRH|VszIAEtkZbrS!v@-*u7cOcAHW+8<*tl5#wSt7^<9%1N*HD~_AJ1ho;6X11&!T2 zfFoZ08AEZ#k>x>M5OtJ1_jO)SnVzkEV6eNh#x7NBYZvO)?(%J!%s;M~W{1JNJqD03 zmD{IpH4}3wg+Y^qPDBp?UecNYSP(qsRl?UIy z`({o1mXhKJ*QnF4&x9bdddF|uI-<>v*+0A5;sx7F5;u;l%rH}aL;AnE{^o_6rilv6 znd>GnR&rssgZD*g99vlT_>Jv}C2()WtlozIe)+(d?RxeaI&yE)%S1>~5=4(0^Yf(= ze#$6F>b;venVM&-u3mmz?-=TLq{T-={`kU=xXg@xA3h(b{b2aV4ZG#VoxOB3bl>F7 z3eM7L)a-&9O;%xmSdUu(BoRapJ!l3u`#FO8F|PIMz8l@G&HmBP1(5t7RIt3V#MIFe zGc#AwEX~|z<>PD3oh5amSu}eiKYpzpdiE-{`^LpZ32`lDWd>n?uPz3mxvj8Q*X(!X ziSO`gyEgD?TP2B?nch$c%?JO-6q(>p48{(Rc+q^1#VmW10>No2);JiusBwua1`taR~lv&Fn+YtC`bgdq=Pf zLv_PtSg88XuM)jU_mc1$H)X-i6&y7`A_9>uvNqvOc!Q7JW_mUWC`nBVbgvXP+PUU@ z-X9k~MRh=U1H4UrH-C&Az*?Q!z?_V2f0Z?(Zf0ola{|tpJE-q8i%uH|+2_sTI9BZc zOLb?y|J#~>aEO$AHS0I$jE=SdbT)Gm8{6N~Kg0T4kosp>)f)bqOkfyw7R;zeHa6e* zJC*#V|Wx^nDSw1&7l}5plmC=UE=pFaql|k!psD!F867NL>z?w@%2IYdKKoxKBr6K zUyr<#NrZTZ939j!9aDo{o1Y$7udTBH>KljYDYDq);dL^hXUXWVk*vj zrA}ex=tjHTKEc1mG7AL?;QDlQH~j4-9LqbyFW77)f24j-Q^E(-?A5!@R{h zS0Lc$*bLXn-ji}H(pDxBG|tBh;)ZX6=?tKxL=$!^)akuJ?>Z*{_|eCvVXO<{d%#$J zcIZprc5B=1#nq)*%k|j1$9dw-QvPt-d|%i=?-+mqhI+_;lk?VUG{vJ#CqWxx|4Dwy zf*D*EKW2J{zFtlG%42@l>k$4SGtXmsQxcBNNRgYpRL$h^&a?Do?|abR8{(wJ5AV7W z=3i&2uGgqS+MsjS@=oj)&@`kf_nc4KO)q=j&tSwZ1F=jMXMSDiv-HbFNv&}u;<=j# zn{^cbRI;ut;i121r%7$M&vjl3B>;pX)WQe_lkV3uY$McSB|XC zKeE(wmPV1th#|&i{>Fqa{}lY6R)_bW+V|7QQmy#%PsO}CLqDTZfhzo)hx*SlYr6gc z0jLKRE(Z@Q{UaNPXmKO9?EHV9EEIFnRweRq?SIr{=6{~%6=#bZ8XEe8pMENP9#mSd z`b`|&=pY6dz)z*iJ9MnxSgqq(pwM~JT=;z8QMBbcmbFS-U2YH{-Z#yi#K`8xjA<2pb2mR3y`4A zp@p~rM(&@^br>SBy$1}2*d7DN>%)hB?nU!XYCv|3zn?Ucda;G)cHrA6B1~J;v7r5` z6L*NSUNhPdfC#q0>{M&A(T@|4T`96O_ZT5=XmhmO8Zg!HL24*kLjNH3W)GjW!=@QN zG_HllyS*;&nXE@1xy1r54XGlnQbF9$b49hAP`76D5UKUnRRuZFOC^*71N|l&R#PA` zHMN;Gq;(K}5Z~Q0W@B2U$+*z*J`j?Quc&ssUP@A$(prrx*Sh7iBDy=A5>MzSevKV& z3sLGf%vq*(Cl!(E>bIMXpm`RpOHY7N#vVA~HWF=~P`r@8X5bFZCxLijsZq%-Cz@{* zeIAW8kZReFKBk&&tOhgD8~K5KrGqRoY(im_BC_AHpgRt zn&10uJyqJ%-@{qes#SUAlol{LNoyWnZQ3JD%w9N%Wxje?aIW(xv5B2YT`?$=yG-;I zn*-cD%%oJlN+`4J#;Oh(c<3}3xD4HAzbp0js2cW~#7!7SdCg|dwKdY%1!2`R^uXj2?1#!ZgI|z|?1h7Tf@s)%g!n5~E?Z%o2QrIK&&GAX*L!PM zwh<%JQWDW06=saq95oR-Q@rr}lOR0pTjmgK-zFfF^T|-d)bCj0w!b2C*x$J&D z%DZ|bUPDkDU)7)e@Q4E{QYW8ZKm0svPJ%eI>DD~XL8>MvYUHJfi4n4bSSwL_#siLNDdo0v;UPjt`}JDatfibj#Aw70vx$N z*4%~dae7=t@QOV^7_c^)-71x^(d5DXqS{*+*9lB7u)=^G&i~SpgnvlTB(f4G(sDV! zelie`HOI^gDRMJJhkHsNFZoMUpQ*lt$ur(-n->fjpQ0 zS7zUseP^L~tn>%m#jIl*`|=g@wWUC4r8y|G8lzDiVYxgl%wr!o|Pq}>I@ z;)1Vt!OoEGN~Dn%x;}gl7Y7E1k4!b&Tnp=K$;EtU0fGo=NB3tSa5=H zJZRjSXl5Ek6tv17m*u>aQQmyz(n8TNnAvv{Spsbv5)&eM+Z0we0)<3xvm7Ngdemzh zE~KrWr#T$-@qWa|1v*>X+M6F~e<9(sBl(Y=pZVv-8nEu;KOfAQ*(9T`lX&ZAcEU#s z_2Iy>`?KXgae}A1_tSx)0sGHqiuUC$?4R>)J;IEvd>^*>=SiebXOjNqaU@pa(9SzH ylIGygm$e2#j<`YpV45JU_48#7)R|TzTO0WBZB@kWRF(C8>}cnDyz1DQU;i6@s6EF3 literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/conn_component.png b/docs/cugraph/source/images/conn_component.png new file mode 100644 index 0000000000000000000000000000000000000000..b7db09657c8dbe3f0ab8daa9e46b16ba1d52d836 GIT binary patch literal 13053 zcmeHuc{H0{*EiKxahDou3_%T@P*sFhLaCuzs#T??Dy21rnukcWb*I$qE?Prq+M0^0 zSq-HmR7unjr367@7K9`uzUY15>wTVo-sgGO_pSA=^{mfYS=M!~>pJJ!=j^?I`~3Dk zvDYk24)aR#a&T}QHZwK4&cU&dpM!%_`OrajOY3U)ZT3ITfa@j~Im&tgbL@-#9tKwp zI5;Yj_;y?lu&;UiOzi?VIQYNp{WvkuLN^YMv#DlA2G*fYOF0!u(zhO#Z&|J(^Cn^r z{_!eW}wm>Nm+S$`H~91BZ7h{gSMb&&z^O9o(8aQxoL{> zvu}Nr{PX#Qm?!Le@6rG5YZiZqsW~}}x4?JT=8*_>oC_FwIo;1`ZQ=36}2q+Y|eoD64Ka_hM^Z275 zZPgFrVRofJrhBNC+pKCMg0T29nIIOiVt^vezR-H+Wh4Z6@b>A$ur6Yx!$#Y zUn?5TFX|34LV;e}un!*Bo#W`Rp0Gt3)S*9MAkWFqPyAQLo4n=(cm>>=dcU+A-kfwX zUi&a+BfzMGwSG>ufSSyM?4F{$Ug&Cja|~a~z?QlR7j4ehzSUJu0rRsaw6O+*2{M#| zQmoh}e0!AO1hZ+fnwnp)>~O+A(Oka@dAiwdL-mY!(LPQJnZ`(pBkD4VwD{6?9O+cl zU<9?S!n%l+6j`hrpHstfI1q?>vf)wkrP2n*AN1ZA;mQbvZ0GI^@{xl%X%NUAA-(Y3 zZ9+?*K#@9O}((ac$N7gW8+io|3jb)}r)SkcI zPFeSst-5JklRou6AdS*6E_v3SdO%7i$p`%#s)lAq;*+ABYWVTNlTdZDsn6lPPs!$4 zAj;naCTx;}X}6-@FR9q)bCVnNPGy}3aBJI3qT_%BDoL8TCDGlf4Av6L)>86Dg^cUwAb$!UMwxc1IRM4MuyOuX45ye{S(Nl8c6i$X4rln`K%-W|SZ z=etEvhI*8Ai7->-B#w0CdpgKu%YQ(C+)ol?06cup8qa2jiFe?rGaEXWCzH87=2~qV zauE$+Z&Fy9{n5QZhnFnnIrftw@cM$ z6D<@TQ?$-&-W=T2S9fNO{*Hby@sks}aD8maa2Y?eefI8y+^GOyL>DF`;odoIRK3$$ z0XC!s>T)hP4@=eob{w6uTtPsBb`;^p;lsu2M;cTCuab(NH)(p+joN>NKZ}nMki&XT z&6A?nb@BFhdw&j!Z-a(g$FEij1d$+kj;j2UBnZ_jhQT&Ag zJ8Qe+F4J5Co0{PFic#_Cr-l_cv}WK#Cz5nJV+(|D5A4K)dsFq$O>;jUjf1~Sb!dGj zX4Yu(Oe&>MeHAwc+c5R+PTm$zsF)VeC4c8O#{PD?nhRf?pnpEV0me7n>ePY`L))!M zFfz@-pYM-9`$X9NW*@10E=rj)Ve*;Zb)ZHu9X-E(E>5FjqawrmX46~$IcJPS&#jqv z3lGBqrGB=zGyTZA0-jo2;(7yiRoZ$xj%osnseXDee)hG#U0!lcilahv2f`QLQeWMs@HR(r~u=NgpItUo377Rf)>HfE+@){ zo!2Vn2}j=$sO>p9@?44Rh<+P#q!=~%Z;_J*g?5Z4p zhsMV-JIRt>gIckoErImAGW5(E-%6|^QEd!95x!+;Aw>XHb>Jj^ET2}XNfF3V1P;-L zVu9waGI4L!oO-wKI?oe`4#NF%&f6gKP4TUF9kj4t8sA&4 zo|Zs(xtv3%LdbdMz&nkGJg+O=YM;P^vcalnlw6f(tpGUJWzq{@jriG*mFHN`tS-p0 zq+j9RR;->67U;0Eiwm13j))8N+}de)0G_D=?Wy#8NW9COi?rS%3gPDlPZ#CGwiS`So;9nx-skEbpEVoVesv3^WxTFd&~3%bJ<(bu z5F#^uA&{wix|(|Fd^l}rK0#C(BLD<9Iw_FjwHCon&bH>y-3MwJpF#23BUQfg>O<-k zDfQwV9|vlxtgH#q74r@z0$F|1Jsq8RLm!wbt_bY=awP<&HSf??NVuIowQpozuPdy@ zvNkpbyiv*)`b0Pk5Ln!!sE1zO+Hbt|Jc{aUAC#rDve}?k;Rx#j428`xPxg^2FYb=_ z`cq?xV1Z-d7ic~7tJuI8FWLS ztZL(2dt@@t#b=kC!phsH#d!Sfi``P`;FHtohQwgP2j?#NsoRc+;*)^YP#|#rXWs4} zjNjT!7yT(GQ^Cf_|oZX|iUsTB!mQm(daxn)WXnUjcaZ+6d} zu3`hDEXhDCt=a5(psA!%;`-Bz9z})8{p?U77}I7jba8D$N|g{e6Di*GAm!7Om@m?I zHgx?I3Q`MI0B$N5PEpGtU&B7zEAgQ}J^$B1#`sBqD*G>G%MP?Uy2%<{`o!15)5BbV z-WywAi*Sz_~D)3owQjj zc%@0OQSB4Ss0Gd}@1epazdDeRX+!(t`F&6xa`>U0^DCNCC-Akb_41}vL5fhwthRb6@s^GqrXc^;7B%ixb(jO!~A6=0!viJm2SIUR0m?-cvm)n(vhxlCsrh4+S{E zZ-YAJgNR?1Hm3p8h5=P=(Fna4 z=P{;G%NM<`i(i2GmXhE*#6fD+TgV6fh=|(Oi&8(BGa1fhU=u0HEDro@V2-!A-45-W zaWoj?aBNaR2Az0V;W;>XIFzAT=%0^ZW);HbYjfufz4N<=UkcBb&EIt12gbX_^>%7| zyQ&3%K3qKY@A=p4{%`!3jQ zTu1YHE=-u%&czFN|2Rv}OOmFr0QNQsa4OxPVjSx}btnB~ZPfGvrN z%XZbKOmX4wKY5$!6;Z7zzm?Z}OVQt0b}M;G9LaF#Z*~jl@&IdSM}7x77Q&L>)4@!4LX54q{+vMft|wI%Y<{+nG>(~{vZxuJz?LHs|KCd@_2`RaQh0;rVj zPg-XjPQZE}KeV2D$+3K}{!xRDiWNzRL^JcZCLOz@#-JR4O_@3;o+1wpS5yS5^&|aX zrp-$!7?*jlWY?QTH1Hbk-Y7y5Gh-@R=Eg)L_*9U?1ZYzS?1mh!jKzoLIY$;iVQAcQ zSeZ);vH&XG;nIFDc&N#VN26*d7 z7VvTuu{#9wA3A&jryOo7Dgc{z|0#AM_3Ayh5TivKsY9-JQL0hi_xkXzg%uYU6BqqT zRae2m)f%s{0SPafUp~1;Epf)d_%<-^BXf{7zVjQ{3miV}De4;9L5rR(46%8@0q^#e z_&!qn=Da03-U%sHxUMZ(JMtbJm_94y*`jN_qVds+)SQH@N0s(+3n5WrQsxtuafIQ+ zv^1T&r@im_>J%_>%%|hIw){QxBnXR!jjc?iORUY|+hl`vSm>jUU1Bo?t4v$6kjVLSI=7 zx>6W5RPp-O5}tS&YI!>7r~Y~3m^BCxC*jDD50e>slr(%Vo^emPtvX780s#jkTDN>j z#SS`k?o5M-%DGlZ{)$E~PQ{0RuxOz;#C+@&PgOtZ9l^ft zfz<(YcjgQ4L&sW(T$s;cOsh*dd{2oqyYfL|^#*u2!8zKiiqb`mPvBpA*-Zv~T6E5ZTVpaX4b||MNZNiN|lCt!KU5gg_j;FDr{Bd zkhv~;L_n`awq4DTQ^5D6yQl-<=MA$0@rmriv91hvxejB|eb0HDsyv8KQpQg2T-JI% zOwPJ9606LYxuAKlP~ArAhtbf^LxrcME8Xhm2Cj+-Njm7Ev zhYPc$Fw{oIts>&&v6etF_`L0)Q&R@nT1tf5p#;pOTkS{28g4c9n!Uveq}F}NA&E&{ z;H&B@z`|>9kbZCWiqiS$bS(GP^hF>EM!=3+AivGz>o~tUFW%934qk=d2GouF#22!* zbY?0=_s0*ec-b~s4ni5TVJiVQwl}=$5e}A8_&(tuU;B8u!?FysI#2WqKaC8n1De%% z$)#+_ZvpB{);}}Lx^~N!l^@S{Cy?7mn|5HAL-k3Nwan4_r17&6~Q_6}$t4dEF& zzwlw1%O1&@64LtLZq5H`M~`^%^96ii3GzZjkd04$ya*OFah?0Ash>IgZ6e%td>40Z z7o}|eo~b>LQa3(*lz*d?*+(qLHLbMwTyX16{?;|pmXB!)mf3tq*j+Zv&0>i03Q*_p z^BR6+2?q_2{G`?-7F^ukACUQ-D4GvQHT=U_j6v z7y5GqGjZArkWLigjPXv{`Y0(Nw%Wf1x+=HDgWN{4W+dTx1?Z}Qf?h6MI%I3KoD+lc zSXKHiKwG?|-=v*iJ1A~sv@8TY$}c7+meH9T!OPAv!SD)T7$R75+QLyDhS!f9WCWj=;GC7XuQpIYuo&_`Nra-Z0cDb!b8}-NPddX0 z5`JXtCn`mN(zI#ceg9h&G_Te_J4{5d-uy#tfsOunw-mS5a&yY>%@0snhoc5ouI($t ziFK))C95AvtbDMgKL5Lu@Uf59Al&bk)f|Ped~__6#-;hru-WGRbAkl2IVk(`RsqIw zNqWuYLDXv#=ZUn(oMW_%s22|A6@t66uGGwQ#_Eq1L9Nxd82h=++A|-YV^LZ=32srB zjFu&#P5zo z%YybGB1>^LKFSBZ2Ng^s2BkUn{tf5YmKPDHo^sRUz>`Q2C8bN3OGX2U1*^zP(dfO8i( zVhkz%oUv^AcDwmtU;tJ(zN7BN_@=6B+_UU`CMN*HfCVq_iOil2+`+H!9Cp>V6raO13i z>5{NXFNPN$PlZFUkmn0<<4odP$d-K{DDyb=P z{OrDnwt1`67bR`)5J6v?kD~iCV1s4O zH%{c4SB7j!?AemY)&nk|Y4Ve8U8JI;jc!sOn15X1sdRcWAL4zp7|IOSgNQ(us5he& zzz4i9xWcQQ(|QmA z%YBv)iy+G7T{;S3(wA@J(zG#EgHw6)){uOqo?mx|rQNukB$5Te#hmC*EQN-jXZf-z*UvPt}&&~40!!Ofzcd)yW z$MV12kH5@N3ZDHvfK4+Xu**!tXlN?ia~ld-S?YW7A~s5S;)n_b zhRJ#Jnb~+~J7Af$Er#(Y z7~1D>O@Qzblh;=2pFR6|_1C6vWhQ|u+0;~?Ditd9+e0#G#>}k=*zRlDtS?m3rQ1*q z$wQ32;^54xkjgP-`t!QIkMHxzH~)S`=rDVMb2dmL-|MK7>o$1nd20v=4tme9JLS!)b(B&aqs zA;j0=%X(!{)1W`lcrZ#^DB@#fUtH>z4|zv zB11^@#kLt}wFaEMJdH_roqDICjXn<He;F!5RLw^MdfFq&cI*6w7hrX1@TeRT7h?e@gx;mp?m%Dh8{gB#@ArXF5JMg1cac`j?xvMhixy87yEFb_#`!<@u%Lqe7`z*qk zRQxS4&V~@T+7$-hZY7abF8g3Gw64rO^6on)`t=fS*x}QID9e`t!`mlGqH3W@Q0`53 z?j~y-l`< z7e`mvIgV3Ia=X|Dn-9tIYKlKlr3Yhdj3ikJ8jN18Su2%}&9REC=o|Hlj6H0;K1w)( z(_U}CfD&`mPg?M14bw639gJ5_d=kGF;Vo$DFF8F^jZB|P3k8%NGh0orjVo6r*XPbM>rJLiLGb+6*D2TgP$KoeBK&EK}F>b@02tC&_STOu9=d}_JUA?R^!1`=kE&gNQ zpL#HUG<8%N8M$VYT5`$2^1Ypu+3RT6<}i&~5f19@&+c)8Z$uTy{lAY1Rsk9AmhiwS zs5Bx9&ge&$fxXU$!&&Rd@@<3a=RI%RT+V6X+$a4)K((Q*5s~K`ntEakQxz5LIJ1-q zo&mj@y=HeCM<%kL-YjC`Mfo-59%e)P&FV~HDw!(kZecC7G;ePD zq^CNl0P=D@d=z5$F~hPx%ml%%)`UqJjDBCS;RsiNN6xvn%jk;irzoDL4~#D8XsG~Z zXMau|SfMPDH-psxo{xD^ZU?L3SEn^k%u;nv&*6=RF~wjtYPofeqP-qPwL{x;#b?ll zdUVP>5R`~*50?NFtbr+VMg z0P7eI#i1|oA+6m$$>fYplt{$}@o_j{^#JlUF#L4^G_qhRq-A@I>W@BuF1R)XKOB*! z7u2ibw7TbHpc!p)UA$8EQPx(%lE@Q6v?0)CPYD6fDU-F^KnZ5390tXDG2fqBnj4xi zB~`tp1=cCZaYY{*bVf6$o~xCWYi!;xdjLVSx$r5}Hx?_?67#bHzUU^>{@`blySnCR55VGIF!M!}h^80!~9{k1MS@ z^d&aYg6?~l4YEQKG$ROCvh5ouez~GjJ_MYe0S`MSn3`8?VJcJx6-;iN!W_IPrL~6U zjMW2c@cZ=!Ev8yaaJBMhIWH>34zc3XC?_@1iuR~>d!k`#p;yFInNK9T_kI)+${3(6 zT=1tk*<}SQ87iL<7kGuB)iQzb7 zz1i&Q|F<0~-vV*Osu13!sF?XCxgntp4N?B4Go!qJnoweFR_{;ol^wQkGN}zojd_x- z`llx7DEaCJxHhg6_oibH6CeLIUi(!K-}sQIMocYgw9+s6>ffT&&3#Lt^d|lCvZDMS zd9-ur>f`nPhq2gkN^r#gZp&X@pW>3b#Qw}AN^gQ!Xw+t+Sj4Y5p;PH+;4CJTgx0(O ziNK8h!oC0B@_#oQ%-N}_DIpyU!mQ>?z*-2Lx)t~Or)9vV?`RBZWUJ+u zmv<)fj$>YA(WrQ5wLL)ZP69K?mz#H4+fSBqkWq}eHFk8;^qo#4CaJ+`lL-7R7G#4t zUsCRbWVWpD`CXa$wpv0|gq|(i%e1(voeKKOy; z*p(R%6}OhhKBbbdNatEe#O5M9+yUG&9o||Yw9Ieq;rX8XPcZ6LCMkq2y8kxahcd}% z*;(rz40>7vXB5C2Y7p&WV0IcId{7xy@v@CgU-oY~#vS70zGSDYZE)SfN?-n1e#&tA z$f*;muWQ@!&Tv?c!U@?A)K3Jaue$$rkgPP?jI)h;J06fVD#EXS8YS-0);9SxPv2Bt z$$&)LD@4(Kw>0?$af@^`@S&bU%2u1)fqkT#%37=WY0cy_#&M5zeOD~~H9iIji}J^P zw5Yiz9B{7)R+Lv5FLau*avk*MnhU}9@(Asf_R--yf#&{#=2i|y(s=s&2Oc{#90}n% zLN@7Z*MYHSp{&^vl4d{%zJfhFrd4LF6Qb7@K7iJGkXqv>r==iJF|zp@KLDx%C}x*or^qR8SMDA=p;Vz#eakb|>ARGoX(E&zQ+T_7d)xL{!I3I# zB6OaqQ_$4RQxv6I?Rt#3CPCmo$}b`V;BqE^2SyC{fHdw-`dxf~0zg}gPu<8ja^ia| z!s*x|dZ@ZQ=E>{q-@T(1&?|!Rsk?=+Kh8v4T+!A4e!&`K{UO6hU>M~`*5X=f;wa{M z6}-r#W8uxJ@t!FU3#CUZq|Qes&PHpY-(M1jyiQ>2TJULrZ|n8RWQA{lGN#o5>j;2a zQ_iDYnpgB^e;Dz@28x$0kxs6+%ZZmN)mu+>P6`?O4BOj-2BytZ-#uz|XI|>!4_AnY zOI(f+Vj)=@1q442jkUWN=Q;hTu!ma(;!&_wM;!f;0 zj_DM#r`GOuyfCel`g83z*lhXQHapmQn6qm7v`<3up^z8zU82QC%Byz44bx6$c&(un ziNfdU9Ergw+<#LO0#jhARpWK?k4to1YN5q|F(S7dwntl-qo?4HyqfY}9p*!5JZ(qT z9=6~8@#fb$W$#lMd^Is?8v)}*i11pPpRnm*Q|%d*+eI}ulkBK z^w$f1{V##U94Y+j&f@>WmaVM4{pd6RsPrQ-t9M6y9p?Yu z&UqaAD_jsqS9j>^+P(NanEL9N@~4R_AudilS8J|2DU>J5Pt9zb+H=+z`wxECL}toN zZPN(NjQNsoO9*pgjlI8a&?Tj<@1a`4=2rp8jvt+Z0Xomg65zYcg;PLLJ~oKPeA!E} zJBke4B1fIMWg}Rn!N=w?MB`vhc`Gj@rTdn@^ zKsq83PN2xWy)XvIE}{Hk#EaND{&-vWBw(4)(n80GJZkFOcT?n&jwTOcO$sW^W@By1 z?A@iDAPAeJQ5VGvmAbPMInEBct{AiHVM@6!8L9x*l;n~26@N{~S11+3*96MD-?^eH ziy{Q%?e4juYUJ7{ei6BXN9@q_wJ3jM(1eLl*r#i7x$n-2)*|t#0 z@+y(9Y$kHUX3Fpm+eH+{a9!gm@S70lAyZuWq9QBUbGWy}Pc`X2CnzM0DAVVn8h{31 z%CsXsw;5Rcr2$Vv*a*{sX^bhn#I-~h6J;h`ic+33WLK6d$FTj^@sLYx^)uRh+&U*t z14J&>pZa$csIwG=A^l@Sz1_1{5c)7tAk7=|k0jKvdFdk>+CXdaodyJbUb}3P{^R`k zFXrYyqlLl_l4huVbU$BpslR=t%D-El8&=}SW_9X}esT713}MN?d@Q?+_w?G36atU@ zRp7fOD$2iR(HKgZT4OP{GxnmmU-i0Jwp}p3IwvBZc8v75rYULvYfEE>fLcgpw=!!#or}9#J zr$2niLv~Ap(%YCDtB4J}H>BBE1>SP{!Zu0?8Bpp2(c2K@VXc?m#itv*>rm@HLfj#Kiz7g{39T-Z+H3gJG3?hQBza36JBa2IC0>&JS?`BSe_YuF5u zzBhIBcw4O>OPZzgp42|bhj=-(`n~eg*00Ikm~Jx!eL*+0-|tf#KmAqW4pAXx<8VyC zvi4R^Pe?j`xc!A~tTGhf8D9-kdUmR7&0#3Ov81wWvcYEIuRcN^Bh+@Ml?A4~Kjd#f zJx*&qC5~uu={dboN*fQx7i|fe$w%voRl`J|>BVq;tIOSHmiiBRX_UGYdR2oz{gPm> zviyHaQvX$8-eVub1+NACS&7b1HTZ>2dj;s7f7MUzm8RSOSB1I~Vwaw@E3xti3u71= zwnL+|?v^pKpv^j=jz-|$*s5*vyHW10GCZ(Zdk5c8u<2}A&h}H~@Drg%(^Vh*j_90S z`E;%Dy%`SpSTjo9GD$|SfgzI8&vxr>ijiK*KgQTq>?iY6c4)@mv|S6k&X^_Pm-dW9 zXyZPdUC$n5&#rHZs;U%yK)gV_SbG$5a`=s@*;*;=tkMzt#R2uYL?`JdUX+3bS8jp^ zI#nfdO44V{DaUfe#6PwZ@nvrmg-b$rQ_rwh8`qO`h;~J2<~EGzccVq2|COm1Fb{|? zhJFs(z;Q2j1b=p7rsxOGhaL?q0b?qwvF4QKD}KO zTdK64+v&BhVCmZAwQZNIfIwNWU(TCs=$t9+$&qk)Z(pgupTm-D`yecO-mB2Hy(O7G zS5qv#pJ39#w{HP9&r%|~le0kY5+VDl3vaUTJw|)Vu z?(PXXUj6_OMOb<#-6gIGek02Hz+8}X5MnyXF=&%hq8pr$kBxbjeUQEOY5i$*$LPl> zQAPQOl5Cy1!cUlA7)|_i5oRwZU-0B#r#b&BGh}B+*(yDE4fw`JKa#k2RN{Ff6+KBU z_|-!r8{hu*p@@IT$NckI3ik@XP_|-x HoyY$N)qS`{l2sDjEI zAxH*>m@*Fm6(JQ(m=c17K}8ZsWK6=45V)^Xdrr^&?sx8UpZh)M`-Ar-d3W|+dkuT7 z^;_$`?&@r>@}=gNN=iy94o8o;DJgyCsHC*UV8eRg8_Gn*H^6NT%FX_eQUyzU0{F1@ z+(D;A;}Fk%Jzw0h9TAslE(OwIX8} z|5ER(hv6Pi?7jp!e|_{P|7}Ko(pzFdTg4+Phx$ztd0p*ceP4}Vn)hWmpYT&4Aj27dZ;kZ z|K^rs%N?fVy}-5}TI`!UsG$%1_tPd*fFnvus;9Ev-Si)tt^sb2KLvgU+_IEf*8(@w zM_r~9jC{nSch4Tyf3L@;-~7_${di*yo4|%x`r7y71Yag^1csmbOj#AUJzV=?fsG#* z(8dK~6f&u7MIjfR>n!kR=WR~cvkkCKSiGR%GjnffW?k-U&$o=}GYM#_Bh{arudnD> zni^^p-8F#+um_O`77p`KzP8gj)AaZSXFlY!hf}ArD%L3X?$gc)9N;R2s)S&}@<*bp z(y9x#WT#7fg+6H&ZJr5J33x~{XoC`WQ}Tpn-Lcj;%;Q0MUk)r*tgR89NPSz*O_ z$kvSAWA&Mr9BCED#O}*YYaSv$OrEIS7&s79*}|&E-VO_Qxbu8}TZ0i(5!a}q`T4`& z2D1Wtf=TvsM#I{@Dy&>0uSDy%jP0#(;Yyo2nuErhR)>E^vXFNjQzF<%0+{LAA!cUC=Of+HSd3Bg0Q%R=DZ{>+>JyAP z?S|yXlZgB@k=0;{pv(`E)?VL0uO@T-g2sKUV`pAo)Xo_A{9)XM$>gWtJdCQEq8-d( zfzMUz4K+cvetCgT?RGk;5cWgh5&XwNsTU8>si^6CX*9utG{nx5zf5p*#0?21=RNo9 zEI9417Rhp(H#S7z<2%l=YtwdN8au{)C@vsnov~h=>-7E(hELt_bT3+CtgspDY+WIb z-lOV2QoLYhbpn4Fx4);M^XvgZW(H^qX+fRh7S!Wz4oUa82E zVD8?gddHzmsm<743N@q*lc77-$6H!I489cJ@H{XxeH`)&JAgrm2Cd#4#gXn$yGJB7Q4+3O>r z$@ARPzQ>6oqnvlpjobxmUPjeNx3a-XqgppbL(bqz4@A<4#6ze8lPcDqe1TT3Ed@uw zB4!9U7|EYXwPg~yaVJVnhxaBt$qeX-;KeDK`vW4wA^jR)99%G0%a1rb zBXg>&6lqMX-_52XgRv;pVKykOU7VE|A;a}6FjMrQ3RKp>h7ygj?LAEy^W^-|zGZvN zf<@{@A=gflMQG-^@(fkv32u0s`4A_Gw*uQH7KvYsHKfb*OAd6!@*KxP%5=R@Z!u9& zt|>E)+j*U}tK@VHl)F)hbvX>|cnDmFahEz(Y3koZHaFayfKy575vQaHEw3i9^tgE& z_n-{|FWxUUcX&%BZ)s{R!)M}c1+HRtebqJz&Q22BIR2XnA>oQ(5|ubYC8ou1P77lm zp{V9`8f)IKhY%^(W#~QVycKlj(MeA)#&me*ar*rCV@1)`c)&YHs9 zG~_Qv0~!*ZRl`YX?#&1{BrhU%*c(yS{9PW+cb}Je=E=ne`uICE{ftNmdsOgBo_1HiH03L31y|DG9O%*H-aE{m?Uaj`llN* zt#(px3=`qIZbKp#N(Vd+^ch3eJPh4;z_FBr@PF2myzymCA@1&KtkjCOS3BDfFzPp# z8^CV$b2}ZCYE^769d$YUvZ}lHKACt<>B4|A-qgP-;OD^47SE1)ToMNgOI?-BxBM%< z7pJAKc&4iGnXy|j$p*RutMGG09oUmcjz7hrfc@q`4=)i+NwZp09; z_AxA$OC9Ie)Ts=9*;=HELV`3v%Cp(; z@3xg3wkXl6ZMpD1zVE45i>5A$Nq4Ya+zD*Z3 zfG+dQw|xStv3+ZuJV3%KRPVD&^I1!EZm3d$C|2r%ux=i}x%#%T!>lX7k%d8>N$0sbfYqj-y=@z)8k1nEDCf1W0!jlb|waw}kBAP_#Dd&v5m{;~SulrAGAh z0Zce4BsGAjEg4zI?(cYHQR*yEt6Z22dsZ*1<^x)};{>Fyc`Lrxx{U7uOAvW-Zw*jf z_Q{0=Ok0H%D(KDms^cl@1X=Xsq#t#3fEplzjQGHLJEMkC<)n)F@*H7*J-fMCKj&PE zAp%Lr+e!$~!i}P;;T&d|KU)S_|L~SyW1tWW!J77b#{J@;6vMI5;pQH*)3drgKrGrj z=tYw}VlN^DTBBwNe9GzhE#-oqh}}F=9CuWc%D;zc0~fTAMul}XA^|if#2FN$6JKC6 z-F!Re?rDzJJy-vcC?n9$!pJb)$LN$@;e*3L-1-f;v(|RZ;|%nkzGZ|ZOmISK3K4{H zMso^$@zn99FQxAJ)id$Zatp6AcOr|>*-Z+_Z9=)nFOV&`OYFQWbF=4;3)ocQR9@Km z!XQC;P?%93b;`~kjz{?t7FdyND9Je}v~_ZQx8ioXK;XUNQqs6m#HY@)>UB5KeYUhM z3|w^Xvlk@nPQY=ZGY=D4m0HSe3$*dL(DNKmmGY;?BBwfk?vLSC(K?5N;&r4n9p7mz zPO0J+aOf&N_pn`SByD}GbyvjUj@KoQD;~ib8?vQY#zMGls=GlV{bovH#)YONg$V9z z`%pM)6;+nl0i!C+hjt^9)WaxV`Cg3$m-%J=5Mp#+Wlj00R-60YYN~9}rOPM{E(TG= zDs_Ztp{TAXeHS7>K-7eFLMuwmqz#-V_OQZ^Stx+>N}4G8LicKJeNH-$W7BTI23a-e zPfa%&f$YahYXE&{|Q+!4iEs&9c=G6}h!g{aVQv`$aL z)u1FP$18U>|1gs#D5;wm$jV7dC!nALwZl>n9LA$Ghfz^e_^1gsO&H%cABm$bP-1ke zlQ>O+aHs`0?m8x1NINapCD{OBntPed;)}X3(>hEiTE|e*%Wk&Ciae6y3QNQ@uK60j!=Gtguf(ut9 zu$Cb9bg$V+A|E~~3FBLF`nXIID_mZmc5x=>1(bO_^>rUDUKdiKfbNA#K?Njkr+7ws zadns&q~jT-{NMl=nI(CoVVzR0XN)x$;x*XL=%V4``A z22w^Pk#-$zc9OX0%orxEqKy*9qt5vBi#^`D@HRz$1dcc>w2E)axy#~&-HR^Z{?eUL zT2RCX49vuX=rm{%pXQ$g6H_BSprQ(0%Pr}ioj469t;mACEE;s8a>J(Gr#n+lQm<_r z8w9zYhucaWXc+vJ_HF)KemV?kEh%NH*jb=*fFhoF&H(!zG@a2HZund@*YgLSBVTcxid2Gl{Lf8`a=%_&H?)Z|QNKqL?+p%90W00zwk5<3?>M9{Q!U{ox?JvE^JgYMoxld-5FqtO+1Zfdw{IeCOnRs#cX_^!n3yiHHF__=hUM9#J2YfB_;Dbt^JMzzbh=Yq+V!|Nf#xyk{f}N ztbTTgLCKXpYVc25bezn4W(G}mu&p^EpUxR6G?=EGWCp6YvIX?g&0Yun)h8CbmPHTI z&IFqdxOk?%c|F3m)9dVr-odg+D`}le#E|t+rP7ZN%}w2S=^8Qu&rCKgjwqCeZq5;qsN!a(V!|IXzO^jexw)Mz7=VT>MH5Cl3Y7)@h2Smy3_7kD$i zCc891b!&l0y`etY0o$WF85fcuD18mu3^4pz_cr}}G1LuzB-M;yk^D6HM>YynTX5sX zK(W(87!6z$`iO!$oP0d%j?HqSVl3m_e$0dZ;8I zjzoPelf}@(-ulZo%wB9f!j+ufYNc2}(*Y&Fi9Di*Y0^y~^;) zV_Ghh<7|~G&`w8Ag3hQ_m zU6r?v(qxf(jf?}<1!5ULH)41n#u$m0HDRQ}^1WAi6eLS?Fq!WwHq|`5#HLY_zpsAB z?$-k<8D=}eH=|p-))z5@dBJ0eX!jnPP}a^PYk7>6)#kfnnZ2oI<^odIp{4p&gFEeI z@TcSmh33V~u(KW}fGk1h1QRycJXxaeW^^Kg+tMJWwGXzLlG8`8cT3A@wNkPm^YZGU z8Wv}16Q`8QG6V}C0B{mfBt#xKp=Zei(6Y2^Ky2|iHx;iITbVUb)-)kae7v`}-4je| zKFLL;y&|eh=^ZuH_$x9tu4@UA3fo;MCMWt<+=kIz>8wp-f=G{P8|*N8Fx-MIDqN8r zAJ#2RD~)c{DA9DekAO|k?6I=?1~XxvXH;230JVtbF<4S08#vJ>+c0)75d)&9Un~=> zQF^&p_jT*q3%p-nP?9$(DX}&wt3F(dkr9Yc#RPzr=56^$NbespL(>nb(WZ|l$kM$j zoJZ{AcL3srM>89H+eAMc*D8^1Pcf1*R-gh&XQMnR=9QkNWXrZRSGeNGQ;OL~nl3?- zkr&-@tK;`xbFGA1zNfI8tMQR|NT;(-s{s;M?k62hrTcf1=eud5n5COffL(efMeS8Q zC?1eCrHGH(Za9eoRx}SrJ~0h-FK@TAiko|LQOtn~ne5zJGlY)b=`6O_s>*dXyPmM=<5VB)ob~dg74JC`Wcor`2PRgi$>9pC!x@Iw(lyb1P9 zz>Uz1vi8HlziW}e(bW!GVMF(vKt?=Qc&h$zHvf=%MjJ)OJy zYAmuK$z-fkXE{GSYIt%U|C_>ZG^sfn;WM&$cTSV}mTgycQ5~k7IF8teqZX~|ZSlku z?SvklMI=>E@+r3@QK>e_d}4b~B(4-A%6$z1fb0(whxC@`mIg=@J4`o9%X4dFi_<&8 z3CO8LYN7&;3H4L>7bA@**%P$9BnUngz&g;%8C6~K0oG_^K7ap$3AWT9JLew6tviG+ zxW?o3r-1f3BBvt(rl)0V?R7k)6^4E721bWJ8D6#=4so}IfoD?&NX@mtmKYS#sMmgi zhGs1v&g0V9M1~s{``g(Pa$)|8o%k4~u_Pv^>+gMcv-@;9cd0GhIJi zHr^KXbay#m{NG+h-}=WdpBd;TZw!6}d3x{odjEanAD-UnJFB*(b#3LT{7=7r|Ma+M zU&^(>pHDt;{xtC2+}Zch+uB-9AMN(u2`ux+G|zh)QeOvVo%*&c<4+>)9S;umht8Th znr`ZI_;bCkA1bSi39%abS?feQ-c$Q=nv*_vVOGkj&wu`BvWG3-VLhAF*G6CZ>%$wb zKXCKEG24H^gf7U^=<4#c@4Ch1)E_@F^p8zfIh#7RBzL;}xtUMQ>w-KoRcPAS{>hYg zyJ@OC;P|ez(h@uS`t?wjKW_*++!R-tC>igTjXk@o4w4t5e-4y^2D|>`@dxIVtf9hN zhO07Y=35J_<7*6|n(PJnJ@h_?DONT?6vd2SW89mU=bqgds9l=pMR2F-3&SN$`E*Om z3IN%2ZC97YGD5md!&|wO3x`)FYH`9+HCyczTkXUBh-Qe=y*7(4FEfE;%f)8t&77iQ zw{)qyP7@yY_SxVd^%km*`h3j?1$i%7=03-P5IviOy5^+S6&WkvQY+{?iBP33f(3Os zFf~kXoFJ~4>6@8isQnnvSq&p0aRoswYOn>WUb*wgQ9t(7(~O;WC#~+d4)qdUOvFd| z4rq7UawfDbvS(SFz~wH*A-|vC;Wv<}=vyp=pytqI)XJmFU*wu6$$Ig|yx4?l*mH{# zSJ4@G>kh){-X4{dZGit)If3h4S{8Q_V6O1nNd5BJ5N`~iHL2|04?8>)B`(sV3*e+ozC+UZ5|uQyJ& z0HN4#bQsUaD~RKF6-AD#bH1NO{5~JUDnb|tv*bnS->w7-T@FeDa$AmLZXMX?xjmzc z=M~g5+W*yk;(jv|vmYMQtW2_f7s_24^LVR@kyZNE)_TJlzR;W^9z?&DA+v=22my`b zG8(s>o2-TXXfp5B$v-=gqE+dTTSroqULprSgm1qNlOvR5-XH;N^#h%ZkVJAbm*Wlc zdgV6=Ww7dV|CZ7IHeB?Mw2P#`9{w=Sp4zsX)%&>B=#_5*TmEZ8B$gq_mQx1*O9#^`#HEQ|qxoW490S;~J z7woEUxGCcBkHSIsnS-O(5wRZ2(%Z*mnoI^c=|2Aez{n2-PIyjxQiE`Th-S@&(RABp zD=ICck75fHaYk++i%OI|?rW~^b-FstrHX}doQb^G#Vz({+0G^QaneOIM^3*7CcalB zl2hDqJUM#_fE?L;4R2X;w9(%a)u!AF9!U`2S<}_-$J5A7m$8qw!Ot2drwG+rU9vp1 zU8uVD<8;=pisx6_U18eU8yVGcBn2pqQOdSRJN-i|Oh7z+*-d0U?L+ipILILI1+th5 zu}C`xcAWAB$_deS@9e5r#XpyBh6;Z(G5sCK403L?C@5P^>)H|USBKUYX+E$)>bOjv zHy1`8id??U8{buLQ`2RxN$or=<-8l zPp?0sGu`FQNOAPSzOLHhvs13AU+6Rhw~HfiOPziqg;QV`@{facgeDoU%rQS5@w%z7l8r;h^6iPx1D)`X~P%xVe94U;kft=l>7Mp3dhWsw=K; z`mFKcU&AiDS)fYwYU6(z@cdgb;cDMgy;VVyVJ5+__UozI_iFy9B4<>E!VJKk$nAp> z``?9>A7epQDmY@eOwSrqmAEuFX2-Oua_%XyD;)Hw$6aV)zBj24b^_UN>X_-i$2~zl zSu0u?j*n1GHBMKIB&{l>>9#9%uZIM3S7x&l`b*I|LanHZR`Pm|TH4UKFtMo#JL+ zx1Ci~p6fJYV5E_D&@rILB{y+ldIbK?J?{CO;6=YngZZTw&dt1Pl>#ePSpY#z%kr4X z-C2*k8v%G5Y{SH8b4n;s*Q0oiSnwT0JHma`I-G7B+5qtvQ{Uvw`FAlIla`8eV*S4M z?di>DFlt`m9#Lfwoo8xs%0T^0&nospT!N}L0z)Y|{@5n-5B3HBE$GbW-v_-OJNT{U>gk-e!ZHwz@u2b$+fp&g^~ep;qlI)XCslqH4g(&%yS^Wq?GDPAb0Kyc;S zm?VzQ+!e%HxhoJ)d{Cd9N5;tx)ciBfHQTZfSI#G5H&u%n{iYR4HR;mF=& zB<6JDLL)QmzGm97eQsusc_0NFDDS<;B024I#K3Xs7dAr{d}qz1-6a2hkKG~(CyYU|)wg1QAjT#>F21FlE|Z6~#4k+w=jc^?CgD-W8u}gN zJBC%WMCzdAB&Xd(((p{Rgk7S!v2YZ5=gRYIJSWBdh_N!UCQ5QUul)l3#R-$kqf16< zC=V+^kyqG#jf5X!wEsxb?2m!>?fdcRPq_YCMRQ1gvD1x1_2l1FA|%pXMp~cce$C}6 zk8{G$>`oZw7ZTfb!JmoB2W@NBodDMLqd_W^tmYd39N@^@v67#opij`p>-Ws95mL z#+Sz4s@%lc5WRTvl4#urX)mx4)jr(E>n{0sK^$_wy?W7oOygJDuaEj~`%l+0Lb59M z?ff|TA7|^KPd3Thu)uqtn|{hZuYS_Ye?HUxr|Igy)#?4OXW9QgWKWl!C`Iq_`>KE4 zKZzoR`@ZM#e+_E*_hABnkN}?7GHE&KUEiq1y|kY|oiGUij)A;xN*Zlmzs6ikR}h_rkIlmFIHB{I(nF5%iID%Bny& zs&>HqC56T1z@IPwx<uj_+aC!<-!(tvxY|Gp}S5y~MjE2t89r zU_TdQU#!a}0O-RR#j9(20XYTKmK9mgj-{ql7(?J;>evW$tn5FNbtYw5>I-(nn6!Gd z)bReG<{;F5tg_QaOH*HG6(BlyVe10g!tULwW{xWzPF&EolE#?nDFr(3p()Hjt>e|% zFSGxa{atnuBI}A!iq30B9(TD5c!^%1bCsp6rC&&N+;^N~y7V*yzOb(KNTI92wY!FA zVkHM0JX&|C7PquY+SK9a13W#v$k@G(jr%_V>=Oq_$qp{r5>75&pVKnVgv&|*QRMF8 zJusYuzm2{*dvhCQJt=29t$dmsxiy*mzQ~tvz%kuu?eJS*4=n{|iM1wDC!#+siY2*Z zQ38Lnpd^zCHN=oRPdqhB)&iK_rB5>UXf1BROIGC5a2tzTm%iEOM43)nqup9%siE(! z_Ky-JtpB(Z?5QNysjQczM<02i*N@+?p>K}(B#1BO1@4zt&bS%d(3mY-j(;w+;tzDx&JkM_b*-GH)iI> zo=JxqW4JL>gC{?k`HlU8r#IA1%rHmQQB#`MiOaG3V{X6i?JjzDR9&YrAe{9zRV8Za zs_f*1Sg)Lb3XIY_>87r%s{K3{9MQK$&t@=dZJ}LU1vs{_-WDSXahH7et)Sv&RP5ZF zr(?61zu001J`fdGpSZHT#P`M=-9$(Bj{9{rDJ5+xclE|pL2Ol_mrBGPBwBXirgStoX5JKm* zm0w}Fprj;WNX5h!pW~vRfiuY=w_sH*w&Lif(C_!G)l}T2tZP`lsLJK8lR*xO)t_n< ztGx2)mU0uSjVS=RH~Q=q$5^>Z;Lk8{Gz5LXE3D`>@z{r%CUOgq(LN6}N@#%)FI?Bz zRQ}tY5e;OnRwGg-IHN3tdR3dp_#64`T$mc-HAM$?0TT!l&?ep1Pr`y6-tPt5E1K@S zX&7{&LUUsMi3soUIkV!*uI2$D*a2hH7?E+^#dH-2QWD2>dI72xjRrR6a<-pH= z$~XiMeL7AWtVf-zd?5V+qY%g2yzHw5r0>9k&Qrhpd%zNXV*hR zK%hTmyzWoOhP`}c0K&+j4tSP3 za?NpDE+7DXcv(E2Ti-=n{l1blVG_1;v{xlV$YHb-YRVZ9@IFT)&0aX+KDFuyZm0Ow zykh#h`T#rZk$ij#^O=D&)5QP@Gy`ekx2RtB^&6HxGxvPo zw+kR5N)JIrJIxhO`@?^9$yu)2_%d_bzx;N~|442BPw(CQrI5EWo1S!D)BWT7Jia%# z2#1f3dAVdcZ*-dZ^fu7@_k#YjSAx<5A%R)8V%u{cE~|55hH1X(bNmv1|A$M`#1fUe bs<-3DBZhhW7Z=}E86E7Lk5n8w_49uKV)9rc literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/k_truss.png b/docs/cugraph/source/images/k_truss.png new file mode 100644 index 0000000000000000000000000000000000000000..78a1978d103909ee9b4467729d75ba55f6c35f88 GIT binary patch literal 13389 zcmdUWcUV)~wr_B&Shk`lf>Z$&0Z|bI0ZA+rQF;{+qM{N*iGUykLa-tzAflkONRbvu zqy<4jRHT>CLJU=EB%uZpAoT^^=bU@?KKH)wd*^=dz555>Vy%@m=Nxm6IeugO#+Xso zmZsabNN)iE0Nc#Y8`%N?>ns2O5!Fo_g->Gd?uim!L;`G0&jK**vQxsF^`2)e&Hw<# zaa*}pHwf=H-#+gg003;SUHcPhfaP5W0JKKUjLz5xyU(N*$GYhzmh;XpZVLacVav32 zRKmU8Cy(zwX*+5ZwM*nT?eDEwBKkXy{Sc`ikxIzRJ$72g_w;Obr|+ld%IbRu1mKGD z=8LZ5MYRRC5+O7WU^KnR@JunM5e zsmF@pqlA_npoQTZPmsKooff*kFYFSyJzj|t9686OH^f1zu0BDn%pL#rAsECD=v*3$ z1IhE3azQqSc5iufHp7~ajMur)Zk!;$PAKLlP+s4r&h^6oIsiW}??sGutv%nkLpWam zfXr!a5#h!AkC*rE8YS?*z1)(yS@`g~Zgxyzf_-rJ&(FW7=~}>mJAZy9^7H+b-<}*4 z{w!kKzGUGApc4UE>+##CJDF2c-MNtsaS96Vgm7NYM&FdAqhR^jVFI$X?f2^7h0z~z z+^S7Q%0;=(VFWd&wz}p;Mbz>OQ^2KNM!POInrBXt&`yz&VhZxeTfc8IfbTk#A_yETvz+ccqVe?y`$K-<6%ozFfhc;a*w$kd z^ZmZslF145hp1bC4?a6AZnsZo=HES!8b!!`f7P-QBy*V>wiI5Ep?iMFTv$rzcdjEfxc+B{_b2vPvF*uF{!lv^qFW5>`mNdb+vX6xkdYY zT&T{YfPDSceCl2_*&OBz01WBcJv|6(uk`b+6SX${y?Oz9uDy*c6C{WB?O`>g;QN3f zjPWp861`F^i23oyzK>6OMUH3cn|B+sF(WJ)?G zhxp2|aMn}dA(wH9#aSeRK(GL{S=r&S8wN2&`jmj0nDlb46O!r2ag>EBYqEZHX8Ce6 z+3s@^rkJ|uerniBR%&f;SkEaOdvjd;LX?J%p)Lcbmf0zWx?Gzmh_pp`;%ujPxH8-7AOcOXCwh-ZAr`kP1xInlit zjvQ8f-r{8&fxakyB?eV-t@RW+VQYZrac=b0LFhx8d4ufgm`y0oeV*7lmdp-AHvT!+YQ=osPm z5>oP0_z$T;>SqYbb>{eX;u$8MM3}InTcdSCYnEs+Y)|tq2zVuAbqq*CNtZTi%9*N* z;o@tS`&w?YL!sR0(Z-r~0;2j!Z`hN|q8doqi?L(HuDAk_t51%C#u4~-zY zEw8FpW39!cy8EX?BQK-zVj^;<)j}?98RR%0-ADoa{0jn9kpdKdWNSr_BtQoSQbwv=qStY?aUfF>d7mmN~@ znwBG}4uW0%UcSayJ7#Scu*xcllEk)5ym)j!PrfVvLq9{kPB3 zCo<5Xs|kVlTJvS*F*M5vjcmzH-8!ac9oh}e)%2fz;dq3!{IbnG9X;%ugwwWUVkNMh zr9-JV6cCew!QSK)`4jT1uNO1u$t_=bt^qH{Drmc(a_@w(KDI}4a=_P~TC8U4K>IUK zGynj)ZA+O=GD|mf%oW1bhTh}vk!ySM(m0UINQ+Bi))h8Qo;Z{a<4Wc4ftKQJQQXDg zXlMR-SfzzwtBa^l-(snwAkj+NiRTYY1rZ7Pi}t+pa6YQ}0S4XH!tSDbmLS_h%h4}% z;C|(K3*S`wi-1l&U=QBPSX_hEvR6zIE_}!<$;V&rx@8H`SSK(oASqbhwPqb)VE^2c zoYK|Wn-R@8IOLTp-+aHA_l4G_^k z`4c~KcU!=_t-%ZAFUa`WT}Zr<@7#7tgvJ)Y?e4<|e!__F*8dWq@{JStajJaco4`&1 zFPq!OKvp*;Ei(xiPMq%Sf?&Q|P+pZJm;ien7pQU4Lr`?N@0gNmNS6qSL)FFHuWu*Z z`F*U1yNXmKf_BEIR0$4RNhme&s*aIvR)vIxmApiy%=>Tvo=>cwv_yVTiOAodpb#cckML*Ji7;>y)E8p;iRnV zqp&_iO3jmB`zU;t1!v7v49jEFw!S)L)MG08wmI=!P>bSdu?ZvG&C5Cn`b_J4h;LnP z2&_$Lu!V3_9oHA^j#Wn`jo2Sfn#inPT_uB&2zhe;rO-t!AwGX+zp>Pw5*-xO<#@ev znG@)DKYk&2lxs1`nInLJchfjt30(o?P_29&d+cbagb#FfxUbb!=V4e+PF~GfCd-s~ z-72zZEPu?=K&=2P2vkE2nCck1$LUK_R1mlE(_XcApG7P7Q;0?ZFV3Fzaj9v-O;d6c zHAwj}7upHb2}-G_?t#h+L_Rl$biupGh#_x6ZI^v+9n^5vGPeAeYuS^^~NV#44%_= zKCp%;I8M&h?e5b-Tuf2Yos@^75O<1zC-!7{wm*&5y&gB5oSQ1DD5ICt#E!-?6g5EaAYd2)4^i^^n&o=vZYFJZ3_(+X=7|lW%0lJJYu9 zbIR$#*Gs>-8z-7Kn zN@I-e^+KALz+&WiBISZ; z!!CED!seyDV^VawBQL;?7_(Y5p{T@|RE~ZP)iO{|;Y4p)=Q;kRm3~Y-$U(aaSKooB zSxwmWz#IIb?MU`iZpEZ>e|oetj!yjS*h5};2cAMQl^YyV9OaFPRB-s}M4H}@Hmg`a zsGu9st_wJVUQ`B}Eg^w}IPP0obQACwFQyzP($=je* zeL=gXLVZ#^di6eTVm7UcpKy$usghOZiQ`wi5HO(EfC+zfM_Wh$XGs+sQ$HoARPrX^ zN{SM%u$Xo2d?*Rm7|zOg#4bn!s#A`7pb_@e5Ussnx#56Z{fb*7LAf2TV+|X2PkF{p z1lJ$x2zueD6PDq;%Pcbwi(GO41KlKzZ7N}TE95v z6*m_Q_t+mjQ8K#oN}TzD^qu zGo~+fR7nQ-Y2XOrBBzRorTQPSPx6BtZa~f8loX~-iAaha9>{DZ4i42^|7f%v@_foY zeDdVp()1P;V&t^vGZs0>U)f6V;+YoJT4%a+tSPAd=>^!}fs%yc$qzV~z<#b0oeh>? zmTqI-<(6B?SuvB3#=X%T*P=8tothRx8)M1W%5!Vvon}sfSXVCsy%6p1Fg{w;x9yRA zp@lS$t*g0w_P|-WqhAJn)4+Z`<3!wVfN^l6%-acpVQnG`9c4P)a;75R3{rXOqTG@4Eum zF*+*m1KsCk+E^$!VP@Rg>Y1>UYutUi7Tm_+2h3S}idwQYPhF zr3s)aR1-u6S1A~1cPiQ--~jF9oB^ub9dRltcd}adZl**4jNjwv(Zi*Oly@reumJ_c z-RBq?lcXia%WyQ79GaHmt{_1*Bf62J_kPD5PlMdQ+iB9LU+8Hn6Ep|7*MDZ!+TG+* zAkR%#74DcBjkml1JHYY*#3iCd)@2nO?d;1SNK;a&W5+XZITYVcTbXHz{$c~9Q;#j* z>pTGf2>ubcbyFymGaG&edGj(|CcS2A)wk9HGRS;;N{%k-SFHAzAduXuhU|U}5*7aN zSpA*Me7?m_fuE7w>}rzNB36tJYy6OwkL)OF=p~%#md-M^($0smoD2lJ)aQ1s&gh{R zD~+{DA1c~w1ML>Ui214ou=bu@c8_P}M$4kl8UZ(HQZ5PQN1mpW!+e~Dq4sSpp`M3b znh?qHQqT#xl0PIGjiD>*rwHRMRf23TW>|MGm~BlhZRq_ltjkawYeeFVe!Ss_JbRW; z&c$evR#ETshz~KQldT_69l4W0P64L9h#fqga#=;#xTa90}}dZ~-`IW<>X~aE zcFt<`txk<)imxYXNwqi|m#(mUk=uHq*9 z3%ib_!0t>tlhAQ9! zkQo!P6-rNa&`zwkfO2VMS1X~2r9ut6^J&~k_bgfevZdT5{Y%`Z)n`SY_2s0=hn$}% zpFg9 zdKi@6!A^#g&j~TyU`lBjqQGkAD4{xh;HVGXC8&$Pr^|OMrL|H*8XfD_FBE#_*ei%Q zu#DULNwnTrG;#aeu>R+waXxH-1>|b9z!HKoaPd!zt1&O)%YPqu7{LpjdDj2 z$=NuqlCjmYC-YKOvvozIsi}-JerAXAQ8eb$;n)+4d914CFecZ*fcW47OH-pC51c|p z^Y-j5jUFpv?zzt(`-qb&i!R1y(fh1$B?xs(0N{gM&Mpf8T5%-d+9m*iwOQyqZuw!6 zFqA0*_{2Z+PyElH$f00U?XUxj!kqV6-Oum04@cR879q=(2l4#FtSi##KA@u%4+n_W+2<$8}@QpYd=t+wzKi8Midupnwc( zxp|=BSc!6I6s;r;Ll2fV9j@zLfc8gIJ3J<t9Ag zx<=q`u;e@EJwr*(KkSR26OGtDu4ZgU6qT31;(!8!E5E5IyWYy`yYo!kLe>6ZpTn>^ zbPn4^OL@0crzu$Kwt`I*8uv}t3AxRm?J_=*73sri3v+}^)hk!phMz7drYW5!?>Up3 z_oZ&#zDG7VGteUp0yXojet8Z=qRL+^yLL?PaZ}CYOW2nHf9KBZ1+GKN!`OZ&uxtyM z&q%89f2%1BB(Y)V!96+zTOlX-`d+?K!uG-4YF{F&w_B5zhg^CdIF->{AeR;R@^iuf z0k%2{=e*Xjd!4P3aBV+Ief!qh!%({>O>+sgvw@T&4SUiP)NUGiJ^PA~`#5Q`9IMJ{ z#yq-)IM~y$=sFYa=q<#gil((Szil{P@#NrgTdgx0_}XgZ-0%5RVUSu*P!;?7JByYT*Y zA-{=xb3uFaQ&iaEXsyFU&s@XG4lM8CjN^#nod`~Zq;CrEQD2j*rzc-L zLY_e{`W%otyw5!L>h%U5vryL;c=4(owS^U06Qyt2o`P`$Q!Q(HDxmUMc1D-9jIa|O zVVw4PMtP@lr`;pr{U^7Lo$Ke`>YlEi5~_5t%zqW^{%JS<3#yhGD)FlU*m)UemHhdf zHQx5^X9HOBP?*xsaC+VUQ3p<=LOPO@g65qzUpyB(87W&NLihb$NX|wXVs>3iR`Ayb zjYGCwy6a)jd2&$dzQy%ztUbH*gwM-17Jk$$Kp5@%{qw<+jiFDYWZT5_Tz=~vlp+B$ zH$T3z%lQGm&ggF~K7%TBu0Quas_^sBH6>GFPrVU=x+djuko$_?jC!xYG~=^Y6-~W9 zRmo)anQDLD;jQHLzT$ObC--^Qi7SIKHZ#h}S2pRypZMl#J`M?Q?38{#Tc5|;X>n|Q zN~83&#=-wEOoh}nL&o}W@;m$LKM&$}N`;xbGADvCc*RH7exne!LPJM_FFWQOIz+MI8`#zIG%l4yYftJa`ud9hwxz{)S--b&?H1 zRs88v;&z4bm675`+{$DMtet+@-ich(M-9lZ>fgywo`XrRdJMayvO{`hZ(+2p(C-^~ zx7KFz*NhFTb}lVVb=N+XJUTSMr1be33psj3Fk`%FWTGVr2CeVVuk`KQFQOEm)g^8r zE(*_iP@|H(e|-&`h8_#tlDyl_te4tH<}Ff*ab!jv@J8t1S`;_;t=ddzhaXPTba)Ck z;yB%FlKySsB;4`vKyMYlV#)pM4zGVWxo^%?*o$T7^{WgjNKmPs{OqB%mG-+Qd6{Np z5@!8_Vcb%0IBw^jy9CbQT=IS6nWf9HJZ`T!qGyQ7E;FRe%M^uq1=`?;J~SO2Y)Pz3 zeNPlDf&>A(O`1#l&|r2Uy&9EfIBT`G2y2S_42nFjJ@p>JU0E0zsf~;bU7nSzd(@FD zHUd8_gQt8z3i!n+wrlF8uRq*z%=}SJ`yMME$=x`RAmGi#%?#Gp4Eh^=)JEyV^}(<^Cr zP(N<77L7#lbnkO0)oO;z3@JyRrFz(YJZq-F3N0Lv_e8LvYpm(=+2yLB0Vp=KA^>Sa z?-FyXnAAm-_>eAk>?KV(Go7y^Gm34W#JmmkkB6`c0uDjfDYxl?Yvk$;V>^lMRY@9z zk>cCO0t<)w$ctR*_Mp(S=7I3UFJ}B^SeECWXz`oQM+&Sy;_N)gp3w)1PFT10me%UI zm<64md-j(Jsm*LUduHYB9{rb0o!Vu1PN>{>cK28n`D)sAK~bh^T?B@=r|#Oh1H+kl zG3=NGMCz5ZO0n#3IAG|H zkZ{8!A4A;58(#RiofhrvHe@FlZmu`<#;dfjEhew*^v9Rw}ixjTk1fob7>m>ORZpHl=R2?@AS%5%9_dh zHGO2IMhhK4FmRq8# zobZ?vkIj%LX4ec3|FTELqzhA5Bl(E(3pM>#JMF1YPCLa#CGjzg6v?YLBP^Wq{v(I? zw<5U>NLe6%Wi0D?9>Pa!=4qmwA=aZR%_rw51xR4fNoXwX`o+vv{*TJOsW%UwM;j+_ z@-AK4%nTGr1P{xM`TR8Af6U?GWIjizPZ~Sl*6Q_3ZG9dVnmvjwzkbd*uw+EnLJgN| z6{ma$OD$19lc8yU-SqyKAdP(`g!+YGv-)qPByqXDsA%I0y#}*|-MKSmlVpCCRXu9) zx`CR5b|owp-0y8n&yLR7`=L1{VR`kOO2&(`Q#sb=(z=gqM|H2F=YautT-N~N7hdrc zJMZkWKhw2K*ASz$$td=&gKMeM$SYX7Nr(O$8*Q`0X4-Bho2rTBpe#sL^rvc#n-ml# zVGYQB0Tw6?BCg-V+h}}aJc6)2V@&r{=xz&#m$8|F$Auqk3GFxf`cIZ0^bT@Ar1L*R zgdG-(rGG)WqFqMa+&@9u4>7BPX+OVD%-TO;Z3*nBpZ-4rnE!|P5C7gU1rL5EOyng+ zFN6N%#3*upGByAG9?!ow;x)$%lEz&G@mj=hsQ+t^=GNwyds`e?A6U!(?EH1hjWFrt zNENGeS@)W-)tpIxNLtzY=4bz^ovGLe2(8Jri1ww{b_vT`EMZMxI&na*W!&c_! zyRAU@1J{=FL!)FUOewWE1rU%c^wAoqY6YBSq+mH9RE=efBBmP}u>8Q3GDw#e2|agX z{mWjni*H*m;ksNn`d{51er?M|g~NB-i5rL;F`ql(`%qBj{ob}KEBbZznJf*Ez4l{Z zsrx`ZF`)i`FV}}OtE%M8`|1d0Z5b4vSjY>wNU&&Z^yuLenwlmALfn(&u9k}$> zzVo_YHpkmThLJH^1I^Mr{KS8t5BvplgRWF&>9X!deaWjTR^YrV;8NkZtE|nOYaYx; z`ON;S`{Vc7LY5w$!{uKyXE9IpZ)(`T^PTR^)yaz&sy$b(@a%i_{#8LezrcMpip%kL?^lXk84 z{sE5uL8gV~fu7-+3DmH4ewv55z>1VB7j=v@Ri%>LGgN#X8Sm%7`tu6A8YKYLHSYOe z_?j(M0w|cVlf3~p>W1A=7xwvStK4_s_fq%q(mmz){QGB@6_DHoO;d5XBA{xa3k zL|ih>lvFmv4@q5W-7l`X`tMZ($_{l-Vz+1w4&yA7hgp0PR>wR;Q6GPYBB#*@=mQJ>9VmMrSDG}9*TxC7!Q4IvJk4d=583w*59CQsQ{FWbrMv@Q;LmUhLgYPQeeuE$ zUQ{PI;*FovmhwM%w*(wed3_xNZ~hzTq=A7!-cRJr3hWX;q?K4Su+*1y`g!oY zOxE(}{MR>kTS%^!{fD^e??(D>8_N*G`%4`AH`V@E#`6DN;QQ~x!2gx_{NEeq|4e-S zHz&3zL5#l&QWaMHmDH9NcC)vdP@|d@}cKO z(K*$Ld2nGcy8vJRM;DjBVhMcs%HO$&HKL+{jY9HQ=bKbesbq9aU-)91E;aB-0eTJ5 zsl>s%!qz{Ci(3qbb7m^{UC7wK6g0Ik))47QOi$a2-nqkKO^{*#o#JMDF8m$dx5!2v zjFLmJtPRSY2>WL?Z9Wv|xARv!;wF4GO!UFQi`o-T!OHuRh2@Vo2shSmmf0*@Y^?3z zcOA-e0wb8Q?`dl@aJ6$BJN%NI2&ocA`}0l#>;9M?NyV#USn5(>s__fpBE53R={8>X z{OXCJ-U{<|o$H;xWSaf(Nmwr|6eHYYP!IhvKYP!>9sqeu-FJtMIG=vEVl$YpYa?!d zM793NXua_tdAKq$VNsV)Le1kv*LU0rSoqFVVQuE2@@E|!xdS%CYEu@r^qB-}d|!&4 zkGig6^3OucGA&!}&3Ky6o#WdBGT2EAzXbK~LnyOBZ4`wZ3qZb1uab>SV(aKoSUBL9rG zuxhk9yal>1`3m%p75J+v-Hf9@d@Upa|39Jf{ykIiZ!(U^+d)A=)|9c)2zun7XwU1z z`eu)GEslL17VwvKojCXY&XUW4jkHOXqRHgEX@c(u2En@ z-@Jjd14<7g9JgGq6YnNqo z$KprdL(Fi;vc16p9glO{zR2Xv2^dKPZ8`$9>jcOU?Bp89AV&I3ny=J2fz%m9kg^8Y zYZhLax`BFniS@vpS9skWUeh~J?jO>BQuj^Pko-Jvc^J$Yih$g|>aw*b^3#Pz7y;)n z<#VWtB}Q!>-$>TaeyPku4-~=M>b!db4}zH+n^*l~9bMxpDW2^QXYoTWOs&iiGBYxT4$M@XrMoSK~_EHF+3r8Z<|B|$kw zedtxM2U8>9U^NTQ)belEt;?_-OAi^`ygE@^UQgw{U6Z|h0!4czGhsyZbiX@g)qWQy zmY122scl)4l(w#DSzT!uA4QJI93o!u4Gs*L=QV?g4zhDsx&06-Xm8zD;#}>9xl%CxMV5zBaFaf zajB=Mr?3vy!Z)RFtA8MKZimAI(wXXo2Fl=>3a-zL2%e1=N^mhl;Ska(yt7_Uoh`yZ z79C0)m|uD@;pxOv3?0tKs`2GENVs{X{U8y~vZZa+%>~ua=zxB1S+m7b= zNjJd_6hS3Vs&$%vg#7$Md&r~}ImZdKbqYnTp2YC0TdQZqR}<;z3en5ULIwF%_->GC zMmYKI{?EIWZEK~F5<@@JtJ=X-i|!PwpQ)JBfBgC{n_HN<`FjohfBy3Idfsp7iiNmr zn!*VLW87?6L5l=p6d+{-amY=3zvrywbMHKV(u6d_4}n$053-%U!N$M>!g=k@b%u)s zxY+GoYZgeq+A7{YeqUvNv6`z5LzJ?}| z9SZa8d+NuU)7kXy#>8vZhn4!rgQL8w+3tZsmkUs2;Dj(K{ocSfVpG%u zJ7EgloaKPn^28lPdG+=-VYd`?tU~9=ExR6mO~s*I>>fm4r_BDk#G3tkpL)C8xgo1W zH6@q0B6?Qf-Pu&Y-L)jzbnNb}O2+rjcUo^S6}A-}=}u6C7eFp>3=4=~&YBUsZWm^5fbAl)YCor4bDjaCj*9fN%Vl)7Qb9 z*CP+)=t~$%Qs!Pl#GzDdkG=n?1l$2@ZcD$sYc@ztwO$q#y}`QvMW$J{RkB#QLrS3) zXG-Cd+k{-_oi&hJQ-uLp2S}UAmSYhgqs!t0Uu=O1*WNeE`BBsK{q5g#{`=HYc<$pr z*V+Hw%TETGQ`6JaZy9U(g@(vq$5;-%1j&^PhuplHl9O`w!O)F=I`|T9)V=K-_bW2# z*NGVcdqFwbSL)YMnx7|T$Z-mtcWQqfrTKYch5+YA)(U>>w?LiJNoAI;?KU%GOC!wL HEBF5wI1Br@ literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/katz.png b/docs/cugraph/source/images/katz.png new file mode 100644 index 0000000000000000000000000000000000000000..9f2303a21e3a38cfe573aad4952e18dfcbd0fa9a GIT binary patch literal 13558 zcmdse2UL^k)-DcnWIQU2l@6mQ%@L#*AsIzMQ0a&Th!Bt#YA8~Y=%AuPKm`O05Fsc% znSiLY1PAE^q=p_fl0X7P2q7jR$qh4eW={Ff`TzUhbMIaEvKDK9CF?Ewec!#G{p|hS zb9S`b`Ddj+OG!!Xv_EsoRZ8kBQz@w}THkDyJZT)+_E;jeKwa%lN@4p|79=-c2c9^8 zLP`psxqZX;Pm=p>k!LPJrKDurHvhMDMwI(YNg1cvpE}`w-Dj!LI#b2&Neve_jfh12 zSk3J5R?-0Pu)GXDpj7*{Qtf&7oPMqcdzGD*INd2Vv?-77wrhNUkIlnX(8PY`@@6}#; zyQy}^$8tsE%}MR!lIzy)oA@nKQa4C9{&NxO4rZ^;^|I|U)e4CTY?`-h^_#lDfqt}A zLbPaoX{uKssBRjYNs*E32el+}^yqO5tO32f?6y?70T2-`)T6%`OkVSuw{DQV+;+K_ z9EP}JzT8y1b<47m0&36KU^XDWdt4>N{pI@XQJ6j?J!8eEzePS(VF2a=Vz5VcD7Nlz zfOd2!-!etJ*(q0=j)5UV-2i2_5{esE_qeF@XJ9=m>GjN~%bw@TKKsurB;8fy)IX

2FcNET8=>#w-9q9xnJYYv6B2096SBSEFO z$VEp+puAcg87eAQ?;ND6AV(==m+6qQPn?=4$u1EKYe`;L$nirh4}FHtCuA545 ziLhWoSs?_yWSmJb^(MzgJHqqQ?N+- z3biwri?M8ISUNdJo;bA^9R8ycYhk!7m#aIl^Q97PuHz7 z1k6ND%=MZ6$;vQLXoyA@6K^Xl|DJwugdWrnN?vJ~VLj`uZTXypW^*rQbT(E{W(m|r ztQq4h8$Eoib5*7B&d(e}C$dGRTAIuEJ)Q|4FM3lZKP6!e)_mpyjTa20T9;g(tTcC>Jvmz-Ijv6FZ~Hy8`nu>-y!`w?S~I z!E%~l5xk9)FN0ESr_isTsrJ@*wZvx-n8xIn#8)joB?&fGrFBl*nw zNCHbPL4D-qKn0ii(z|yTtl;5nQ=|Zpx)|ul1mvovDYp9~1uAybxRysY^zBm2$QUq& zbSju4cb#&h`R3YrO_xiQ@pa$uI-}WS@n%a;z6 z7fuSkJ7%+vpEoR@%Gf3!EJH^B6Baa{~1M@yJVq@ zZK!X-5i)d`{YYR*CMKzUBsL!$UY$&_Wxv>O(I3=b1j-CF#&){oD=g&nwVNbgV&Op)hi%W0~r^9F;cP}UDC7c^5?tGy9Q(l!>u0D3iE~sGo2Rqbt z9C##CfEu3US%{_uA9hk?es}~vEO-mFJoEcNT|uyd*)nj^Sr`P*6!wXt`vq_3xd)~c zxl98)r6<@|U9SzF&H9bRWw?0HHLwPi>}{7~Ypm#ji2TNN48itElXw<{)xtBa_b*?e ze$~2Edmo=K(CtA5hkxJCzViWj?Zt^j-KkW-1$rf^w{O(V%YxpJ2<0^r9n=TS!k9&u zxt5w{?_u<&?c7`gl*^(Y`iqfc-C47_DzCowQUc#xqQ{_;#-_2HNYzebr7|4Lo7oaN zW?}+`1x=4Y-eOmddyaaDuq(95D_YFwtAkcqRs6|kJQdG{Ij)&t>y0VjeUh~uL!Aj% zXk3;#OZhe~%$R{7)1vpmvKISP?-mYj)!}c@7)w(=F|b>VIa*O8%13L(d;kp;KbHB-yi0?iL3M&-_V-4Ouza|dlDDg7jwhx#_7DPqWaxgn1}sb1m~zU&=u z@FcYxp37LU2x-o!kF6$rO0FuNM2Uhay@P%+gWiEn}xM=*?KL*lOOA{-J?m z!R5BbAa64oTc`sDyNO!RIjP9HU)+@H(mmt?>dtb#LBbvZ4Hcp11I4V`K2$C=Ybe6j z<@@h&fJrZcm!JxN=1eiKfo8|)lWDLumqM#Uwv9Mc;yu!zS^5aU;~nO!EuFCva-ExBa@I(-0%AliIv zO_hMeO}aH%m9bs`DN{J^ompkfE@;+EM2UQK=Deo$>g}_MXAu$`3%KsbUglix*y<=9 z@yS!w^e1)8cN;RwAK<>j!z4|qX*C5@(p}lB!dn*HJ=!XFixZ)5;qOmn*PtlozZk89Q*?UmQ6X; zh?+MmS-WD@4BXRsxlVM{8h$oJbq%mUl%6=q-+5$=^fmn(uhc6hrn}-RZ)kP{J$sPTY72Up_#adybhX)Kd`4tEPE2a zl-9BKT0lsJ6)@AAL*7w>X?8gH_MHOEMx2y@(iW`C) zTUgk3LZXgiF}Wd5Z>oO4>-NWv@z&z2koETuY-TtrZCE)D{-b*mr|vBC6emt~hlcFn z`v(`7cXy7T%Cr{4VkzyeDR@@Fn4zIYY43d9Gdzd#;G6ck0d>c|ffQU$qIm4YLUuZm zR-6ZFzP@sT(x|?WR|+V!I3Krg{Tn4Wm{{>N=|eb922i!#a*= zuP6NqbZt4x($+>K|@D05#^PYlC9#VErvb zzwqUgoB|vZlFK^ZZ^o_@>5XVIm#Z>O@ay+zAM!?tF5dMI)$$`o%4k#z|H1+k`8EXw z!F=NRqn=s5S%4aYF}DlEpp5-4ddDzHh%?R3Y@4DW7u^8VDe&v1147M8qx(jJ3p)hK z+U^vNO3I{W)5ku;;+=MLi5^9!=g6bxYE=Vo8cD?j;|wNu{TgVed!XK9m9qf=^V4G< zVm-H^`c6lhU(dSg+$B417;G=vjO(qWWep-P>(fVEtaEi15{>ORbE37$&S+)M2hI@v znP=uRLy`-}(n5R)JZ_>@@0RZ)FjqvoF3VY9vGkY$^WF8`CI#g+vbDx1q2(}woFX)2 zX+;9Ph~E9LF#Iuljy0uueClKyYqYF0* zENn?S6uG$y$g&XHaE8=$p)&Bi_2YRZOVIbD+Z@ptMHJz+?e-SLjg6ZtU&%N2Coa3; z8eAH}#yp-d<1=7@>$(FC+h7;hwvh8gPY2rdVhRH==5Sy4t*|^lMjt{odVzr&N8N5rT}(>;9d)kx$BgQ9jDquy%VMl1n;}#NmgpZE7Y&N zF-`4ROJIa$4NeRxmv>Cd~IzzBJZ34?8+2QG~LAsV-aQP1^Z{grhIMM1vRAPmw+O zJUUw6EW?o%FV7$%KHplMA#@s~GBO!Dpp``<#7rjORz+DCzE*rHs!Y+i1!rS2A4UIS z{aB?cGmG_VweI^)NB2&YXZXdI7OZMS#<0l+FCRN^c_a&PZapB51YCed_q7b9TjPwz z9MPWr0ekDZMSD6m`31G-*JnDruM5m|-O!BV;C($zg9;dSZuVLRApE0(mD59mBtQ{y zImeZLiuy%Jg2$1(11#YEExECKQvlLXHgnCq@lF%%@kP0M<_g1-+8>%V8B+^%-ms<*rn)ah?_t2MU3w%`9wpy7y7P9A+~5zKXs|0`ihFU@2Nq-WF%AY*YOhJ` zi5HZ<8c5<$ZCt|k**vy*u9nnq;Znr*FrygND+cR~bOdDBaAIqcTN;5}3b+fOwP^Fn zD|=PwkewU)rWEd-`@LK8*vrf;u0sQ`d?q%=Gyfg)Okn$*rLLQ5=j7mP*fLBugsCM} z!YuGl7W7-=RB@d>wH70RA+NTR4##>BgHL$?hF)L*hPI$8U^|oZ-3a}Rcu`~FazzH% z6*y*Q@zODpuR$(h)j(z}<~**y@FX%b7Wn(|*%1UY2o63S7*fLTAApv28jx-6xlqeD>lT&xo5Q~0Z4JFa zm>%>`f)9APCfVopgIh_o#sX$Di+l4ZPG7mWfTFYP8)=k$71A&HlLYQlwy7sY*~z*= z>}6oS>Q^1KHH*>uCe5Kk=rzbLSlntI+9^Y7H%~rE@o>?v!Tjo*J*aq`g0ciMm|Og7 z8Ug9X%kdK3k(O`(sVXz$pD2ne6vfZ#o;}19SJ_XA#CPKA1}xhNHc$*Op7ki7B_P;W z{UQF})d~qKM^95ffn1gl!E9adFdcSRwo{XQT6EhD)9|VB+kPgL#jL>HY%sY#uO8bA zx3Z-LxWW;{Z2sl1q;AAYlrczKmi{5gjr?ImZr_|B(rwn;%oK1Y?2tQM*DY#CX-Y@E zsnRg}9mm5!#gJvP5EUa_1BAk~F9`^G?i5zN53DA-=J2at!{X%n(eYFT>sErPh9S9gAysX@F?X5oKOcDB2vFf<)gpfOEZo>ODV!fo z5WXWx_nAB91T<&hxcdC@WzVqDYir_thp}|L>k%Xszqd}1k9w6ww=B!;BAgB>Np4@# z+M@>*V-mD+p^Ks#G~<)Gg=aBem)t_W6Ll#fE2zLLCG~-U67nGnIiFCpM@nj1#m>l7 z_}PJ9T+-MH_ro+=7Qnt}Qv$4CLbPIz@t=AS1oP$9>a&-C^one8LXnTP(_s+mZFT44)aooI|AtDj z;V9?LG<~jg-GUJ5J;$}dcZTy%1lt_gn?iL?3T7l0g`(p#>dxX8uas(Ddj7H~Yo2Y0 zErXAYu@E>k^_!Wh4JQPD_!$;sFj6*E*#~vVj#rgzjv(8pH7jHeZNHS%7(J-N78P%% zeT`gPJFvwdRaRLw=p<}Fbwt?O+tx5F=Sc;W-lI2!`T!v3)gHj$)^T-bWd;ri_F0gd zC2_1?OCzAd{xJtn|F|Tyz;044L+JsoGwF8PJ`18;;>MixFjM3mSeolsn1qWfRxpe+ zdcxG!ua}lg-r@;<@5KR3zld)3=!L+){#&|CfNMuEz3%D+d$4`EWe3@kw_ z4G}GT%fUR}aAL-UT&o&-fwQ~)*w6+1#irSaS#Isf5qeo7FDIvSUkIwPuYJnIHK{YN z+%Y37)|I&Ev>DD(yYbmLg8O4j@7sb^uj818b*q-RJre4T8Ep*KNLKiw72$n^RcUit z=7wz9tmC5P)tDgOw^OmsO>)4U-zIXd=i^L6Q03pto$W-WRtUOr6YhW>mdw@6nYF}w ze$jmv8T=>^Y<5#4R=|2UfxNztSI8r3BO;g7juPrT+ixF&#vh+EpFmT0wG?7=%*g0R z%tJi`jSjG#<=*}EozFhyZIOC^&Abx7#oT15e2Y};LJIz7Qd-eZ^E&4cmnAl#Yxt7wzype-&gdTj$s-bt0+*+^?i-|YmET)+3{y}w6l^86(}DNy;v^o;MYGc)C* zzT?UCtY$tT`T7GR8AnDVLBvb2Q9p8KF=FJ^D#W=FA4Op8qk1*GOJ0*tHnF8?c`7QmKmS7H(NJ7I zerrpS&uHKiMGmENzTArQO0TIS>e70R)l8L~cc`fH$Xess?GkR(M6!1#=s)zfd3||G zvKjmGL;WQe>zN}!mn(CghGYF69UU<{1udVw z_yxAH$0uw#)q~imy@E6GMuw(5kcng-(}yHx0n^rmNimcg16$+q6Q6yj1(Kt{MMVXI zG@0C0tG9LNjGoo1Ck;5r2uLaYq5WqS4#VQ^Y}_q+37#zuoixc9R+*qrH%poezmHK= zR)2Xiwd!1&oTYkerIDqShghMbBUK~^7`;yWygIjjx6^(-pf*h3bEAoNHF`%@ z*!j%!+)iWDN(~!!8Bn<3r{=&AR7RQ3>;v$I$mv3@vnm=peN=hPKrlrk*t$? zG^WjKt<5EU)i$WvSlRL}hn3ja-DP75e%&+vskzomhjhg(w0k#r*u(?4tUpy1Zl=|J z#)*p$RXMEu9VT%seuNbizPXrh2Gf!6eDhMD`o!YeTox@qk6h?3M~X>(dhSKnPT@jw zLG|BwGbue1;!vJ7H+AUlY%WzjCsTIArKroG@67i0oULutrXu=NGS-QV!zA(!9Tv=* zWc1cXa7+RrUKHxDnl>>#DRUL@rmHPJCQn)YE++2ek~=o>9z&Z*Uh3!&asfjz<#I^W zN?^0a&y+8fS2Mqol8SVb z_|Mj*tmoS##{qB?KcN7BGJ+o4>{jMq`(VjWerNk7bbVl8Al>0t=WX%{v>1?m-EZ@_ zY4yrDL1tV0emQQRT7U8M*M1kRej*W!FYaKL{*d~Ws^+_rO>w_IwspMK@1oY9 zEcebIySaZ;I9Jz5Y=Et_t}o;&=)+l9-+SZI(LkihoGm_CkA8_=Tw+GyO?n z%>JhApDfQkJ-sRRhB9y9Z?^yBD4E;5tisFk`WYTwje<#-Uzp0K@eaSN9N(`OM zKMY`5xCa$DP*g}v2|-A+JZjHp`+U4U2{w4CYmF!gcR% zK>))}7)#(Ys(MWuoYRHDmhG$gz_(ZeT@2h(e0*Uz!kg)n`V-RCUms7KPeN5^k4xlYK5VCgb z{fD+x`K6KBtv}n^a2Z!{4r(^`OCvsfoZ;fG)c6<3$V^SoX!2%p4Mw_ zt7(_u$(butb65^E266}4PI<4KOY=*tQ<4G!lvQM|`izEolW zIPmPccd4P0se<7y4fcbMLvlf=RBst9qr8PtH5&nEqtcSSt7>|ZfEo5-F*VWNM_w9x z_+e|Dy%&yTm(@<4T=6E4e?vNqd|&)ig46F?B62U7nsQO5)HEK%Jm%4`8iSsficKhK z8&6oa!3AfmUgmg|sdpuOIih3Btun=|>#B2ds3guDZ{BC#hg|fE!JOqx`>Zvm+mJP2 zMw;B3rSaUKg1~T zC;kUQRm~7kI1Ka-vV4;3jDR03eOZL=aP( z*7kdWW)#ZhfB{ji6rkt}H2{Wb%+G}dRR4_kSLt`_S;;PS%e+*2=U}H*Hf+pZt&*T_ zeba>?dr9?Wu=&v5HA9RgMr&t+GdIdE^|b^n+z<1-V$E}dU=&6!XddizNy^`)EQit= z9ju@HCZ}o1i?G%Oug=8M(@8;);y4VCwH?L%!@PS@};e`(uz3j*t4XAxaN_AwCfm{GO_EBnO-f^o~aTQ3& zIB2Tc41jQ+$V8}p(v*ubQnC{+8roZ22=u(7tlT~r)prSG@XWwnnaM)OHRG4gv-2$z@}{gmEQll*1~@m`c>r{J?q-!IN&(hxXc=Q&bOv1V2^W@a z3sUDLjsUvzXnaP;)o?gwqo~Yt<_Jhap7=K_antk(ufHdcYt-Z*$m!>a62?dC;H5GY z@*%IwW?CCZKD+vZd?Psz@Hzfr*gh%_A!MdC4`yj&l7Tw;M)KB_MHHG_Vb8EId+A)% z<4(8y>pd$;XYrPsShL5=V{cqj1!p~Q^3C4Zm%(>-E#iOPqkNcK=9#GxgNx==*?_>> z73!+|WpZlrOR~9|ez!$G$HZ`}1@N%)y*_TUkrQq?uoOp)MHL_zpWxXN2sEWaBv0hR zbXAv>jb4XJx}4Qjab% zu?L~qPLZ;{lvqEbr#R>2)<3IOnto*&#IAa5_12@42e`Fow2;{u29YhGi;PUQMg?dw z51mG?n$#5oMO2{PvbR3?)kEhdWBE$0EZWM*OS0!`GJO?FAPScv^wAxknn#F305E;H zz{dBlPthR>qz7hYo_g=B`dI1i;YFn@n!)}m z+}ot2-@Qj_mVs$xe1n96c7m+hcb|_gmyI7BpZDSE67?39|2O*Af4|S`B1%dZyfxEz zs{Dh4m!uW`7kuP@6*Bz$r}S?}HUAff_%D!#1o-`jKrMZ&8z^Rpn>;5ClYRak*$rAM z=I6Qm;~QvZf?$1#H(ciQTNy@TP+|JHKJ73?fd1cNsfNyi5#C49;#Vgk9_FUk{FvWmeiJ{g-Sy2 zhsD>weEx8{X|v8`vkR8u&K8Ex)>Fd79nZ9!iV7sqXFz_=p+`VQ2+6&oY?U-0NxCO$ zup?}=42kT?#=RsPSX|c9*kj~y_5uEjWc>Q1zWG)U)zG8hgTuav2iZ$po0MNi{^3N9nQni06Yuii z=4TGdEJ6>P!`%cnf>W_x;Sl3SJpSe&Y%EIselhiBCB^`9IeL25=g!p;vd6DVKU%De zGS<>agk?R)fTiCt7em1Q-;U09?OlNPyO-+&chv?D6|4ciY;y(7xd7pYF=q%J&f@#U zp(T8WMo$=WiJ-y;K%K_T{u-=*f6BI%_V!wbrfupibb+>$Q^3aCQ%->+7id4v^PRz2 zoZI{n1)jexwanMW%K=K&5r6x_(N^hn;LI;YK>Yq~q+2;Ud$X&t*D ztdJXZn?h`KB9d93$@S-NcRCp70%VXZzEZDKqWi$}(*gHEb4{yt>LjOx#v&urN)|6L zsW^FVNm|hH2OdyB-ChJk)+QD6?|R#6>`XiFfJ1N9oC@zApT7~?gOD@w>rbD*vA$(? z@Bq%Re)P?bmaLTjibiOUNa^BsqFreoRnNYkD!W?>M~1|5h7B19W4YZ2o02=z}o z72ze)<~#4n&3UKG`p*Oc@5zZaQ3OkMQTSASWz!k4@}-d<^5WYoqJgU2*M}kg@Y)Ch z%fQrZ`yT-m>UMfMGX2bx{no1Z4`=)jt*XmSeQV*QX=4vMT>%8*U6d5})N{$+XxRG1 zGkZzWEj}OIuoGPOHR@{&Gh89&3twui*ecO4E467T0wI{<$|hz$6ms#%*(+SM6C~TL zQLPw1`l^Q$l9jb&tVP#qT2;mIS3F%9y?0LAJ7pQ0)q4`A*5Em~X4qBAUx$0puVUH` zUH+{o*6^cSdEmgr1E+7F6%iw%lEK|;rHu+ zjiy}+iW)Obg;aMRTA+G9dL@0;?MO>cSA+h5XVei8w-Oh6I5#-_r0|)K=Np2-@3EqE z2kq1YgO%YYHi)B1M>=SET{}vH&9`=bX|CQy28@r7_Z{k?l3Q>xe<)G;qQtm$y-gqT zK0!O7`7uw0%8r&tl8O!R5q}%*{WIfrv;CjH`EQ4aCq#1+-pkMz^8V`o*HQ2PqaOVM zLQ?F~bWTCp{2!Qs);<44UgLiX0R$Cpv;oDy&4SLQhs(zmlwY&2Le3o89W0KaycusT zlmXjFDm?{d0$B-z!CsuG`@LH7qiAXB72OF9uto^^9A0m(zJ3p>$`d!+M3E2{o9&@D z+q;;a-Xjl6j`H<8rG#pJyppi-Wu={S;I~d*$Wh>eC7zu+Eau@K6o9x2M`-NZbd>`{ zu}F}XTx@Es&y&yr{XpKQthASjqr$LFqNHm)nN2dweSVmf{?6^mrYj|0U->w+q4HJ7 z9Z7-ex|8RDJsQ#))Z$P@<^27{$snzsP4BNi)HS{aCy4mWG}BjGB=gb$Yp_4!3LEi* zE$Qds-m%{(C`%MF2~@qm<8Jz0oLz@)S971Fi9?FYz-KOn5}sr2LebJv()m7nBn*+Jr1308Wo7Rpilf71SAJp$=$NUbCiewmQpWn(L_o-` zbLRo13G#S(5x*#D_Y2d%Uh94Q#4h6Fr9&F4ATb9dPwZc6PQHv``hd3D^02wF;k_f@ zr0<+g_^rW4)>X*+{uKOyhi6KK!!Q}JbiCT{OK#@7neNlu)3H@{>t8lc|9J(}*Z&a% z^?!x}`+vh2^ijkjZoDT~`R~a;31#|E>ZN$Ob#-;Uw|*&g&bqDjFRQI1!+eNvnCSC| zHnxnMS^+R|ZGMA(Ocs?JT|hVCzr`Jp$>8^MM^BHZ(+B+Cw@GeWXT{ z@O-(UHr)YNM&Bp}`83EaK8b+COUe85BxzvlyMoe`JQNGR$+Fg8--PB_T26Yv(#z?W z3%GwGnV3m91H>^8J!I+V<98}(04(t#-`C`c;DX6ufwBC;gU}PaqKH}#)y?cjgKTsL zluZg%P0UmW!Xs1*b)lB?qvY-jv0sENgt(|ex#ng8OG&~8zrP&xA=rlOr!o_v8s^3S zL6h&ykY>om5-GbA(Zl0?B%wZn*Gn^52=$G5z)6U3ktt3%o9WG7N<;$AmSx0#js?km z1k68xK0*D~o&@w1tdE&8EY?1-F>YbwqF_t}?Fvt)x-JtxIL+zAzr@;8?#A9t2>asF zJOsO_JmyDEwv(Bde-8|Z&cJZKds&$>3r6gFW|7~(!xl^IF!uKT2qhHZ;f&VW33rCl zS{vf5R{+Fo7D?J~vY-qt*?E~QUg=af!=)HM*zXr^9YpZf4}Ni5dZtuX6(_TRk5#=z z10SW)n(jzg?brJ!pT2NekhW|jI$hUbq7o!C6<#eu31Az#x-0Y&WmL_=(*(x7zUhrs z0<-=s(GA0V*A0vAwakP4rG?AoljY4)^$NL-i*ASMMP!L3-M`Rb+ed;ip9BG*$}6yU zC!j2%@F0yQDH1*Y;ttmEtFpHRvfX2`@4S2@LCK2ldj7BKnhe^-Z*te~6a3KMgvhLx zD>Z#3g7F`}OSX3bO4mR;fhE}7Ka#E9`bR3@KfQZ^7zA0z|NLvMz|VU&%LK-E{Fbu&2XRBAMEurC zyz9O9NpVz_$ZUN>PGOJG?lwR38+y+c$8C~~A4EwJxdF-%2n}4&hTj_Qj=xuxj(mW( z=nz-$h@83#k-9ou9qex@P`EpF75%z2t8Zs(MGAh)jpO%ZFFli-Dz9b+bGPE&mit!K zu$of7oI0SnYuL9&Py~rd{5Aah+$JZFYaB=c2wWSeXV&(OcqwOYR^8Fb7aHRTozBr# z4iiM|X8317@p%kO$WS^isB=7tlbi;%{_XFxB`RtyJe3*)2p+uI)H^=^;PJz6Lzp%` z2@?YzOUdJ*DynOZ{iY{5!>1JTw<+%%^-!)qKhjSPP<smU{^eFj6vcWL;8YABiiw47aoi_af@%BJcl#|P+4p|z z3A@ODQz(<9-z)GLBi#l&w`_z2W<0dm6|CeAsKgL(d|ihCKG$HZeP!F|sMm;7rl+Jt z?9q}PRa)|SgtFA@tWU+EF|7h7DuTeHmGpN#Eidf3MY>6ZcQ>}o6ySFya>gAl4frEW z4`3spk5;tj+_DNN(?GR2;6bVz(bV*yz@cW1v0;3?La|jxvdmKEyJx%#b_YL_a#WOz zv~|~D)}!Rg=s(X_JSdv|Y_b=!<(H)Kzrj{Z_VMp|wEud>?YCb_KEL_(`Qqa6HQfuZR|i1SCKZb!7?&4_~rYapRfJpK_6m?Ct*hrahDlGducjek}*ix NpLRTjJ$d!V{{a@l%UJ*b literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/pagerank.png b/docs/cugraph/source/images/pagerank.png new file mode 100644 index 0000000000000000000000000000000000000000..193c0a8bbd1519beecd28c9d61fa41ce8855c86e GIT binary patch literal 13243 zcmd^mc|4Tu|8J$VSyDY2dmCDW$u3Lcp;AUXLRn%mc4ERmZa}tFaD~Wf&QTbL&~Y-|z3dUg!5Z=XZYRcV6fGf!BTAx9h&{>;7Dy_i}yi7*i7? z{sW>1cI?=}f9p!p)4X>?o#)&j1Izo%L?&?bw0G z9^7!)103&vaMc>VV~0S~_Rmg|ci#OSJCsbX8R-3nuw$lyvEr?-yjDvHVGE%jtvVzu zm6c`EtfAoD{%?1p@zW@MBZ2U}S2NVU?Yg*6uIxkBsmnYLJKd!9KkoZx^%3)Erz?M* zUFzY6(1gxd8ZGTrd-34Zj>gX@RFzLd9*T^$9;R(Q?0K%c@l|lX;=2sPa?93`%V=oJ zv*+qX^d4F%+>*}NDeFxy+wA*qFUz}{G61*m987aL(E z)9R4xY+Gg&omOb+ygEOs8@#qOQJc(N2z{=rX(BO+TXjIEo2Lmb)nAy(i2mfM7ZJeY zvQQz*JJlK|JMYGhW80z{2Ufo&vSO|1!9Er|p=X*e{R%yI^1ZYuK`bMpkq$$qw12U{ z=?sqte;gY2=kX=4S^XX<)HLR>jA^}1yk#){vU6;BhnlNQ#Wz(k!Yq0IX42#ATw;JPo%|hgYsDIZ&TD$h)Xnu}tc{ZnBlv1-@Fhm|+e5N)Eqrt4uk$H%4`fXAke}Dk zn2%4fT8d5IY>S5Q(mQI7-MB`@!-zKYX|-KmSE}{8(?iQ8Ot%)kHXT%>_oK5X<7Axqdk#OJGaJ0846#cI$Nv%;%H!eX-PXVe%{U;EUhpnMs& zMf6wPf?)EVrO*!FDmmC>YGC3DZsPb!-`=(3yxxR5T8` zv>qHE!mzXPB&)B_6K87QIG?SE{$x8zF%@Enb2rf~uz#UJdcNKpOhnPmtm{^221JP1pBD zI;obkXn5z^pymijS8k&IAw9KRHX+;JZ^*h9WxhftT78aEEz5=s{5D*ULN>`MfatUk zfmR>Yz2V^5>Hup2ztBhA0?&AQSG5Whmu@{5YaMn^7uhYdXl6KUjQ?_LicbKoF?;u` zx=*z1$&*`SJ=%?J===&}TltOt61B+1T@4SjM6BHsQpBnwso59SVx2vC==`}ADJ^yK z#?@HXFEFO_l zbf7bnn(rUR<*QX*PATUVjMk@7UD#YY8*#}q2 zU~DXsrEtc#9uPvK};E&drVRFWZVS%LG#s~^`p^D&LG}cEeY>EKm)&+MZiVY zF-+ekG4CgqPG!1*Qs7k$k610h0&JN$N?x z_;xvGtIye!sTyyqQQUK`9qy5vQ0+k#*$UEHGMG_)W9up>4l?it6ADdDK2oyLr$;kE z4G6~js-KoAp|Aaxl$LPv3uQ! zBHlXdUdK^Mp=E)K-&%q_1>2!L@^p+^UEw>0fV?%OfVIuWfqIJj)uChsl4}B@`Km2V zvF$9A=;LBOwVw#&w2i0zKQ0dqG)RuFrlY3}2YRe5gvp%9TgJ@0qhi<7qNHt=U_tB| zscT^zYASK{z|zR~XSRdXz!MGDxHiWZy$H;c@p=djSc?-8AVPop!a2E?)^fnfmlm`-47sF1T3`mPpMg(5Dg|KfZ%kHmA%gdrWjx!o%`p#${Xn*;WlW> zDf#-1>i7CyyaBCaMb(XVub|{s{NyW;6z9B}GGY)!C@MbTpcQ~|T8LkMGF>z0o;%$U z_DH?q99hZ!f=bG~_@Eo5Lfx{}kNwITs_mreoYXNO&p853)&GQ_O|Sgo!2HZKhzQs( zH#93Z64ayzhsYS<%eT;HshCHmg1`NEhYb}_}x=%!W7jGMD z>Pu$y ztG@>xl;|lQG69elzi+7_OODUP?U*Tp)o_AgZJJLHvpxNJ%h23X*bSY=0rPrVWi(%x zAi0w89+r7iSud9JJ&8Y7KpNw_l9)DUx}HY1?-8~23-7LX8v}WghPrv#5`}Q{sk#?7 zzB}m%T705-t_~@5wRW_w7bAWcIyxluObjy8k?ifHrPe0}FZ6Mk)0fPo>54 z6fqe_AE-)URPgr{HSdwONc{OS2$^8yYOXN1J_%Lbdn)GF%$x0js!~Z!qR+sy_<_(T z6&6O^_pXe%Ep>lp6DSqlT2A?;L|isOG%MYHxiv&6iru=?P+}%!^~B2hq-Hf5=>ouy z^8i?83h624633b5@y#%`pDL~pIZ=Y-RZxV{_ECN6g&HPF{=WGu;hfpjT#Eoo_w=3Q zxI4O&x-ST;u`Gw2XVCHAI&jbXj`|)b%NAtan$;&hdpeNgR>%UgGYu@6@ZrE0wJm57 ztz-Ii`yZU_iBlfm)#`4-Zaa3LY3_hpmKcOP+>l?tw}?`E+I%|H*3`pUs25+VjWND^ zy+{@^;M{0V!p}Esh30}6XNT?w`=d6a$d^b;GG7uZo{KZY@4t<8`RqR_5S;*3!GcMO zJ-u25MC6xve?rgZFOF0VhlWX8EH;^;)t-kG!7(*;XKIzD{KODxt{Lem^7NgHu)t2N z52A?P8^kyq74qKeT)Bd0!;f#zR*928>W4@n;W}siwfb2SdqdvwnUM28r{RP zs1w}NTm4#Uu&fndqf-z+E~y^xPwiD#bd&n}W6lQTfYOw|KbLGyyJkE2){OOL_{7u< z;w{lCIZl%lbvZRr`~cX{MTz2*fq~BG-ldP*k7CKUd?j0bG9Z)eFKc6n$#*?XepA=J zA5()Z5Ra6}$NTt9TuI}|9X@V2a0|GX-z^N-J$T^VgZCfUV0k-YN22w+@F3#2)j(s_ zO#(h_;9bYo=2;J5`ueZPP@G6zx`#^b-t1z&2xfO3IFg)7)7Vt|>qW)XDXVXC>K#rX}!M}+=OS`|uhq;7NxF#I1XHFOG!Lu&7s z88yhUp9c;G$bqjAJE|PsI^$BT4G6ts!#L%Zcb&;;F@jotSk+BG1$94yo# zI>leHqDWIjw#M+{)~kMT3)mglZ^+EakNC1{0;M4tJ#m(kA>t9xq=4(%{?n;D27GVo zzSQ);w|-70njL8GzS(KbN_WkJE4}yD=}Wf_aEAU?tOU%Im=Yqj${3qH+uF82ZA*fO z>HDHLGB)x#G@K)MY=C0&1NqGRzKxX%P863oeQgkLTcM_$H(hal;^xb zz9z9Zy2T2hk4EtrAqCb85J)5)5t+S~A=j?wl#^L6*!ql9Q~yFt8a$IwzkXZ>YPh## zWh$0Ep)+}~HI5F>grY9%`r36_h)h+mp>n?8)CbFoG#dMIFEWlBrnLnT7*hEfyzCg2 z$s^H7z6c}n$>K@Xw!Oq^12OK|^^w^i6sXf;@Pb~$qD?qBeGRuUCk@ZKeF>TyeJdK~UZN)ToWnlxA1?lbUy z?C_<7Px@RTBEr&sehyBxSG1*{v&=SX65VTS_?uIPQ7#GMhii#%SqAxm%u7Un#w@|I zucJlulN$0&Yl6?GLOU0?5Jy>OM)izriDpAtKFf+!Ox0qfI)yR_PmxpR?XOhJZptw( zdYAg=L&U7yEU)Rz7|FZjSI?(lvo-QeEfu~mx30hQ*723|%%qTA zvl{QSs4|g?72GowT!+WPm@+4#+|tLT9t_X9Ib!^Vf7uV{ZCFLL_rakXOJ(63jICp{ z=y_9udh6NPa_f$D!E%$?J1@gNq*A5ar$2#()*{Hp!SSPWnoDANJX z^%t{h9?Sl*d)wU@Z?0b0jSW*qz;(t~9c^DxMunw;|5*p4qbKZw z&pn^ttGOK92ZNHWE+Upo)*hsKn~ZXdudgDVyH4>)G<@S>(ozhkR7$4NQ1+@1H_eUr#^;ZJ}Oev z3h^#HWw`A8a?__b3UV#4sj|id;^}vFFWn>Q`Le3jfJabiHg%tBOZSnR-x+rt`qFh!fQvH}19PV|bdgAgI${5;SCB8B(N~+J_@H?2+%#9QR|`#W)3N*7=j$ z77V(vjX&MYd#N;Ks^01sK+OJ#iLLTOffJ4n*u!>3>kXo+6>dvNb;Wa#)wlCXeB8w< zGz5&9#OfkCwK_u7dlX;Ou0_lBQDW_(veLKMl?}t~J0n%;A-#P`3XD(EQ)CZRS*Eum z7Pz^5Bce3-aVY`?kq=LeUjjJ4brV{IkS%uNK?FVLMoR2r&c4zb5n41#ZBsNtwSOd` zZ7xT`Q8KY|dPuksv%f4`*M~iEr(aEw3QyH4Xf2>)G^hyitS8M8ND_ms&dxgsgB$RD zJCZBf{Ssvl&$z%^ozmS}`xqFX3WCmvzwrWM!;xy4)N*RC@BUs9b6^b|@*NNh10uh-hUXmyF4Hez}RuB2!qdxUm#y6pJc9imfjV`6TUW&QNd8k}7eawjx%l z0Lu|ybt^*F*9~Q82~5yYAce?%=vGOmwx`Pg9_bKF8S&4lMgM=8cZmFU+%DS#_WDBc zn8nCaIxf*Mc&5X^!i6OKxksU}2q^;HNLRv48N03d4vfcIsgMRL*#~FL{V-JUE+6-~ zkihxA;?O`Z?tA+@R^X=u28P41S|{!tm{gap#)8i_9|r7MX4SogzW&c+4!ejA2NsrR zy5Xg!H*cbL!XDAx`6QV?O3lE>$`$;eCe)T8swc;Shvcr}P4~Gh7*^*%ZZIS5^OHwN z%ZGtqSDf17Oug?CxX}DI2NOyMof7XzTy7I*ibZ|&N8glKZ_c1qcL$y$1iS8gV6GQ5*Q<&Wvn@O21@dc; zK1M8-YWgnbtfhM8?g*)%aB@4plctD#k}?^h*+JbM{5VzGzGihY`I?HzdWP){%Wom8 z!vXUv6Do#pC->75X;$Gn%hE(ds#WLZVY7Nalv-V)Im#pK@L z_7Y`a7i?e)1;s(B7j>!&a?0Nm$7ozlot1LkKw>i`&rszwvEJuX8dEfZL-zI*i?(Rb z{W<%N(k{1sF`E4OI(d+g6UKZg&#>= zyeIBF7sv5QO(a;`51`@kL%>|z7X$p_)~LPk+*TlDXP)uSV;Z~lLwRojK+Jjlzi-pQ zx@B~0b2Z>@-7RMYYX3n$_LVf!*FVrd%GZ)|WV@lx6gU?s?zmtU@>tgLV`SgY%eGr9 zdIJC0Z5TJTiT-)nb{CHBX`))zkq7a@jJZh7;7p30N0E7iY~h0we|6UXtjN!8;r=i3 z&31XEe_VwxH{nlz`>&nnH^y019op{T=T$8|nktC*2ZSB}8fjhMom^hX?XOq$(3kgPAXD3|<;bo;_E+PtySf1|ekJu|n`954*sdVLoJ71HF%q|G(`)EVMtkD)IA0*F? zX+JLwHJvsJ8!*pXr<_BmDs*OxrsF|ER{fBRy>jtQ1VOJ$4?_cqT@vI{^em z(u6$~Q$`nK#E?o5vJ$iFmJ+iw~eKQ1}*QnMOS_DY+_%z?3u~)L=RgUAF6SN`z0}H&kZsl3r1P%B|6zAJ zx6C*UH0vhm(J@&$`Ld4+)#33I7lS_?kOCDZ6ARz5cej8X6hBt!KEBno+Yu%JRxNM; zQFadH^M0$hVRMBbnN}`=?}JIYj~>uu#C-%Aa^C3NnF*Ud(u_}EjN@4>=qHDC31+qF zAg5DOHHd`gIN_eZ16Kj0$DZ;!vs5M$jBo?p*m+CXnrfD_9TD z5u?p;CY?`1-znQ=S6Kh9B6`|!Jpb+ie8bk}I*T1%QC>_esJ`NxANU4iM#@rzKq;nE zhZu>LPR#w9qMB@$va6wrB5m9h#w`dG88nYJLaMkD98YesbhlW*1LR{{U(a_{f#EUx z!I7=F%{g;>+{K>M1jek^qG;3Ysl9zUn0k;wg*k4s`Pt&iKSIZu&+hXp+*mu`zI(?V zFB-yfc-*$!stN*INx#3N){DmbyiI1AZC&>)J?gG})tCcYf!d-i;F04mR6S&;I?E#a z6QP5elNMjK*0}Rw4Z%THZ6#sc*@J#Z2W~rF-jxG|CGOxcM_WZ|-Y`M$p)H`e8pO@f zO?Krf1Bx(=utXpFA{AlDZxPAiyz!XBgSH6t^q1yO&bqCXiXtJ+|`e z&t&yyT6(L30RPfl_1UP5hFB0)iUKDlVNrW&8Ml^vWSrP|w+d(Od6nF3t1P`cU226= zvj`_#P=&}TYKX|{GlhC}T}oeRbxx>zMCDV^bH(p_Z|swS`OA?xw@;a)O@~XQTI6Y2 zN-G32M6%IFe+NJBfWv3SB4b`}aozDBPAh&5j7V|}?>zGFtUAp*AJDKq9Q<_8sau0j z60nsH`ZP$j4_GT3`Heb8f~aQdKF- z0?i+ZKfJg5m$)YPba4u!(|R_c!udImR5@&;tY|=q@~NIva0+su(5`}o_ZWUEnouE1 zNdGI}T8a=-?kPiz4Lty>U=IV~4jsO&7_$;HVxdk9eBA%Z=sEZ}fmn9WfI)ics@9Y1 z9l04;vj?nM1_>E^#noq=Rhe`V&t$hL7&TwL<;Lwx|Cl;gSfZ(jRl$*Exi5s7{lwjH zl;_e;Ov}+A2j*XB!Risy`tWGH_2|So*f=Oh4nolsC3If4e`RhJ4N6a1t35#aVvTz_ ztW0Z0H_Ju*(>jBE9BXAx9HkuLyOADZn*g0Apn1K?tq0xzz>NQ2plLhZC?kB&{0qQR zK!l^XvDR1rio-cb6Vn{?F{>G3!3eKyL75 z(c(WcV=lnVme|a}QW91QR~NLxux(!%(@GK%YEyM`C%pTU@cG}gEAy_bvC-9)h0rf@jq3eB>XD^j|7YI$oi#e0Ow z?u4VhKL$rP{n<0QxbGOAYUJ8P!qj@NL;DKaM+B!;{2A=qa3Ekpvl!q^EcKt>&HMqa zz@__^+jxIo!M^(y0m&i3;o!{D=pISWA8rcf~Ka=xD0SsMGH`GR!K-uaLm|^q#n;J3=KUx$hH+y>R z!|KixaEVyl=5nqxprt#6`6ampGs${l9LrNr`kTsb!JfW*k@_gKHL@m~AmG%#;)!nL zDKqHp8G1EawM`c35B{K(0z9XLo1VmeW&?#Q0fSWSkRJ~!?K5v;qCmx+V=bTX_JgN+ zHBfI9JVB=>i!hhiq$d{1-lH=|fbWdpnQ13S}g#Id;zu>&KL)G|DM2h~(bJ)q{o zG4y*N5`Hq!^B)eiylJ7fkbO>lj#FkJt|!R4Xcc8C$ee$Nw1}`AEuH<+=iA6 zd|b5{CAYZ)=jKbnf~fm^`<{;C3BO^|7&!x}B1hX1=@!Z`FLZknpNHgGr=f1Y&_@YV!6$-1 ztFpRlt$QLqcW@Tpc;1FISPvynxh%k-Jbt1C`s<&8Fa}vh;f6=I0x|Vcg!@u;>n&^qAc-3P99NHc@Gxt4iH>XSHLt8jC@u&y zy8G|!-~uSHR9U4{j{P|v=zqe{?fgrLOA^9+m{$=0Z$(aFK1o1i{~Oc-XzK5fh4JhE zDjF5E8tGE}>EgftKd(stvyjtH5i#eoFm9?2*_^T8#I8)oN^HXY%N&+qh%K$>%Rfzg3f}HTI#`E&WhL8Kda>aZ5>Q%+Sbs6y1V()xBE7! z3)!a(1R>ubP`HLyf}8T&4HI|hY$E{j7i9!cpXsQv;(KN^gw@_pk_ro9Gb=~y3pFmS z|E;3)NufW!MwLPU!F&bz1#kou5ejw$bv17qbti-=sU7qXD2FJ8A3^-1yM3T)g@z4Z zOG98cLo9Q1wxnUi6TZRf(A8aEBuNFVX8>lzy#+;|HFbbl%|RC!q$*3jzn(klvEHTT zvBhNpebTIl6fpMRL4=UqZR?s!2moRmmPn-;jio?kX&}U40gs4J4bpg-)L3|reI~0hWu+Tcu*KGMD2VY*kq*FWr9n|j}vj&P~8mw+r20AsoZj8H&=Iz+-BG*n}XGA=%yscx9yNC@NLhiZFD2SoE+SbcN_Hvd5 z047=ZcOOxgBV;Q@|M^AzJ148pPsRLnfwo13mcajtHAFeciI)9W zm!-XjQD{NK;h5exNt7=$$z^X?GR4=xh_gD4amHT_CZ0^Z=1EOWeJgQU?2eU{g|lV* zw4SgNl~3X!MfqWOOgcb<=e-U+l3#wFz` zc4|+*PF+%Q#Zi);UU>MY1nJ)^+!yo_vz~S@9REerZE{!ho(Z5z;8H`PfA_VKaV+L~ zK-$phnzYI+qO@n&CJP6-!*W{lua;?FVj^yevL;?38Rzd#0|wx$q_k*3DV?DUL=$Rc z-z_Dz{iw_G1;)bN<~`R&HWW0aB2B*xEpiR?bw*mZHZIM)ofV2N^U=TL)Sgv*JuO0- zoBOtuuHs6@5xg16ZyMvyD_wFZ1zl=vdAe&@*;3oWSS<~@X>+w-3b&M6V{%zc4I`{z zCpV&OVl6BB<*|%am#QeiyjzP=@UTV(z1GVu$WEbvw=R@i6V=47>R;7QSU-P`+ds4G4kPlYrhBn)s+Xpbwu_#*TG|j_L8HDfXM_;gZ74f0qNY^_ z3(H|rM+f`SX&art&1y3E(u%KTQEmy|`BD}M_^vhwO$Igyg0d4HvFRsmOo=#bbUh`x z^9cOQO=ntD@SugP8c&(&mq?xIjJOorG7Aag0h3#-xq+tO0VWV5rB}q3Qf|E&ceXatA()MZ~WS93jv4tVzI*qxsWqBJo=uAwt!X z!1oie`?J>eCB6)KyY{W+8Ef9F(2(CH%3-JDIP@0d46uuIxt7!ECZRy>1IaJlShv=B z=_a0qJzV;01F(y9o3v5T7}0NohFbFb$s634*zV5|l{5R>hog=HP|08aXn|34qjEp~ zScVw%{D(i$_@DX;-R}+KcEB7A%Hz_*|*#H0l literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/sssp.png b/docs/cugraph/source/images/sssp.png new file mode 100644 index 0000000000000000000000000000000000000000..2c9dfc3685293d638b46a2058dd3b4a1adaa926e GIT binary patch literal 14542 zcmd^m30PCfwl3Wc-HL#=6En89(oP_=K}b{-v_%jR0R^JW$Rva?1QOZ-P=Q25WC$c8 zNRT0j$Pk90q6C@36e43pNCF8EAcR0d!rOG)=bS$0-23i%-@EUg?|XbY|Lyw^1@O-|YQsjb!_6~}Fc}&7hSlG1D1rIzGBQWq&Yn7PDcof$s}!wj@uXaE zM*KZu4e6Yj?nIO0z1TyyVhl`mOc5C@OMK zvk!F^Y7pLb3&%X%?L5Lf%8qfxzs5HwBvfq)(mlWrxf~|2sEKU<@Zm#EG5_w8-U=-C zkIUX`i_@T1{1RU8kBjz5^GqE|f)J+GFURcq{x$h8fe{yzsD>>vO;Ux!RroQc+`Y zrY>qJSorzvoff(+pVfvP%CW$NYP-~i+kfl3JL`F6NSxrPpqkvmp6V;}s?oH@O%9Wh zV&a(e3qRDJ#plophZ^D(Fx3k;YLzS7DLZBc19kCYL|u1gWw82nB}MR2B}MYUNDo)J zA?(;;2!Mq@c z?(YY;C+mdm`SffA+r3Q8jDJ2T2+}rX>#FH$=>F*YrF@9K>^%CI_m<@3>w~(={$W~u zwZ%#w)!y3d;S_ASh<;Sm#_QpV45$uzeJKX7G_uPM#gfS}5DB9`mh zLwu11>%*{Db)sB4uDDBd5YS1oq7PIs=ovCQ8KH(sZgu`t`&msAi#Q}^?Tu;}rqc(p zC5$6I^qtVCpa_@DZW02boXpWI;Ep~RW|WSBCM4r&^C*c(z~qw2=4>L}7Tk6Y7CHo3 z9#PPv3fXi$z^vF1=l>9e_9*CCpxm9WQEOJ7q@Ie9l5a>EM^vQ`s- z3Q}PlN@`}DYS}Z)RsVVQPSxiG&azZkk8gGB=@p3`pk?g@drYTssEG!?lIW(cG6K%E zs7Q`gqX~jy*gayCZJ<;W`v;Ts2s;)miJv4@7H1m1bbjrUoU8(_KU=q}d>jWxBH-EM z@!E`@fYDb`{F))Maa~}xKXY25hPV@km;dh%~p+R{Vw^>!hRlJ_B4=V&@ zF6@jdI^C<5_z~8QGItAM^kMkSh+kNZg`rE}HsRfYC;`ph9+@~7jsoq()KDEgn%U9> zv`1vlY_9tBK{gTW{*wj-wd^9!#UQolnX&JXOL50K4Yn1edh|B%Hy3oXCv_qc8kvra z{_xsqGm$4Gj~>Ur4sRfQNDX{|-bvUAjn>OR{WvaumI_MUiDVdYYU*tig1lXFX73LJ zi^-Z#m3;{3RI4hB2WN3Biuweke@Ly6W`M7Gcptp4JiX>fVq^5=lHGvKisxgo12GbJ zw=v7t)9N6#ELFHb=^99!OK5yH_DVlj>As zSbAL-{Arbq*UU%LeFwy9v~&7G;;n)CS!Il-Sh-v&(Zc(XBT>>Bk){&Y{{o4(Wx)rC z%zH%M86L8#1RRI+%7x=w1~b`>xe_pigsO4z)Ym$y5HPg?3e9bdr^QnR3Wr0c#BBix zM=DLM3uB5u8I_QXN^e5dyMb{A_WD7L8f(WB6(@dTH48^d(X74ZpyFh0Fj9KI`63nR zHYvIXN<@(6rfYVggEd-8-4r|7B~k-4*vhxxU~&rcoNHpV;sP#UM|fVLCVkOl4^|>Y6q$U@}#E^xex@pn3`mT#I@50)C!lf!PT*17M^b@N;_Xte@fU{RUCs(?@VJg zCWEpLvT>baLzuBJN>Uf>od8xq4jn5IXLesW>ctu#4GyWZ@MnKA(qK+gjA6*<>(jM! zd?Bg6`sQq268CItmRrtf9P9gmvCEsF&pzax=ByjuxjeY;gt&kpa(%rhe#axRhMfqz z*fj0n`;^Z*5$Z@w$@ybO3ub}CoiU@^90poVIW2;y0MYW+6eeicXn}~SYF~n5TMe3# zhhT2w%#5oL9#YDIj=l_S%m|f>6uV6rYaClavfsGz9VF^@3GEY3c*=axAtcUI3tGY> zC2-(F7Nflt%^|G`jrvA79L%Z=G`K=GBqqBl3v%FC*t9s$mqii?;}n@Q;TOxcz$BPY z1T3Ypq@G0XX{?^YdcD%7|JY`s+SS5#V=P~*sr6da($WfjP`-(gEqfdi@ERZQEW6P0 zY>LDR#^1(WE$O$9=wiG^O5X+cM>SwL_lf7&Y4Pvj6C?{*g%;Retaau<3yA9TGs3x( z`nGf)Wlq9;V86he5cf7K&x2o~Z`KiMgC@Y(Y-%#28F;r`TYPWI7cFj7&X?3QQe~4ny}(H;huH7GqJY{&_MSX9215BzB^nfkjc~8HtMJYC&h# z28nE6WbLc={37VJS_b<|MSGwL#f8qFSURt<@?)l^*j`-g zu-~gbJR+(umGk_vE3t^JXWUE8QA{DIZ}b;0p1}h;`KjTJLekF|MW2rH?pj)&aDcmk z4MV|(f(o4%bd*avPzx_4WTHtly^Nf3)m(!;_-R!miA<7-WeW!FNm_M`&QBH|y(dX+ z-j*;v_e!`)y0v$G<@OxW0b45}gByNGQr%7Yf=IvW_ApkG(MctAmJ(wwES-T;gL1j5 zNHC%;6s57Sq%3@H9Ke31S_5w-aSZA|o=9K~Qg12w6jgH?rVOTucDfve_BbJKrY_QS zVk`KtZrushB~MP^aqR`v+b)!r`;-PErx|OJwH6!0s`u#QH5DAFUiiiPQez;r@1$`(k5r%}@I(T^Fs|$VOGe zsyRj)Z~Q;UbI906kd6?gIuw2#83L}J_?3U0jcniaF5;%P96tIB*XW0p`3CDFGudxu z&UZ$ZsvKK+)2QHCiE00?8n5A$dVKI)gQlNIG&+UV@|JAtfh~OMTGwsF=rqx5?>i-S z1$FH;^W_np9N=Vgx7Mm+{DOts$o}rQ6I>W$XJf;{Ox+`)>fqIe?^aPRqJK;3u}l+nE!v zT3H8n=W?$2OeC$1chgplizwhns4cWHt5uBy!P22HhI5-Q!-+lbxG`2ihS*Z92zJh0 zPJUIfy>1Y)o$HzgapZoxM6Z_$=Bio-UB=(_w?xW2@3M%9mI&xA^v=XHI|J2VNP4Ft z`~*EHdvMyG(qnG^mfsn~zOd*zWwJ3vNca|{%?ybaHyNgDQ$9Nt6k<~hI#plcDCuV! z^iLG%DPw~V;tf=p7t=XE7ZtzMqT(U=$5;~#0*(?gXSrz{PFl-0b*!*v`dLXzW_(NV zFfNX}ulad3Vk)0lZll$9)h$RxNHEUPQSv{hU2iex(9XzEq^bB6y&}!E7CX5pa4aZ@ zb~CoE2T9(Rj?6e%!YVBYh{}?4NFiULgDZjkr)K1dj()!2k zrtIw9RbuGD6*g61O^#Wx^EeJVU{eX0~Xf zzt_sJ)g|hcjm)%TsR6tQ)~CDd{LvQ#I38PuuAvgIKbm5*_L!j2jTyds#oPNsQT|kB zWd(UDPyPnGw6K_(JaDugy+4`}HZTdn!2QKewmGLf)Jxt;=0z`(U_!7d1&MUrN*$2A z@JHJZ6!1XCc+S9^);33RuTarshwnAy>4hYIGuR|>M6|L)z>kUex-kFrsQP2t3rK;7qDS+WlG`J693iRt}SxOY2E%RgfCUHB)9XWFhk4!g_u zl#dZz8X(2<$!_l=EF$zPYL5uXdenPFSi7Bn@vV;{kXDl0*&Y{=GjSm&Xg{VuGT@{~ zMx6Myq4QQYbykVihE2a9j~PIQv|YlD+K4-=&t0sysZZCVbfkA*bg!ArsZ?T~nTlNL zweomjf%j#lnT^gW9B%j3SNF{&KR!pI#Vy#8c~K8srb(0C%WQl!ZafJtY;z=u#h)E| zdigCu%ti6hS5TT#mW7THt5d-jCUVz+(ILWQ#F{rtGjBKHwH&JReHcvgan3WQ%2`P7 zzy|2p}EnkM$?EZz~!#xxoAYE>M$#8Ad=ertl<4#?R3y-bQu*MAfH~s-qF|YrkGt(K%iSy zw2%ab2g(BWvFYN<&;zinEyPL%!34<*znhYdB)PDtQ@N>;T74xpB+)VV#f7#$Hnfwv z1lBia+G#8`+%}_#Ks}45o>K>d2R#E?)2R*`fi=-oFCyj?j6|GW_p7*t6KlOw6LO3H zp#rKJ(;U^wq0HUR*L5z6>eM@6k2R}LrE5QhoUP_8_4G;#aDrJyGzed3Qr=Sm%UkcPKnsH+6ZYtX%jsO5dFdn7dJePyBB||Z z8aIrQX1hj>Dj)_6(kE4XpH=iK$+X`AHgCs=(>uY=Mf343qP_mvvgqiOv>4sCN7_%S z!~5R*K#hLzFD%ynwS7c3O)Iv59n5(7{%79!mvLDFY^=>U0BnqdZ3a9wJS2<0s4^)p(e{4odsDMEzsqWFrIa0+pqaVq-P~17% zM_3m<%Ld}JC_;?8S)YGI#JuBQ447a>U9mV|izP-=ll$|hbfRA;QS>$1j!{#_B}GVZ zB&pw_1(`wh(KLtQx~-WJxJ0y2bx-iJvu#~ok}XtHn?d3~9=&|EL`ZCNQi-7^qiZbv zQ`X1;&}0(>;F7F@erfoQmxF!NW221a4W^@p(O1}w!NMx^NQl2ohzi(L zFdC6>nUSv{EDXJogDOW|>V^HJUU>nn9|btwTzW|f>nB!!6`8w%E&wSfzICI9Qgk`C zNXK(II^nG^RAgk{_?K3XbSDxSGUx023W$NNIicc!!L!5H1PA!fz}xyOd?t|Ik?!8x z05N+g+bUQ39Sp`b^Jed5c9To@f$ED0{=($>+4+dVVxchcx|IfocR1O`5er)D8!9S0X?oF7x(OTDvINAe6Z@YWL9kRrb z91Dt{1p>kC!L zD8BW=ke`K7yG`G{+PCO{isn6+rlXYca+Y51`x0PdCc?IJaF|H}OsyeRV7Y)DwegPB z56lkc@l?zL2rfu+U!`KA(SQn5_^B)Zfcu$qdDOQSx*yyxcS-w z?EF&U&>-Sl85xIFT@!#_3c`0(1SE(9DysL3Qm11A&;{5CR+EZffGk#~GqAmX+Q?TY z?9Q|is@x7LZY%(clP$4Hz|aw<$q3es~XeLC~$tn3Ny#gY#Pn zCK25sR*dE3+PAEVGi#DZFQ%WrBQDFfN!SEc-j(F}lv~P>6r#t8$VyvU+DrHwKL^DZ8G$H97#NnwjK$J&n$b#pJIa9 z$|c;N4#1o(7ILPBtw^t=-Z-E`zQG4$6+V-H?0#Y{FC+6A3;=NN@eO#K3m`MUb(=JR zN8U6>P-mQ-oF=#Z&OiI_f_h;HpEvw|FPyi5QDgu6gXV8$z5y!useY|TGH9UJzwCv+ z$3izx&Is`nTpdw;OV`3+efcks_lUO~u!Yu3MO>!u3eJEzzw=OG)_`$p1GL>Yq&B*b zh%R(lV0;8yu1Jl9LtJKBu!%DjOx&VgIN%QL4~;;beVr!m`qHU$x?L7%Gynrku-qAC zg%lrX-tmE@6&sI@euplFZfzi}Ad;r+@=8B-!1Z!a-+jjMeG2{Dko2&G*ZFW!gX~PO}n_O6z)9dT4{2zzqngZ<& zY5^kC(XS{{udVTQY5?QOq_DUK!w&DgHAc7>Q0&1h+C$=emJ(7Jw)TtbEa=Y`Jm zscq*ns*93GxAiZedi6MrdLA89qpg#12bhQo199WNqK|MnyogUJ)Ab=TC|^j+5?(Yu zbB$nx?}apoRgLGcC>_86ZblXbEV?Upxo=R}HneO97QW8MuGcxBl8y3r8f5U9y zUP?ACQEZa~+UWIyH+Ze-lxxcY8W0e7cZ@C&TqGeA`M;~?7@hqNPZkY zm$|C`gkXVg5>k-{bIEyCM37nwEL2SGD{SU5;aN4lyIxNPt=A-;l0=-8xd;=aP?`%M04#)96G=(l{n7WP%Y|w3NQm zC7r%F8}&-SZAbAPGF}l=`P9Vo^QenP{x!9}gwlCJ`g94#Ooszl0hwO*tX*|msm-ZoE9-`cV6iscE6tK8uuQ8)KEXxet<9FPC@@W$6acerjuZ1eT* zgt;zW+ftOgfBfsy+q`wwnLaRhJr2@|dij0t$8GaJcRc&OPTFnt56V~OGZn!OYq{<~ z^*^_Ou-c??=r&w;o$2c9f2))C)S@(YhpA?)!j8Ab>9^I_UHMuE=!b(@JdU3eA0MCf z?Z}AYKj~U_=0f;J+|1kE|K#5PY`3=p>_Ug}5++)#?HW_XyD7K-ABLL^*VR4vzHOA; zV37L{(!1OHv8iaY5k@RefB5Igz@>~tE==}KGh0%53oKB@Fmh~apt6Qwjh}X@LkNiO zN>?=Zl${(jnLg%JEBm{cfvGFXK25zp;#(WRXT8EDbY2dh{qn&>(dVJ~T!*nz`t~oZ zpkGkSvy>rCeAFxP1OY#sGO1N*w=hf5;DN=HWkcHa%X8CbvUSx4?AetJ)BqCTERqKc zguE3guO^VZ#Bx5H1lkToEipah%U2awiKwFxohYi6BdYL z62A$b@Ge=?IuxUB+HV>@^SMiLLtX#z5ZLwkQxhW%aj8>1!;Q`BEyc~n_A%lM`C0vi zj7`ReBl{6q;>y9f<%UtO|)?1-HzV3lo3 z-#LG-JDi+7dY~J%+ihg@zU70KOtTUnW!P1L9YG~U|MbSh&Ip=kj|;_u^s1iZgu^Vv zER;`06p|O5sZ*!5TY|hJa|6*INp@O#H!kJ5AzOD3$O=;S1#kE|w^kp=_dD%&>uz&2 z^1Oky{%mPvu6wWfxkla-WpK(1g<@*dNl(9%{QEf~@emIcW=FKI?5?8 z)i2b9A_Sl01;v{OoBnpEYZ*ovg5K>j8^2vlI5SMwnQDo00)D1jynpJGIy9u3n-8E6 z9yz&91p9)VP*-C8Pux}}C-y98`BHoKAOx6qJ7i9^;}cSF(L1+~K~=U#l20x}$o^&f{DbuE<;zkaV*5|c2x5pU^jEY9tbE%hG*s5XAPAkV- ztDCkXnTD^`b3~l}2?IGsKFd#ebq_9^GwTk)m<}*;_ zpD>=cH|d0ul3uo(q2~i>WRz*irblvpnIq@?iu@kkm4h0Eo|H6QW0rzZm~6}tNSNZ~ z!k&Ox?lGcQHX;5jR^0oUKKGv~k1!$K_r_@~E2CFQNfGD$Y&+C%W=7V~<;BC!&^u}rpA$=~As`{VUM^k;>(}GbO zRf-M?Ys)N4+V#5Vqi}3yG-u4@LBQgmp_m6ppXR!|U%rQ>?4QUx-BdKM84Am<0~0OH z+`C@xrGQ70OZioo7|5}O?^kT@ig03}ugb!dfu|E`kgn|@3;#j|8kT$29sGBjN{Kx1 z6#D|?W=J{Tn$RQ-$lwxpcc)n|KY999hhA!WI=0m){ZX`p3r#I?4B!()4~9eAqh3UX z9FoJgQVDnzVyU3IkJq&8e!%+E74U#^x|6>2&`w82qaT#SE_Bj;j7voq;?O<`>g;Ut z5bnVbp1bpP^*jJw&s67yjsO6|CZ7}$sL18r|N+utQ{0m12Xh)py(CN z99htb&F{01hE-9meuJ(6uwm>jiX^Jb6csqu21AN>Kx+dx?lnwlbiZiWA_)bCHsj($hyQCRf z8y-X|whz?Qy`rE2uJsd+cjgRHAuA?2*AO&vLi${_!iQBD^qWT=6w@r7iGw&XZNm>C zkkaKP50y$gD;2PsEvesyvh0(Upf$Gk#OAm!`G9Yr5!EL9fED%{6|Jq^QP{Rwlmxob6 zav8ItSsj2#Bz!&^qoVRMV9{^#H6EU6<~Ckc%*j{4k|rxFQ}?SejRA~0e$s05+Mu6_ z@s@F|{S4_s<-zWKd-lTiy6qN`mCC1@4jp)8W?dCkOD`imzCHvYL?#t2! zEROY}*fOi!^?W||)!mUj+Bh>&uBU4+-sGSR!dALK5}hkP%xqa3WIt-03rgBj>eCXO zL}_&zI7R8&l#cYYh;7R&22#odd&+1qj_8H1+PQA8hJC>%2l%!oadMOSgOmjR0*FJa zZWSE91SY)Q|0e9K&$Z)-*PL=#fP(i zD;FVLjXXB4!j{6RO|-M-EOTqBbQPC3zrLrh<4N?-AO{4ug3EH|aCEpw_3J76PrNeY z>C*h+rjWk(a8|P(W(I|KyJ0}D{PfojphALzN^N&7UpJ4nUVpeP6PMR_Z?CSyl7))l z1FhtxivVNC+ZI7Kh^cRuV;J%_I!NeHKXV{|3}uX zx&*UE3ms8it^VnBi}5td-VjMo5F381?!YTZvViaQ#fHQt0^t^)EEu!)63%*&CAu6j zjaU9-OYr1(H?Rni=jy@hR@H@bV$FlrhjJYmtwUaP!^`3{?tte__lMT&R{Ix^``z7B za&FpidzbtN=hLix(P93FS<_`>8lKc70Dd|IdmJtz}-+-&lbrGW^RZ`XbRb~y zwRS_QhFEJAeylxlR6ZzSB>6r+xjwfNKpBDBEfq~yVyUZX|<4OQy+E#Ae=eVlwg z!cT_OYpp6Tps^>rY=r{WEwFy7To@zXB0U7V9)&RTtv0=f=~8RjZ<}|EXsN; zlR7I(;>FY&W^&6z=21r}TrZ+*T!~7=RdJ!S-$^70WJ-QZPo*-_SF+pM=pw9n0m#OXuRG(mrE^_0d!K6_OJ~pMq+q z+R3)j>9MC^kV>xCGwvN$5j}I|1rZNuZOh>%b20mr%X@yCaX;j{u21tSgX6K49{To- zTKHCeqUF4#*j4UgJgg3t&i)%BJZz}m&nz^Ma$xKg;-p1vJnp~)IWZ8FT0jHBY?x? z&&YlN-RLZdN8wjPMi4I73o7gD1Oe#(H1ZL#6qFR$j4B1k#aDKc=uZ4^Y1$6RXzrL? zA+*RkXqLu(jDax7Pt`y=9&B7L!zk?V z`(j2dOR;W$ag0lOJv`(@FC|stnu{#nTD;SFkUHX4qXDR8j`#cq~Kq8NbX_F_T~Y6 ztWPUO146#hEe^vY;C;a9BO+KzD_lUf!rDX|*|WBwm}YUffl*~=#YuN^W*nM82D$<0 z?l?0~T$z(Te@gv8nbs z(=|L#@p#5JGErxMw4s}Xi99a3qiv7Oo4B<=Hc^9BM|T|?gdzdAHUi>uY>ybQce=Im*!Q^b?kZvQX4 CoHy?P literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/images/wcc.png b/docs/cugraph/source/images/wcc.png new file mode 100644 index 0000000000000000000000000000000000000000..2d27a3f675c32158a6f3f398274003e0835866f8 GIT binary patch literal 14519 zcmeHu30RW(+ILNjwwcp7Q>G?mQ%*UiWVoeLSvi$P6N=?VNohf9YG#VaG)_7*E@@8g zYnBQKWQseG<0S4{3gljippcRx5Fq$H)|_+Zyze>JyM5nxUGMvS&jl9`&-35^_y1mg z_kI6-=Cq5>I)im9R;u$-b}nSc)yh3 zLUidfh®V%9J&yJ=SZQpaf{IH%?yI$BU_msU~pV8|u`ZfK;;oZd3 zx?kRSaHe?vYhoJwMrb3`Y>Dn!@sia=8FcKNQL3HJy|h!krOSqx`HeI32|gEpz5jgz z{(VBE%UwOLSP8PCj7y;bK1>n5t4`C&r4oYcV!2CMp=!z!ctYvBgld%{8L@5=Ez(8} zD@hcozPt5fwWWC}hG_Vw+8?WBQjze{>wl{Kp&Bzr4c-3pMg8~>9|3Frgp`ZK<0N;N z-rxQBEdFBkQ?tm7(wrWXV)S-q zT7(oh;i4t{?%YW8`QlreMMZaDA?~I`HZvS)ZX(x)iV9=q+nd!F?lq2x-@@L)D_*0% zmwzvRGa_|xcRVYnjl`LR+2^3$q;W^!=GZW|(j0@Axi3pj@7TS70;4h+TjT_7JB0X| zv1(I&o1#x+x!q*Bbjmf{XTCp_DqF%jdOMw84RJ$z;2Z*O?~rffbR*0WQSTo+m_BPL zw@XX-xzr0zH_Q><)pe6@5WKa6Lk+QuKH_Gau(IPBwgp?w0`1G{?KMhEV7+ihJ;1)b zNR%{lY#y=WpC~`$WiiK2SyIwNr^8SB&G*>^NmuO)=V-3aYUPG{ODlh_$`;r^=e8R- zsTcZ(xdX8$pi`1y+)q!&7IxSo*{%aor=!FT*1Xd3deE;VFEe^6J14ua9M?~!HVS&r zDLL3D|_Nk;F*K6$y4UQ1bc{JUJ%|k!qpmxS`7c1mtZdI68aZ19yh2 zw~@>Y-{Ppmo*Q3c6gVH=7}i{z_Y^F@dZQG^PMa_fI?9V~Hj7+xkUg|FgpDS>JX+w0 zy(y_GRKGo#d&XK`eHmW5?a6D7kB_Jx3z|=}iHi0sH}nZx(4li9*O7d9?F!s@zlKPhWoXbHmS- z?U#1;@^1-ud!FxR>9>ZF@pXql{k_GU$h{(>3Etg9q^hrF3TXt_==N6>Z}sss6@MU zREj%a`C<*ipap>wP)bCfnB{v)Z;)YZ#yxpGrap&yPZZ(9FN^(lHZOr3OESQYKPFP; z$$VUpSv12#?7Zg2T+_nC0c9S0-?(I?PSDz1lYH!7Eku?>e-8 z$5>BX>rq@6qJBeQAB6;OZPBWwwS!B8uWwC{j>eTerYYi`xvE8P(X63_ptePJdk?pV zPC2ivYmMnLtU*uBlnq?p%R)QeP=67Gn5aJuRVKhWH?EpLpLA$Oa3}GP-o&iPdc4^CMh+ij;q#i_tSyZ6oodgJ- z$3#cbtO1z8sfH}AL;KN@-Xd|u!V2`s6c5MjG76phiJHE#j5qSa%&2jn3Ab~3Vuzm0 zRn0x#QMqZ7YHZgd|7wRJ7crm{#I6mliD;Sd5|%Hx(L$oTDu;S(oYsthmWQ(;XHMl;@kQ6J{yIWZqL)r_o1@6v5D>EPwY_>u zce?6JT(GUjxL@xc$W5Iie21FC6+ym*3$P$>h)gB_YqwAh^w^cNDr7^7$0#XhD)Ox()I+rn zcXv9YDcck?Em>%2)O!5lI$IZkyLhVeBvtX5w`eUZEGSt#bQ6cTL@B^T6h7T!huxNM zu0F<7)~|#6+O#)(_eAlIrfXfqxJb3$og!i)#JO;%H6mL-jAr1h&2he0P#90Fm_5R$ zU`kWE=nb*uWXu2;k7iyuBUsTRN?0%8?X);!MKi_oNvK!9%J-q*&S53;N?r1oEc6Wb zB8l7!OM$yV)I^S%`98uO(E=%F{Fpd=Hkq<(A#|BzUP|%ljADZJo!7RU{UU!lvI86Y zi|oD!erO}Tkd#x4CP2zy?8}gSJcH)YpE79q+(oyY?@lIrP|cOmDP(5LEzTh10a-9D z-oc8PAvC7MkpV{89?G@6HkxTNZ&|kvu5--<*WVeBuu-41?l>yW)V&>@U-mk}s>=wc z_YjYM9pT}u^nJ22K;wnVH_g39(ev1Oj=Zv*#6L&_-Izg0j*Qiko$>@;>^(!f5XzfD z`pyip?nBhbBZR+)o}F^>{3$j>$%X{J^+e*Nu^Txo`-1|`=iSPolxbWhwuIlY=u${vN+=x9n^|0B7`m9dj>@bnWW$zoVwwaZ1-!p$apOeTH~~--jCAi z4m~=*vTCT}n0-v@O}(9Gn_b!XU(N3(=u6F80N#4HRQbuvMWNo8Y0g)OllowU=CFG1 zd0}5-TZ=Osx7*>w8|w5t_ja*{_NsVlf;W5NtmhD?ZnqwZkO6CSj`C>QzVGv<%@Z)E z4s(B_ozj#fVfNpXIx<5}Kc8mMvg+h*%>91eGq)G#`zzJ{x+ z%OnH@Nx9#b#`0eW{#v7~GEgI+%(5%kb*wSxO=*~;aNjJ;Ps$E$!{RLQYu;Z^zSya4 zyX3oA71(P%jv1=oQfK}ghdPzSLgp8Sc#*g0GOB=!R zI7Fqweb~$87Um7G zvWO@${!&R0x2=dJ+ zw7MAGaNl&Mk9(-cYsq(P=eH`MTC26pF=D{_jVDvdjtmQT#I4LvuS%JbU!m)6^=5_U z$rHiSz(@BmNW5Aai6*92aaoH`X1Pl>!L!bZX zwZC-i6ph|f^?(Ywp%c~JYlb_I(OmF>JiIy7etx(G(cN@>mETqvvK;l($d-r=X&M@R z%-O^Q8|OKziF$hil53S4sD3&$w0b%{wuz%#u+A=woU&GGpO;hnB76me%k&pOK(-4| z*=xi!WYhCsd{PbOS7t_KTY#=aI4)2@{Kugbi397MRce@hnQMcpGs!rNZ!Rzc^Oj z_^76S45Q^2?g|x7g|Ebw*VkHk@_Eq*^fqN*Qan_SimYOiTCKinKYG(r;0DKQgh)Q^nC;b3qfcCasophCnNu9b$q4H~O_F-;6oLmGa^$ zWz?weLlu2*x%Uha6GCk3`g%!6IZ8V}@^g%oQMq*s7kC>Zf zCAL5{$`&5!qA=YFZ9capyw`PnSdX7eekK@@tcY%s%33*4|DwRIMn>;)hHgqbLE*Uf zBcS?&+p82@RBLKn@)yG0ps7Ng%n5RL4d3z3wXr5yC5;mES>8!t1Wg>H z1*!s#qall+cJl)^*T{D?>>oFoNd^VKFe>WD5*AK$PmiCwB|L~$#t@2*P8BMdvR1_Py?3~oOJpon;EMv>1mlGmYUzI z0N9fCP@%xF2zs5&r-;}0K|#OKEG9g~mtWnwyuz8hw#~EmqQ`o{ z(Ro>f=aimvf5b!cfaD1;vxvS-un!4xh`n;E&|x}P@ubwL6wcbedm?(YTJ8{WEnn~w zyPBTQh+^tq?W7H?SqL=RWd0V9o( zTY^2qj~X$c>-Ca&`BMcP?ikY@Tk~%9AgTEBh%2OitDQ}Wj;a)&-Yf}3CxNVTx z+XKxQNT=^`2Qi>L!pa(ZhKj;$Xa)!Qz`Qv-P4+hfM$QvZ%sGy3OTy8mSDWp=oK@%} z6nbwzWmiDIP=d5JZoRoFO*gqzx$`Snq_TfLq174k$^7fS8@NITI_@XfwW`d?cB|K* ztfM@yf4`RTHxvn{>+ma5&8L>@0VvX70?Cps#Saph-dk6l!$YsC3#;FczY>Zs^oXq- zWd~hhYR4LqUCp^{k5=i@y-cOtnwAht6%l(*Bjvs~>O#sPfc4%pGh!3rme4wYQH*nc zE5<(NLyzaj$#2U`YS^!owd1iWx?HqR@-qNuMZ9Zvg9MZs82L#+U%ERxA5>2<_eXOB zh_`F>&4e9u1*uF8>_?Arvb|a4d3k{)G=hhFFuPd1?urXcGt94;?#qfTe9nkS=CU1c zCE2GGG3;xYHhsfNgJlTp;=99c5lO6!O|$V@Ex*~tml!e&5@XHuI?oM7b%R^Kfvw#N zJGk2{sJ(?&?V-mJZ6HjOsEh_2S@aO!tBae}hrW9MEOARM?5v7*r~v( zwSQRP=Zl>F)F*JPB2G5I%1b&CJVeUy$UMYwF3p*!FvZv(*sXR}#!9+v#<3xLAxpl~ z;WD{av?0~Q1z6`O{g^XwU!Cz^&{zE5TS7^pS*gzauPrlTn|25CUqQ5Ld!Ki2l>Gjf zZGUhfF>OP{Htx!McC9`>+EElR`IX$F*x) zH+P+YoHo-x1@#kwQ!N6>z29|vnFPE`o$IJuxWm}@Nk3P#QrT*RKz&OI?%Qtvm6#)s zyV|K2Vj`7BaHFDe)yJsO8wRo)L`?(>(K)i6WdX)vy6kxCKy(p5o}E=K(AV#^m+xnH zaaq!#;@aX>_^EQ*F8~cl|4`$ygu##Pz2o^`gg-e32_5YNXBhDwkaLAuzxN>W+i%~D z$oE!fCu#`VQ<+gD@Q9aq$F}_zIJ9F^7{e>Fn?*CH3&HMmVYF(*F|d5%_Ts=Z`ne^` z)Ss+LU|y*s&rS6wtT?Kl@?<+1^ zZ_Hk=bf5TP4>79dRkW**W4yjOe>Nb2%jTM2Ae|mMyK}<-T@Cv3Y2k@|6mD4tuVbzw zsTpzbA%M^{b)CX85xu(8&q6sYJM|M;*ZTU5h=sh{8p7I|1EhN0F>wwG; zES8;-+-&2H=Z>{Tv0VU~HP-4hWsCjBtI7ykuBQrYgK@sadXluz(d%G}LB%1^P!ySU z#?9m*h1WcF$>~AP6Td)6cCi74clwfp`uwR7?bg<7-@$CIJZ5jC^)zf&hLCfk?}@Yg z23WRZEz4%P3}kvFKa;_0vybjcjU}X=+uH3hO9m6}79oZ471>kr3qxze&X|dVxLFd- z7xbPM3_b(_33OausOx}*Xg;4hh0pp7AYy>qjBF}+DlZ#YepmM&XFj~Wt`SQxDt$l91-946x_uKIvf!HGxg=OHpTydMDz7pa8!CwP*=ak=y^W%vMnpzHFBg1V zA*_MkC7)YU_~3`zc6K>zP4_y76^{H6N`tXuFvdl60fjl9ZDcmHjLjRHfCk(P{Y&dM zPiW~^=1B}tn}rdrlmaa)3;q|*2I+$sLe%%NtfTXxU6`5I?W&IZRcRIVg5f$z z?TcISR|=E`LEnW_IJ=fNF%^)g4TeXO&bcC7w`!4dsA!HYW%nlkw;P`)n?oKuCHcBS zQWGS>PNd?urFNIDF?-Ae;8Z~8$ltk8d1$zxG^f5Tb523*sf@byDTNC)Iy&yRUTt+8 z7&WhW4q(O`ydh7{xNYUubRld49-~D!N>1T>u;xbnGB$t+l*OaJ#_lU}44@%RVQS}3 z4)0TX2#szVl>u}5wPbB2EyaBQzN6_O3`IbAfuG(PtBD68y^UVEANKflTn_r_B`LoA zr$PH??&0Lvz#MK&rsA{Fsn&OFQ^Fyqt{vlzvoqil$1P_pxb+CCDRQO}9cSbqy+ZO6N7z@il%8>GKWr*`h=75FFge|rAyLDv(fx9Wr+QUi zan$$GVaeqZsXXY>vhSwN`So5`TFfot#VyMooyN)!<}Kjn$W@v;*i zT|wKa9r@q zLA&12b))6|Aa8d-JKfFceJ~~|R{*!2#?k3js=z|`&cR!ST_t?r={G`0pg;+~>TR)n zZ&|5VPpj$@(40$m;CEu6IpFrPF5S}(!Va?@GyT{~Vl+^Gxhi&w-a#N_RLwbMyL1EH zpC7<}tUBl{_Bag;Ao(8*fEtD-PF`z*@jncAv$Wu2RaqbQJ{)cyIDy;$aXjVnk7Iw5 zFhq&oiAI_`fw39eI1<5(%%`{sz@t*bR*2UyZy4_hkQLu0CkJ=2{Q1 zp8sn6v+4J5(_7!4dlwk~FZOFPrn*$AkO*w3e)Qj%R*U|s&j7C1f4wFwDUU37wO_G83E8Lre3Dtuo#o!|NrlLUXS=*7oM zT0ok2*0FPm$R><>U)h2}O2y-bEf)DU%z8x?R8=)B>mcTbLM&L)nE617F8+N>Hw>=? zTyxg9*uHL_;PX@p&@`#EJ)odueslFMZ7Ms9i&FNNW+xnbbRzteE!DfbpxpLYY=XF|VVG`$tZT*gQ7r!_$ zY+dk6pj%mWRZ7L3n66D{hb!wXz?0nHjB_immHh2!Q^PbX^hPIsZAVQVNiur3va`Q0 zFFZMXgR#Pn)#V|44|j`zBpM0!TzrqJr^k4&X;ia_knViIQRR3*_A=?R;h!0jWDGEDjO=FTrFYS?TgkWi7#*{gOhr+S zzX|gGb_8;;Z4Mpth*N!^l>34#B@bZ@seqRU9R8!1g+%i<5Q@d;*L!;PQwSP(-n!kE zF!bHZ{Yj%S@#f+6XqZkEnHQ+n|H~%6GNV#0z-t6z{gJ-pkf|AI6EG+Yagy7VY%+$6z_W!?)ZD?lTBBD4~*cy50E(JijtRGbV8VgFf}MoV#A_a89;M zI{f$nDQDsg-nzbrYkL;!YN4<#%dcVQY^0SLQv^TX^4-Ex zkas2U<8#EtPgnup4F4Sq49pd-ZMZnPpXu%6uV35~d6k0s6y;awGydsE#%Qb6N{!w2 z!&7(v?f8Wf&?v)yA2?eYo#2$!p?^_1kOwAs&y4=~!Xo}ljVxqIhSO2_*=TKUIW=%| z3S!DdU^iF=>B|6BysyW02u0NG_GLlc7vJUnowR5|!*#M44!3vB`mYeTtP!l8V_hMt zKwh)!ia5GSKwe`KvEEC;01^UD>Q z5OBHxv~PDPH1F>^TozA%<=oKReMGcOM+M`5kgoL4r&0&gEF#t3EtVTvEGgSBC6{_p0L6Mabjt_UX5dD9+d5nu65<2| zzvKcTQ%sB!Q)aQG zVA31smu9AsYUKDGkkvayj(+v z@yS8At>-m&mt->aNYj%P4|Bc4ZUhCT*{hGp(3Wp~5?}!zU8b&H9-yq%&b6Bsj(V!NC}REbgQ#Bc%N%rzWs+!PmtpJXRUP1fFVedqNv|`e-4-*K7Len) zmz9*Ruej5h6Z)^5kUvPWu*&Dq@UQAP$va!{aF=yIoV*u1ggD zgI+{_wbZEk;t%I7!oj9Nd4hKPwEB%S!8G;;ib6D1N0CdTA(7(D`dKr8*pwZ1&)g+L zfLiRlyEbQ1pEt(L5!fT(K>+=IMT~l8pW`fLoHzTp3-p4%S^~0dJkkskU_dci0KadhjFRLD1zBc7*V)i2#^Nv-C zC;YM(%x8y|C?6<=a?t-1GjMFy!6cA7D$(vMe=Re!xa8)F-I)`qH~IeGP`xFe>-P?1 z)&w7!9#4+UPanE^ic$fcq2s&6#W!0&skfYo@IAjCQrd$G4_w40R6KX@G|zhD@Q5Lz zGWwUPKp*R3L+5_3Ka37Ao#-3UoTNt`V24RR^uD3%Kxwo^$*o|Y62h2DcRN69=AL&; zhdyu^(|^rjd>sFWn6dcpf1Z7vYhL#19~jKMx0*rI|JN<{e=ddo+kp@Ncq)HUgKn#M z+IWo^VA>m6>IA7O)E{DM^4bCRatL%KtW8~+NtgvqYl;fT zYJk||++HB^lrI%=Z7h(CCb}^INJ#S~@DI#bztw4nJfYNUpH}+fa(pZ2kTZlW)$doo z9jqFjXe=UF15!N+_eLKuSwSJO%KpN!#ditNrex%jl3P0iKnkGXSl|^iq)ZoQ24E0b zeeRvlE)W5@yi*=uZ*cY1^uh-zxr_(2s`h<6->5=ct=-Bk%=gb!6Z)tPR;;Gr(36FC z*SgVsHk2An~{#v7MJyX#g5*$09UO19-$xSK;Sjw~C-36j;fMo?& z^V-u7;-#zpvN|K57Pv1h3ZP8K)~~tLo4l(xCx3MU>E)p7v$D{yu0rA)a)b`eJtE%^ ziU+U?p?ebw>{$*1702xKw4`sjad&ukKbpZ$D|;aYnF2nJ<;|FAc=N^l-7j^w`WIXP zS#^~yM7051;7=rI6@1Hb4|iRyob&jM@p;(ZWj?*3<)vv#s1SaIf$5H<)Y?Akb$0n1 zCzJh*dgnM5V`_i98NgOYv-lwOluIqu8|Q@h8z)2+rmCF(RA+5H^xu5>e;fwv0bO{2 zeIL-d=O6TBWA_1O{htIA{^OInjClStw3hQ9PAK>j2n+z&f67T{`w)f?(AZ~y?#uwv z4cfa+3uD40%i5?7aK;-=u-`MLlEyyjqRW{Eo|8`PbhVI8iZsW+0hC&?6t9?{Y(&)S z_8D|PcXFqzAa)>`j_mFgZy-WaR9(Ux(e*CO^ayz%whME*vej24n8abXG=<4~od_8R z)7nTx9rj}k2gN)@X zyPrquiXQ-6_PqKkwV+e;_{0^WyhMZrhB=W%aJqDyH(O*A@L;TE)LkB4zhQR98Fox< zGQhiPLh4w)dP&7H9y$fTxZz_m&3_?9=HH~~V@EID%d|j_HAPE>qk|EL9D56)_(&NJ0LO8X zXIqd>2ldietNrSieH)2CwaK6I5YB_W;&g!OM`8ewx|YiQ9Sm3hMZj@<65A&H|3XP7 zCQr(Ko;9+ff_{MukC+}gYP z!y_WA%5lP&+p&eQ%4wCR$q6faTLV2zC&Y2RLO9RtW+J*tLyZiRdD}i<{xLlFxOJlo z9wo!c{@YGrmCc;GtllJ2e5O;B(7RBBP`=M&jT}k3-ko0djFxc*bI0C=tV?uq^z3c{ zOJwzcV=Rpo2jn;Lc53VC$E^bmxpNo27jkaT>0hO=^R`jscD7X7cmiRg4;8sIrTnFB z;?Re=^y{E;_YA_fmb`Skv~5`DvLkh3V>Wt7s3$OQ^A@cER{$|7GeJEbApR>a2rY?( zTq=(kP!{F98KPo5rtPJQW4sAK=mDz&*FhGUBeUPncUp`u1y7<=ic_AQUZvHsoW#`|&ozV49sr9cpVHu?cdZU0S2AsVn zPqe5b%b>xBe6rc6Kz>>ySN_M@srgTM^Dn{GK!)4+&1Ge4`7_aY+jI6CHz|-#_g~DT r%Mr%^DGTpI^4`)T5#S%fuRJa3aff@YHwBUnSDZL@`Y7edg&+PGwGj%4 literal 0 HcmV?d00001 diff --git a/docs/cugraph/source/index.rst b/docs/cugraph/source/index.rst index b18a79d3396..9ea9e4d65cf 100644 --- a/docs/cugraph/source/index.rst +++ b/docs/cugraph/source/index.rst @@ -46,6 +46,7 @@ the docs and links :caption: Contents: basics/index + nx_cugraph/index installation/index tutorials/index graph_support/index diff --git a/docs/cugraph/source/nx_cugraph/index.rst b/docs/cugraph/source/nx_cugraph/index.rst new file mode 100644 index 00000000000..ef6f51601ab --- /dev/null +++ b/docs/cugraph/source/nx_cugraph/index.rst @@ -0,0 +1,9 @@ +=============================== +nxCugraph as a NetworkX Backend +=============================== + + +.. toctree:: + :maxdepth: 2 + + nx_cugraph.md diff --git a/docs/cugraph/source/nx_cugraph/nx_cugraph.md b/docs/cugraph/source/nx_cugraph/nx_cugraph.md new file mode 100644 index 00000000000..8d497e3a1d7 --- /dev/null +++ b/docs/cugraph/source/nx_cugraph/nx_cugraph.md @@ -0,0 +1,165 @@ +### nx_cugraph + + +Whereas previous versions of cuGraph have included mechanisms to make it +trivial to plug in cuGraph algorithm calls. Beginning with version 24.02, nx-cuGraph +is now a [networkX backend](). +The user now need only [install nx-cugraph]() +to experience GPU speedups. + +Lets look at some examples of algorithm speedups comparing CPU based NetworkX to dispatched versions run on GPU with nx_cugraph. + +Each chart has three measurements. +* NX - running the algorithm natively with networkX on CPU. +* nx-cugraph - running with GPU accelerated networkX achieved by simply calling the cugraph backend. This pays the overhead of building the GPU resident object for each algorithm called. This achieves significant improvement but stil isn't compleltely optimum. +* nx-cugraph (preconvert) - This is a bit more complicated since it involves building (precomputing) the GPU resident graph ahead and reusing it for each algorithm. + + +![Ancestors](../images/ancestors.png) +![BFS Tree](../images/bfs_tree.png) +![Connected Components](../images/conn_component.png) +![Descendents](../images/descendents.png) +![Katz](../images/katz.png) +![Pagerank](../images/pagerank.png) +![Single Source Shortest Path](../images/sssp.png) +![Weakly Connected Components](../images/wcc.png) + + +The following algorithms are supported and automatically dispatched to nx-cuGraph for acceleration. + +#### Algorithms +``` +bipartite + ├─ basic + │ └─ is_bipartite + └─ generators + └─ complete_bipartite_graph +centrality + ├─ betweenness + │ ├─ betweenness_centrality + │ └─ edge_betweenness_centrality + ├─ degree_alg + │ ├─ degree_centrality + │ ├─ in_degree_centrality + │ └─ out_degree_centrality + ├─ eigenvector + │ └─ eigenvector_centrality + └─ katz + └─ katz_centrality +cluster + ├─ average_clustering + ├─ clustering + ├─ transitivity + └─ triangles +community + └─ louvain + └─ louvain_communities +components + ├─ connected + │ ├─ connected_components + │ ├─ is_connected + │ ├─ node_connected_component + │ └─ number_connected_components + └─ weakly_connected + ├─ is_weakly_connected + ├─ number_weakly_connected_components + └─ weakly_connected_components +core + ├─ core_number + └─ k_truss +dag + ├─ ancestors + └─ descendants +isolate + ├─ is_isolate + ├─ isolates + └─ number_of_isolates +link_analysis + ├─ hits_alg + │ └─ hits + └─ pagerank_alg + └─ pagerank +operators + └─ unary + ├─ complement + └─ reverse +reciprocity + ├─ overall_reciprocity + └─ reciprocity +shortest_paths + └─ unweighted + ├─ single_source_shortest_path_length + └─ single_target_shortest_path_length +traversal + └─ breadth_first_search + ├─ bfs_edges + ├─ bfs_layers + ├─ bfs_predecessors + ├─ bfs_successors + ├─ bfs_tree + ├─ descendants_at_distance + └─ generic_bfs_edges +tree + └─ recognition + ├─ is_arborescence + ├─ is_branching + ├─ is_forest + └─ is_tree +``` + +#### Generators +``` +classic + ├─ barbell_graph + ├─ circular_ladder_graph + ├─ complete_graph + ├─ complete_multipartite_graph + ├─ cycle_graph + ├─ empty_graph + ├─ ladder_graph + ├─ lollipop_graph + ├─ null_graph + ├─ path_graph + ├─ star_graph + ├─ tadpole_graph + ├─ trivial_graph + ├─ turan_graph + └─ wheel_graph +community + └─ caveman_graph +small + ├─ bull_graph + ├─ chvatal_graph + ├─ cubical_graph + ├─ desargues_graph + ├─ diamond_graph + ├─ dodecahedral_graph + ├─ frucht_graph + ├─ heawood_graph + ├─ house_graph + ├─ house_x_graph + ├─ icosahedral_graph + ├─ krackhardt_kite_graph + ├─ moebius_kantor_graph + ├─ octahedral_graph + ├─ pappus_graph + ├─ petersen_graph + ├─ sedgewick_maze_graph + ├─ tetrahedral_graph + ├─ truncated_cube_graph + ├─ truncated_tetrahedron_graph + └─ tutte_graph +social + ├─ davis_southern_women_graph + ├─ florentine_families_graph + ├─ karate_club_graph + └─ les_miserables_graph +``` + +#### Other + +``` +convert_matrix + ├─ from_pandas_edgelist + └─ from_scipy_sparse_array +``` diff --git a/docs/cugraph/source/tutorials/community_resources.md b/docs/cugraph/source/tutorials/community_resources.md index 1c4362393d1..975f11965de 100644 --- a/docs/cugraph/source/tutorials/community_resources.md +++ b/docs/cugraph/source/tutorials/community_resources.md @@ -1,2 +1,4 @@ # Commmunity Resources [Rapids Community Repository](https://github.com/rapidsai-community/notebooks-contrib) +[RAPIDS Containers on Docker Hub](https://catalog.ngc.nvidia.com/containers) +[RAPIDS PyTorch Container in Docker](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg) diff --git a/docs/cugraph/source/tutorials/cugraph_blogs.rst b/docs/cugraph/source/tutorials/cugraph_blogs.rst index 373e846f6c3..3665f425e3f 100644 --- a/docs/cugraph/source/tutorials/cugraph_blogs.rst +++ b/docs/cugraph/source/tutorials/cugraph_blogs.rst @@ -9,6 +9,17 @@ Here, we've selected just a few that are of particular interest to cuGraph users Blogs & Conferences ==================== +2024 +------ +Coming Soon + +2023 +------ + * `Intro to Graph Neural Networks with cuGraph-DGL `_ + * `GTC 2023 Ask the Experts Q&A `_ + * `Accelerating NetworkX on NVIDIA GPUs for High Performance Graph Analytics `_ + * `Introduction to Graph Neural Networks with NVIDIA cuGraph-DGL `_ + * `Supercharge Graph Analytics at Scale with GPU-CPU Fusion for 100x Performance `_ 2022 ------ * `GTC: State of cuGraph (video & slides) `_ @@ -50,6 +61,8 @@ Media Academic Papers =============== + * Seunghwa Kang, Chuck Hastings, Joe Eaton, Brad Rees `cuGraph C++ primitives: vertex/edge-centric building blocks for parallel graph computing `_ + * Alex Fender, Brad Rees, Joe Eaton (2022) `Massive Graph Analytics `_ Bader, D. (Editor) CRC Press * S Kang, A. Fender, J. Eaton, B. Rees:`Computing PageRank Scores of Web Crawl Data Using DGX A100 Clusters`. In IEEE HPEC, Sep. 2020 @@ -58,6 +71,8 @@ Academic Papers * Richardson, B., Rees, B., Drabas, T., Oldridge, E., Bader, D. A., & Allen, R. (2020, August). Accelerating and Expanding End-to-End Data Science Workflows with DL/ML Interoperability Using RAPIDS. In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 3503-3504). + * A Gondhalekar, P Sathre, W Feng `Hybrid CPU-GPU Implementation of Edge-Connected Jaccard Similarity in Graph Datasets `_ + Other Blogs ======================== From 6171bd93eee9b7c761162e913edfb597756a5ae9 Mon Sep 17 00:00:00 2001 From: Don Acosta <97529984+acostadon@users.noreply.github.com> Date: Tue, 12 Mar 2024 09:20:03 -0400 Subject: [PATCH 7/7] Added a weighted example to the jaccard notebook (#4222) Added Dining Prefs example to show weighted jaccard. Authors: - Don Acosta (https://github.com/acostadon) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/4222 --- .../link_prediction/Jaccard-Similarity.ipynb | 420 +++++------------- notebooks/img/dorm_data_diagram.png | Bin 0 -> 21289 bytes 2 files changed, 116 insertions(+), 304 deletions(-) create mode 100644 notebooks/img/dorm_data_diagram.png diff --git a/notebooks/algorithms/link_prediction/Jaccard-Similarity.ipynb b/notebooks/algorithms/link_prediction/Jaccard-Similarity.ipynb index 86bb4d17c22..9f62fd4f421 100755 --- a/notebooks/algorithms/link_prediction/Jaccard-Similarity.ipynb +++ b/notebooks/algorithms/link_prediction/Jaccard-Similarity.ipynb @@ -8,12 +8,7 @@ "# Jaccard Similarity\n", "----\n", "\n", - "In this notebook we will explore the Jaccard vertex similarity metrics available in cuGraph.\n", - "\n", - "cuGraph supports Jaccard similarity for both unweighted and weighted graphs, but this notebook \n", - "will demonstrate Jaccard similarity only on unweighted graphs. A future update will include an \n", - "example using a graph with edge weights, where the weights are used to influence the Jaccard \n", - "similarity coefficients." + "In this notebook we will explore the Jaccard vertex similarity metrics available in cuGraph." ] }, { @@ -23,48 +18,30 @@ "source": [ "## Introduction\n", "\n", - "The Jaccard similarity between two sets is defined as the ratio of the volume of their intersection \n", - "divided by the volume of their union, where the sets used are the sets of neighboring vertices for each \n", - "vertex.\n", - "\n", - "The neighbors of a vertex, _v_, is defined as the set, _U_, of vertices connected by way of an edge to vertex v, or _N(v) = {U} where v ∈ V and ∀ u ∈ U ∃ edge(v,u)∈ E_.\n", + "The Jaccard similarity between two sets is defined as the ratio of the volume of their intersection divided by the volume of their union. \n", "\n", - "If we then let set __A__ be the set of neighbors for vertex _a_, and set __B__ be the set of neighbors for vertex _b_, then the Jaccard Similarity for the vertex pair _(a, b)_ can be expressed as\n", + "The Jaccard Similarity can then be expressed as\n", "\n", "$\\text{Jaccard similarity} = \\frac{|A \\cap B|}{|A \\cup B|}$\n", "\n", "\n", - "cuGraph's Jaccard function will, by default, compute the Jaccard similarity coefficient for every pair of \n", - "vertices in the two-hop neighborhood for every vertex.\n", - "\n", - "```df = cugraph.jaccard(G, vertex_pair=None)```\n", - "\n", - "Parameters:\n", + "To compute the Jaccard similarity between all pairs of vertices connected by an edge in cuGraph use:
\n", + "__df = cugraph.jaccard(G)__\n", "\n", " G: A cugraph.Graph object\n", "\n", - " vertex_pair: cudf.DataFrame, optional (default=None)\n", - " A GPU dataframe consisting of two columns representing pairs of\n", - " vertices. If provided, the jaccard coefficient is computed for the\n", - " given vertex pairs. If the vertex_pair is not provided then the\n", - " current implementation computes the jaccard coefficient for all\n", - " adjacent vertices in the graph.\n", - "\n", "Returns:\n", "\n", " df: cudf.DataFrame with three columns:\n", " df[\"first\"]: The first vertex id of each pair.\n", " df[\"second\"]: The second vertex id of each pair.\n", " df[\"jaccard_coeff\"]: The jaccard coefficient computed between the vertex pairs.\n", - "\n", - "To limit the computation to specific vertex pairs, including those not in the same two-hop \n", - "neighborhood, pass a `vertex_pair` value (see example below).\n", + "
\n", "\n", "__References__ \n", "- https://research.nvidia.com/publication/2017-11_Parallel-Jaccard-and \n", "\n", "__Additional Reading__ \n", - "- [Intro to Graph Analysis using cuGraph: Similarity Algorithms](https://medium.com/rapids-ai/intro-to-graph-analysis-using-cugraph-similarity-algorithms-64fa923791ac)\n", "- [Wikipedia: Jaccard](https://en.wikipedia.org/wiki/Jaccard_index)\n" ] }, @@ -94,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "scrolled": true }, @@ -119,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -138,7 +115,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -157,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -170,189 +147,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
firstsecondjaccard_coeff
54114151.000000
54214181.000000
54314201.000000
54414221.000000
56115181.000000
56215201.000000
56315221.000000
58717211.000000
60518201.000000
60618221.000000
62520221.000000
2997130.800000
2856100.750000
388450.750000
44319210.666667
5029280.666667
58417190.666667
22313190.600000
4532330.526316
3107120.500000
\n", - "
" - ], - "text/plain": [ - " first second jaccard_coeff\n", - "541 14 15 1.000000\n", - "542 14 18 1.000000\n", - "543 14 20 1.000000\n", - "544 14 22 1.000000\n", - "561 15 18 1.000000\n", - "562 15 20 1.000000\n", - "563 15 22 1.000000\n", - "587 17 21 1.000000\n", - "605 18 20 1.000000\n", - "606 18 22 1.000000\n", - "625 20 22 1.000000\n", - "299 7 13 0.800000\n", - "285 6 10 0.750000\n", - "388 4 5 0.750000\n", - "443 19 21 0.666667\n", - "502 9 28 0.666667\n", - "584 17 19 0.666667\n", - "223 13 19 0.600000\n", - "45 32 33 0.526316\n", - "310 7 12 0.500000" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Show the top-20 most similar vertices.\n", "jaccard_coeffs.head(20)" @@ -372,63 +169,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If we want to see the similarity of a pair of vertices that are not part of \n", - "the same two-hop neighborhood, we have to specify them in a `cudf.DataFrame` \n", - "to pass to the `jaccard` call." + "We have to specify vertices in a DataFrame to see their similarity if they\n", + "are not part of the same two-hop neighborhood." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
firstsecondjaccard_coeff
016330.0
\n", - "
" - ], - "text/plain": [ - " first second jaccard_coeff\n", - "0 16 33 0.0" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "cugraph.jaccard(G, cudf.DataFrame([(16, 33)]))" ] @@ -443,19 +192,75 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "We can use the `cudf.DataFrame` argument to pass in any number of specific vertex pairs \n", - "to compute the similarity for, regardless of whether or not they're included by default. \n", - "This is useful to limit the computation and result size when only specific vertex \n", - "similarities are needed." + "---\n", + "# Now we look at weighted Jaccard!\n", + "\n", + "A full explanation of the weighted jaccard is found [here](https://en.wikipedia.org/wiki/Jaccard_index#Weighted_Jaccard_similarity_and_distance).\n", + "\n", + "The Dining Preferences data set is a staple of smallest scale social network analysis.\n", + "The data represents the first (weight = 1) and second (weight = 2) dining partner preference from a survey done in a small school dormitory.\n", + "\n", + "This data originated in social network publication by J.L. Moreno\n", + "\n", + "Reference: J. L. Moreno (1960). The Sociometry Reader. The Free Press, Glencoe, Illinois, pg.35\n", + "\n", + "\n", + "Here is a visualization of the dataset\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### First pull in the dining preferences data set and load it into a cuGraph." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import the dining preferences dataset from cugraph's examples\n", + "from cugraph.datasets import dining_prefs\n", + "# load the graph making sure to not ignore the weights\n", + "G = dining_prefs.get_graph(download=True, store_transposed=True, ignore_weights=False)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do the calculations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# calculate both the unweighted and weighted Jaccard\n", + "jaccard_coeffs = cugraph.jaccard(G)\n", + "jaccard_weighted = cugraph.jaccard(G, use_weight=True)\n", + "# rename the weighted results\n", + "jaccard_weighted = jaccard_weighted.rename(columns={'jaccard_coeff' : 'weighted_jaccard' })" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Join the results dataframes" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -482,46 +287,68 @@ " first\n", " second\n", " jaccard_coeff\n", + " weighted_jaccard\n", " \n", " \n", " \n", " \n", " 0\n", - " 16\n", - " 33\n", - " 0.000000\n", + " Lena\n", + " Marion\n", + " 0.125000\n", + " 0.076923\n", " \n", " \n", " 1\n", - " 32\n", - " 33\n", - " 0.526316\n", + " Lena\n", + " Adele\n", + " 0.142857\n", + " 0.090909\n", " \n", " \n", " 2\n", - " 0\n", - " 23\n", - " 0.000000\n", + " Lena\n", + " Ellen\n", + " 0.166667\n", + " 0.100000\n", + " \n", + " \n", + " 3\n", + " Lena\n", + " Louise\n", + " 0.200000\n", + " 0.111111\n", + " \n", + " \n", + " 4\n", + " Louise\n", + " Eva\n", + " 0.111111\n", + " 0.076923\n", " \n", " \n", "\n", "" ], "text/plain": [ - " first second jaccard_coeff\n", - "0 16 33 0.000000\n", - "1 32 33 0.526316\n", - "2 0 23 0.000000" + " first second jaccard_coeff weighted_jaccard\n", + "0 Lena Marion 0.125000 0.076923\n", + "1 Lena Adele 0.142857 0.090909\n", + "2 Lena Ellen 0.166667 0.100000\n", + "3 Lena Louise 0.200000 0.111111\n", + "4 Louise Eva 0.111111 0.076923" ] }, - "execution_count": 7, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "pairs = cudf.DataFrame([(16, 33), (32, 33), (0, 23)])\n", - "cugraph.jaccard(G, pairs)" + "# Merge the two results together joining on the vertices pairs\n", + "jaccard_merged = jaccard_coeffs.merge(jaccard_weighted, on=['first','second'], how='left')\n", + "jaccard_merged.sort_values('weighted_jaccard',ascending=False)\n", + "jaccard_merged.head()" ] }, { @@ -539,21 +366,6 @@ "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", "___" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Revision History\n", - "\n", - "| Author | Date | Update | cuGraph Version | Test Hardware |\n", - "| --------------|------------|------------------|-----------------|---------------------------|\n", - "| Brad Rees | 10/14/2019 | created | 0.14 | GV100 32 GB, CUDA 10.2 |\n", - "| Don Acosta | 07/20/2022 | tested/updated | 22.08 nightly | DGX Tesla V100, CUDA 11.5 |\n", - "| Ralph Liu | 06/29/2023 | updated | 23.08 nightly | DGX Tesla V100, CUDA 12.0 |\n", - "| Rick Ratzel | 02/23/2024 | tested/updated | 24.04 nightly | DGX Tesla V100, CUDA 12.0 |" - ] } ], "metadata": { diff --git a/notebooks/img/dorm_data_diagram.png b/notebooks/img/dorm_data_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..e0780c9c8a32bc7c81c2a1e0c80f2a056acbf36b GIT binary patch literal 21289 zcmeFY_g53m8wZLapwg8<=pZU3w17wl5fPQ%Lx&)s^oTS=QRyKRMVfSxUPBFC$_AAt z9U*{(-U+>P^M3C==brn+{R8efcXsmZ%hkzEzJPEI<)j~1XH?XGz0t3M_y>*H7%g_q#R4Sxwt`$xIrw`Fa~NA4K<39Fp8Eij+P{4 zljxY=+Oa*PtP|U#AKR-R*Q*!Tr}wE(=M#zc>wfCj``oYl`HwE4r~pms_m8*w+6JT^ zT)j$8PJZng1qCG~71j0Y*KgdQp{1i|V7vupWn<^yggN6%q(r;FYO&)JGpxx-gv(C^7ird3k>@3F*rN|9UYUHl$`Q4 zH7z|OGb{UB&iCBB{DQ)w;*!#`ipr{*+B$4~Lt|5OOY4uej!ppA_4C*7o}uB9(XsKr z6O;I<>Dh(FrR9~?we^jIgTuqaqod>FlarIv)3dYl^NWj%GOzIv(pYYIJTvtqBctp1 z?{}rkz2GGona-W(PgD$jtv05jGA)+;#t%(Z3-rtqr@m9`?a}#q=2myUXOWG(A*}XG zYgNeg30zR$budbQ@|s|zvjJ1wS64f$3XS?gfs|M9&l2?|_sl=KOz@~BhJDD@F0mMH zaJ}P4DZ(a8>v0PceuZbnMN%bg>F0ZxQOI5h1=>J(aklro2wr>OJ;qP7id9gBIUq0FjC8qJqrO zh1|mLN6Yaxhh;NOuJ&cV`^3zo+bm4h=Q4u++mNmP8)ZuOl`sEMOFx`?rL@CN_(Fei zKQI|$g*<3kds^b#AeEczm&#uKFB_#_3P&WX#meam$0-Hyp8m-u_3tF#{rql;V-PbZ z_r+lN_u7#zu^-V_V&zYSgqz;-`2!z*sZal`EN>2)88A`y?keVnoXN(ut?trh?Ryg` zlv_TcOt>2lf2ezKo;xdJG*ODi2F1liU)inUjYq1dAp{oN46cha+M~%nN8)TNeK4ze zfd46*LbdxwJ(B*%Q)07B%2>v291xz9O&N?%-TleE=)_$gr zy@%H;FmliDOnB}vljpFUZ~h2Rz2GX-R<1bK*7uj>;@b7b=qe~m24BYZfy&(n9gTH* z=YI+BWP3?lxr69ADw6cPU!j!E1?OoRU)e0%LEhZ7bF?1ebWFq|I;l|Ww6wnizQ(T3W zje`uADIM%~mx2SCCgs(2N`J?XhphV-7Q|3If3xah~lJH|E3(aukPYv z4ABL%KbZ)A3YA4Dq6$}LGi9ASpycHD#CzuASQw{c|5s-=aqq(~E$azD>NUbJAezlb zd(JTv5E~cUe0>$5+K#%I>yPzi-uf3)vM(I$_3G$w;C6v8c_1nkp1`J?I$v?SJMV{M zKSG0>UL%IxyWK;$(?Qkm-3fyTf+Js~$abi(A0^wD_cOKRY0Putdaz#jF~D4616qyiJ0` zDqL*B0pLh_(NN>7FYF*Ds-we=Tx95n#}t3Dua2|PC+@N08SVo-M@P(%MeJj2<%_75 z(9=q9o1=E2{XWL}rvd)blX7h+Di)>ONM<#{toi5&_v2{#~-#$%|msApWM;$>Sve9%&}H!8N}-OLrz z`3!*W#t##5%-tA`YfuOmeAnbk^1N7jf-(!St^8+yM83)+=l&^Vx}n88k+QDqZ|0m4 zv$vzV7sTqlhRT(Nt*CeTT8X!hssZGTY8-c*bud4BDg1`-Cz~|f$peW+6GpIwbH8$e zve!zB{Pu$AR=deygR7cC{0(c1wH4}}k=kuv|40#;6|ICZ!U@z22#e2o?Qbd~GrOTx z4&1YTvvwYyalT5GdAX{yp2MG66#mOFwGB`a%MjhqL9d~ zEH%-Um=?+^lM>SoWy$(55tC|eynn2(F>z?mZ$ea5AZ#EiwSFg=O1k_CuxCniGT*sZ z#~j&FZHY!LPN zuV7w@=!uEnx@|u8b5?$4@}AQB&8Gm+Y_5+< zZY_>I_xEBp=?2YJfZi@}^ntju%Rlk6XA`_Te~D9ed64^7y#(8J-S8R11&%*x@d*+& z@IPDOqat<;h$UIOI`0$9U%=F@0n&Tq2s9RAM&=T0u#YgSHQ0ZiN1x5%3CDTXty}MX zZOcLlJ{|c>{HM9ZW{ISmh{tFXq2T0(9<+h$h94~a9XRZp|4NCohdWCk0~OG1r~$-x z9KlVk8SE0hpU?Q%C`B#&3VP9`G<{Zs-5#l(XWRdFcCXGR`XqM)N&ko1a{+a@WbHXr zzU2^fzMGQ@bm-#6BkUDoWVNGZSaGfqyj)%Q&-top*X=0a#yuiPKHVO4{XKUuDj^~T zT8Cd@%>MxkP6l~um|pz+ts*eW>wv@SCl2s^H-&c5O8P%5e5-H^=t(w5y{X%^M}s`^ zkM2|%0TmCf0hAG@Q0d=q@^=@Gd4-i9hypB?1BynW<`Z|#BIw4tm8I?gN56$jlG(K$ zd&ti|w$t-p;QqCh!s*#T=hl9FD_~ff4M=JgWQ>AINOvwTYJ>|9N^Y`ah~5}jE;9!1 z<~tBQc7>SC)bHbFX#?Q5-!P%0pChD3duVw~8i2QGq`FHd_Bpi7w#rU)Lr(J_+zEGs z$J?mqP!SOsF0C6u=k=Nf#m=a@o~$9wJ5FV*3CK`%ax(2YNhP^mT(>JlpIf1ka=48M z-svttM74Bl`!fp8X_pI+F^~DmI+uEX@Wwms(qu+Yav{8nOW@MbOC|`bB-~Br!|eN_ zA~X*e4NVyFJSuhi z!GKe0vPqFFw^p4)fsbN$ZEfCAx^TC8RV^&pDJlLs{8)4zgsa_3`iYjMFLl>`IL1FD zS@jDSlMGZD6ccX$`toDBRILj9?snbyy@l?ImyTdK&f@UB_@l7PPBtm@yh|#`*<66x z5`XG&z}e*^vPks?VcFre?Be(W<>;8F5JlT=X&yfn3ePC8$mE?F+3=dDb{-b+wd1Ms ztrMrVGih`&?%4;I=<3{OPmPZjK4s^B^d`|Y{f}VSik60+laYl*{)Vh^4{h3y7f(5~@ujs=n+cak#QjIkE=2rO18f2Nc#AOTN+hk$4?NAUx`i0zcVt4U}g z|JKYG2UEnquNmygUkR^j*m3=S+jWL(!_GSRkMCkzKDoME*1b8Bu3VoE$bt!1*k_K= zsv?P6DVFem+31K#(H~s`VTqB%0iBJ@d8&ibfQ!V0Mr&-~%`6k{`h_E0O{{e+wMx*Z zZlew!=6gjjDWQPxM|`t4i26{r(L&^YgWB;!LJ8yO==|0#KjA(dA8#&3 zF&&u@lgy(ZpTsQ0Mhk2KwwNa5oLT)3zZao0rT{E3+^w6l5%;V0dfcKv7^6|~nAVAV zBIxjVUAfVEbL!p7yNSmA!TmGx{rFo57u=)BL}7`qIVhUe3`Wmid3qd4m_~RHN6OhP z$QN|7#zRiR6#*P-mnq5U-{Pt(Oad-CRgKD#pfs%>&aNN)=%lw&k?{QNGg(B^>HxPz z#a={szHeWx8YbWagQQ)bMvmcny-iqlZvA93H4KjW02R3#pFE$<5c>w)Dw@GWb)KoX5uM7I{LPVrXBRm;&dLyX58up@Jrjk&ArhebZKxroT_ARy)kfo6* zCAH^k;vIm?mJ$jgf`vpsi7e4zwMwOa*(!^C1a)0^GpRPz=#UX5FrYxbyw0uj)!%BX z+EnQ66ye5rMQp-=zN!;0*cg8uJ5;-M#?c(lE@4Lw^kAJma&mI>zM%N{aBDB9Id`i8 zd1XF1YJ>o~_DZ<{J0wGy5LHF%=2JbK_7U;6GmuH_x({0uSO3qfEA)1%0AXNs?_Qea zHt8m{W~^Q6cuOKjMwL8*F7z)uSA`2thi@R&?l<>TfeYqz*kAdT$JjU9IoY`fX43_Q zRGpllr4Y+mAg#HhsBQDxWw@JoQ4dpH^&WQC@sUOm8+b_mlX_V&K{RA(OLEqqE~jM5 z?R779jcvbpM?u6)Vsx8x=ySt_%qHML1@v8j-I|{^Z_p3WV12r;9W^dj=g_vOn+=AQ z``i!uN67789v_z#^8TJ2>u#K8vg^FFy;^YipwJ?ANYdY~`sJuq;yxAZB{Ev=YEHcB zVL6Aj{f~z_F%d;rylg|m2yXuG*N{1ZJ(jbpuYSomM;=iW{@z&HN21hxO<(IJ!wS%>M{k96o*UD zbX6l|ttU)K+lxUJvGu^}x;Lfp3r%7Z+fq#Osh%b^uI=!-Mzg&aZOK@)j{!3pV8^vw zvoizEdHRas-nBHS{QgY`>&9KjAfo?)!70CVz-#GK*dceO5VoQEWm-qlQ6GUVnXR^V z+EFizcveHRum*Mw(o)Pm^-S2mde>|nUini*x{u{Elp%(YodX%dy_^5PL@!0Qjs zLyqBYL!0(L&LaIf9POtMKF!O`&m4PJ|Iu;dqWT7l1m93C!Nxt{*rRZ>cJ^zmsiq>{ zxl#Y`$Y=ZSwaXpXvpTWg9-*&(71;_H^CpI+cw}D>S(ucvG_o${8fL$L^XDLO>#EHO zGfx(t`lU&ys%}YYrzRH0)KO3PwT2A!I~$$;Grh;e$u~D6!jRZJMQ(?csD7C+mvpS< z!#S)S-eeAW^ofSLa}RI(UXvDGah>Dx6*)YuYk*VUN%IQkKAa8B{yE%xWo^?6K>oBp1Vi&5SaWaGgP;F!S0>xf-gCQ?7$CrxDw$9dita9wXo5}kC z)qjw*s8&6sq|94|)eU%E-(g5H2-)lPJx>zFev#!L92f$;hblwAnYriqzDHU2S>Jwd zwZhkZQHUWVRVj|r-g>t#&YAj8jvC>!@O7>t13p3wq|lZpsSsdi9K?lmFH6ywghOx!{2 znGHbOJF+_c7w~9=^=)Rwq=SYY-(nTAx{C8g4UMZ>noPOZ!#5@+Q#Pk=_JRA-6HDHl zRqBGuqHhwHStse{S~i8X^$-Z!iM@>x8gNvzx!C2hdv0Me^WZC@WdWnu6*~>!d5P{J zmzK3b({2w)=c9!8d~!13cz=L*(bl3!40z4^&?>kr*F<>hRr_NJ@44$U5+e6*S@AZ# zFa!7;UGil@yGn)#W~+C8>)oVn&(<=>bbNWPi7<{RG{Jx!0LCn zBHSw1s=6l!R-zVeV)2B-4J_~X_T*c|?AFWoD3iL%^>s!$hQ=qG4y(L}w6kVTxO<(m z;EuY-8qBf*dVe~%IyIqC?Kyj8W1Ryvta_&x@X1!Alk2+O3mgc@9xVi1XC#u!Z*X70 z(V+aD`TK>w-j@dlr&gVX3eA@TQ|mDabB%xd%X?bB+8>r^a!vGlmnGyMnCV2Y130e! zC50Qzf?yNIa5kJMPy7=aRtU2gv$&6D_&Y?j^RRrlLw|Rqg8Xp{%a&@ZC5Gi>va{+kq@mP-EKS|K+%j+-*Qgm^0c zcsmc-$u07Rp`ptKN(-(Sf^ehVN&DZ_h1boaECeM`<4pGVLO(X?mMODjTNklgZf4bA zY{Hy}?`if!N@6KNL3dI%rk-$k^m7D(_AO@4?B3a%rc%kB&(S%Mgn%#{V1 z>$wf0k$A0)R6mHB>e5~OqcvBpGc<`!C z7)$%-^JShn3C%0ZNOp}hRgm9jX}zA~XJ_}$fzs~X!>KF%`8EmW=@|$x5PB=QSEK{$ z*imYxo^!Nt5_K}XD~|@Y!_B3hs;_)xzam2MszGYX!g?lzKHdwPwb2^kApLMude$>_ z&G*4Yx!@HbIgrEJpGqX8Ko2n!hWvALvSaHzAiyekAd6-xua8Ml2{XSU5_%n7=6IJR3y70kkzz{9gbK3}UKI--ZTz2=r zrrJKKVGmH*8%V4$31N zDfVQY7DV{j*nHJKX1qh(TMsRL-Aq$CZsN$>8M*hRGlQd;fu!Pa*wPXZ0=19u0pMar zX-Qn_HIY1S3lk4B?yidGO{oLa*Y>WM7LK#bad_y5>_-F$ZQ2EcqY7c+*huKKz$*mV zY>gT&7ddpm?EW|=2B}-}bUhgH?IR`vWv6Tb>A#5>Gjc)#zExQO?_D>XzThLu-$(Qt zUzt%vKEaslX5rWz0%e;>Z0;uCBmeizsXK*OvCgdVtB5gMmr}4hYqd;7rfN+n9-sBe@evaWnl)(l;(sY;ye_oWY=ou+E zYIcn(ZHXwtsQ`*rEgJ(?oLnat=rDY|R=aqZ5_Qkc)&#G#{U=LE&k&M_avs04m3S^B zy5pLOjT9KYbCB;V6Okj3H%axxTKYTc`ha~TX3)_H%6VWj@dyevZhC3VDwg;1D^93} z3!?+5s^DD{CN!Z_PbnKjH_EVPRZgi@KgIqSBI29mtbTSn-s}J~ID?SIEI?vo&CW+i zVqJWj8y!oU(O**`T&L+D6RBd{L?5%xhc7fB>5n1~fW#F*kMVGzdz)5{G-#KPzAaC+ ztR=BUDC%|&rljA0YX!;*^Y|t*DD^gXq!$X=-wY-H5UOcA*~{{Lihm7e-#`CCOspsvt=agy3Wm*AjU>U7PcHFh>Usy@nyu3R99n4 zizt`bX~)}V;xUR*jXax`FI=;H|3y$0`#xuER6v=LT^JAzkdP#WEOmcf6=`&M8CJi) zczAe6@S#knyzvZ<}#}6H8q!YBw&%! ze{pv^ly7U=u3%Lyw1)H0+T@PwspEe_ z$QzD4ZI&mLQV;3mVC!uCc_+t?9Gq=r#!nxpAi@!u5s^>e>FofAAnle0!m9YD8pgj?f5^-WtAYj$ zUFx5^q+}BoN4Gcm*V;=pyBkXN?T%oEC!2^*F}-}iPW0L3?cqkc8AjUWFTu{wH4H5g zV(1mw@EQciTZS$XVK_N(6Q`~W+Wsm#W(f^*w=`ry-)K<-sw9V0dI(<$6Wbe(;CT33 zg>Zk?z=Is<`up+C70q-e^Sj15s@nh50-46fly~#w*Wy=Z7@UWzB-QnxXRXjqQ+yVd z9tWuFU3od+drh7e-q>Zu#82H20p9l=w*Qkl;E!;#9~FTxC2x0k$bpdnsUjUg+kjMo#Npf1HA6d z;^Ez10(7z3yxW(l;|BCLAd2RToq1C-j-6QFum#8ZG0F2jt)jD*mrYj!j#8^S^WMuw~ zBD6TVOBuUZ1rvUvk(Sv?+`1SDJXp^{f<|`6+QQebXPD{KEY~V*^%WoeT;QQ=sKL6+ zc$N^H516_2gjXM8oyYbSq8%y*pi-TzYL=1U!$x<=7x)8?TL^pDscxD@30CkZbU+_- zvK%dsJ~ieDK@6+Vtbi_exwWlarHt{0yWW9ALS9dB0W+o?aBw(RCv9r48IJCd`7pg) z-95TEvDr(jLxyc`nwvOljnM57iKszQMI?iy6Y~gxjHf0XZ|oaDm_#B2XjjO~fCRSM zDKt=7hjDgpL`b2jBFAG1iRR}v&ZmKF2M%fzZBj2fEop;?nG1gR7c$Y`hC%cO(q?suxJKV>#i}bC%@fi+p`#};Ee=&yVxZq#entswvp4=_iz{Qoq&Fh zV0+!Oc14KC@?K5+NW6Hd`RakLQ-7EdDw?6Y<)YM`o}(fHWDW`dpDJvvtfJz^GnZchpFYLQ zp3;)ceaojr7`Tc6pNKw#p~=rIm%lEE$216uya59xzGyGqbtB8`odnJaxhBQb-VUX2 zv9>*KOsaf;)q7Ep6(4Mu&HAbP$7^SfM*_JToPz8lQMQwMG#G*&Wu2m#Q@I~il8)a6 zBlV9c(a@J1p&s$DHBvuVSn=Cb@xKlSCD5SEiXLe~)SH_ujZ97)8tt=P9>K6AJq0!( zomDKA@B9^Z@GU7Z|-+u+OEWZmHj`D213gY)JNfpf*Ci%CrZSezOzYD#Jz6v3d`s7-u8G zOcA-8jdh8-P+w+ckoQZL6FtFh~_0MYTOn~ zU8P9v=2UM|peVql=9~Imt%E5Q7x(`4eaSDgJSdi2KCtJ1)Bg5vxo$x?^e*q3f0*RD zIiewC>M6<3ZSK+W{oZah+Ow-NZ|L{!YZ5Ns#^Vsb->bRhB!glh`L0MsyMnu{X8sfb z_mgdJcgg`Q^IqQ%f*0>Hik;q+5^%t2=YS>PXd{90p4jW^;MFW_3Co?PbZb*3-I8Qb zV>J9{GO_8~l6c4x*M#j(JHf z^V8B=Wmz|BlKroM1CS0m9f=)NqeZhgI(R15FbwGOk@(&%CzPru( zCbd{)So^fcSz-1`jb7tWQ>`C>eUVKw0smNUHi!i}4%>sdtwMgzRaa4enV6{$1FOdsfL2CZHrfjL#YC1zh~S|tpe9yj1! z(OIHEwS^`=M}KtcXpHMumhOCkn*B2vOB&thn%VoG%zkFDMmgTmC}cV0Q8~UP+Qy8S zjqfNtAc(cokAijNgHqk$}SpGq7N~mg1$mgDE?uha)ag=8VN8 zW;!rhHqL;NY^=?m^lq}Ge9~2jO@Nk!#WF)7d7)0Yh;-^-kbbtT+J-nCz#e4eXMs~XIE7)DgD_(#(ZwZ3h+)cU-vC*tNI=OBZQz}u#uTizFLWi zo2~#ZPaz$Acu8XyO`yctV2_e}q5G59eFrjybLN-4^NQhg~b!)P|Y2 z5o_qKC@dU&8h_8KN;Lnz5RDH%T1s;BAzI{q$)ucDzbS+ZO89EKY>h?F9W*xuYXw)c zSH5vYr;Cx}r`CZ2;nF#v>w0b@`aHEkcLB~o;_AGa5-94iv?wl|hx<1=0e<*8b z%-Bm#fKk_YGZbJ5=~_f%6(6q8~^?_|96_Mgy7^H$rza~ef!PS ze|cr{-)R~|>7+Pl@^z?L)OaGXfT^Q2CGiEtQ}mLx;%VM7-dMdOu2=S&ZrU7Y6&x?x z;@&z=@(f4&G!wUtxIbKuWi*bx4xMNu@KQ-znthv4M?C8kLDPy=Hri^YQRgY<>!M=Q z`)ljUQZMj1jepa_cUCOivrwtQE9-lyc0lOAry3fStFyi%y6j1j{;=Ke3$?q^*}mr{ zWzMMpUH@w~%li>YFsR6$9&M?2f?S*}V}?(+Db<9i5{(BU+DwSD&Bq->NQ5Dv8v+lf zfnmjN+mehVE`6B;PB#`uAFkH8^WNKIFMzSO4;xg|TAAf9FNLls?%izJp*-Mw8`2L$ zAhjaq1j}-a1nAtrsGs8``E$XT^RK?Dj{MO#6tVTMOY^fn)VA~ z3uRr%hMa7)+d|eo%>lbrRexdeVv_EeCh23{;t!tTR(Qjv4|>KQVWw{G`-0S;TC_G% zl9V^qLJ4DmO?D!rSM0}Z3P~;3UKR*|U>L_njLTIl+QAm^uN~N9*RNzyMSe{A(&GoT z@PSF>C3AYGdI?tcxI3Tx4$znT$Y>;yPNN;8Cj~tSGKse@zcPqJkepET`=^6sr4b*X zkQj!Go!5HI)`*QeOW~)y=(c+%y$7Y>tdE0X2;tGZQ~ys~6oh4Tk~)IeFejGDj4W2e zJireXZG=7njn&4iOhqk-hrNp{FI1@ zfTLnv>TLe11XEWfCBhI_iwaLOL(rq6P@e6IXHY>CqtqX4URUmKt~wPzkv*C$eGTd| z&$qhYsRm%RYt_5&79^Q!CVjoVIU`;qQng`?7!Y}86gwedFU+unMH!)!ZbAi_-vd5k ziJgv%X?IM^K@8wQqec>b3>A>uQWrA0 zmW7R@d|T)!QyOF2CAes^?3scn1HF8q$dIb#e$ELM-OZ?ASLk`*L*b5bwH|O}XBEU>vF5HR2=vIx} zmY~cZiu7}%c$EebfkQk8W!PmCm8v3{2ZUdkaAU3pJNpxowN@2RwLs-he10sJ&=-dA zqL51fIq?D63`tDoO5VJu6FN)}J4ORno^C|{{NU`lvnY}Tk&jiC8Rb3riJdp;s%$_) zuy7ulxO|;Mne6RNxA6r?>p!ZQztR{PArOZjj-)8WWr&r8qKs~d+R#)VIqg1wurG%)4q=)F! z$rokr6!|&h?-4zw!c3Qgt*7E~fd?(*(g9z?JUxBcbJcWeh&!Zc7=f}`ufIjTAo73B z&0E&>Z&;OERQ-jXRXkQq;h0e{H1yvfPPv?Rs>p>dA70X$y}zZO*4HOO4zw{OfjYd2 z9a~K19p7QXf@4L*ycbcGWd)#ykf_Av;{Ac=v<&P*ZWU|aq;-V;7k9}qh$)-tYUK;^ zzHyoG1!<7>`ES@;bcdiQW!2F?J8rytE;g?=2c95$3w|rzTy@XX5r}@ZRKM1H8#N=$ ziW5ETJi0@cS4?JkD+XzpCIRQ1lZn`sayr}oBzA#xL(L{?YPk97mvn1F(_!X5jZu$? z+zC{ek|cyb<{9km>n$o%g!v2L69YByc&C27HQ--ZqZRUqk{pO1Tol58fBj2@Ry0S; zPEaF0k^l*M#9y{Fi@sz5 zZ5qT&Ty!wZdxdJ{fknNitB&hCXo2jYRNC7~k3aKxC(Sl?{7yS`te51vxHukLX0k zkrU>b>Zq4%+u8MrizhGlhh3ma@rEU;9I=?Eq(CwE2U=5mT-WY^RLNKcTB9#fKB7;7 zHgKb&m>xwIwH-OE%z1dScr7*lZ=sQbu=28)aGM^77KRSt571W0*aF{77Tpy^8S-;s z0jFC(#-Ct1j=N1sF>o%L_t&GOb=fm1MLI>z+*QS`BGha1LDiw^0cp7 z4vu-p&qgiH-ksJy(F3nBsvF_Ml0i`xOwR7-Z{<$gg{&^`EUUG#!-NO-k|?dt|=A$3^=_XURz+iVrJY)a#w{PpJlMP#CQT zF;N+>E3K>*ik8!N-Vn$e>y^?KX6^EQ<~5fgB3w@uAYjG?+(!opWxRZww|~K73Cu8^ z%9B&Z(1J3<4R%$FqF=0G%vG}0jPVyedChVCW={Q2hzUV=OGDS?DItWp?O% zZ=$~Tl%)7!F7S%6@Z)(}|LS7wOU(PO>j07UuuP^OhOzx95$SJ-DR_?f{lG^j(>Vw^ z_nj$~XY;X^=LHvIJmJ{0;=qPIg?SLx>3eMgM#N}zKn*xfxr>O96GTXwVY0WyF zBSQU!=Wm`ZP|%DWNgnAmEx8H_DPBRbaEI^!N_*X+~47;kp-C#~QCBlTzZg zK|%|6^70{F1*~0lh2=av^kqrsazZGJirGHf&Y+~hh&H}?GuG3lU-ai zbrewHmV-hEINx_C%Z0(khyAa1`T~#D#32sE|1HiSAkn{rvVTwBV@5^7e#!8_F%G4f zU8)c?d`XYvIp(_cbVG%#up<%K|9(MI(!WpHpA4vYRNiKn`V^y%t00{HIg!(}(ru!1 zJUt9zl#YKL&34f>=h1CCst}#iIyDg*q?KerJA)80JZ@RZRy@)-$TG$AO4$WP6jypQ z`g&Hj+=xnDgGayffPAk&nPo z&%6%I$;*nzJ`CwHX>V3Mq3w(|F_~} z)(s(00^64j*w#pICoi)c(XsF2OI)I-k9L%ga;&bY5~L`$0p!7AoLF0CyhlXn=+lxx z%s5o>d(wmu35j1spc}vC?cI!5`Z&WWqE>KR`Ocn(Wo9yT36(w5)p+CMP3FSd_bZ+& z(?%tNT0pe(2%I5S6VZx3^%IxA&k#>|`GOroVAbMfx=dV4*YxuZE?4}4+PwW3@J-qb z^S7&_^I3XolEyWsBzF4PDU5;;Vp2FKwCQ0Dz30AF5u^70lShjh4GJL%l7e$P7cq7^C&3tkrt6f_uewLXkOevQ4|E`A*{pr1l?MZg`%Oz-Y>d}kb{rD94O~8^WA7+_BpWga2_&0JY&lgf8GZ#DDsU%ubG%11_&q3 zJww(d-q}J0RFjc$tFpx)rzbR2aT;?n-{Fm{_;;&q3Po}V;Lj~Vt2~5o#vs%E?%vta z_DXfd{x?I2|unOK{sHH^RsL%~fOXnllilL8@9foXIJIeYe7wBO+cW zzcX5&)4p#i45;)sW~}4_2%4n9VuEuB`)A03(GWVbw%ck4t3(n z=I2`|>vAw|8uc62J3P}TQXT>ENYY1mbe)&!B8~jH0(E21Ud@RCjdBt8dR>USKp>-u z+ot`HRaDgYf&7hC?2O*J9+ZtP4~t;vz7u1YPPnU%@Ol+$z#%`mbvJEpS{=g=w0j7& z3w1KGI){~t0WFMUO!TpcU8CB|rypfEw-&^p)+kPsyc$x(K;m;u|0zl>!TfZasqSnR zag4mnaoBjWrff~};ifNEY|`cV33U>|#h}y4`h~{drOj_V$Z9O@^4(vC#GE$e{2H&I zL_ENK@^(8ef8*?~`pLz0Y>apcPM>o_$!~mls9s{k;Yn}eJZ)TXE^S3Za(`lZeDv_I z^Q%K`NfFW|W$*05GZjUSMexUik8?$n_&HXX=HZvDC@Ay*8;feM{N`XpJ&?BTu*B@ z&H{X4Rbf>gCnr%yU%15FG0QEFH%?$0;WP~V*)o=$UY6mNiWN88F$DvtCJP(>pLDU!P06kHvR?H$ zb`ao=0rs{XJ|Cos*k-vclkYL%7bOtwxDIrwXj{C@%!c1ujD3TIY3${Lt3JCZk`PiKn-T+I z;~e=lud2qMzFVV&=n1w}rPL8r{8PMeE%9eI#Y~7b&NxnF3u4UU?O39p=6>BrS9ml# zI|dyMa=!hUy6P`o^<%_p4%Yrzd+hr}%CsP0+@f;KlRg_(Z@MlXenTm{m)1vZT@L|0 z{fmDJiqYXhP>(O#4ZZN;oZaOGt5W~U5LJmBPe2R_WGDlFPU^#UXDxN{po*UNBox2i zjvCN144tKmP*q&5yk((oCKNiB&TOECqz!H{i_!VT))|9_pdz!Sre%?zxUso=B)z43 z(dSun(Fp!Kws`i(p*YA1^WG$r`L-7Fyu=4;x)ff?jEN`g+?3+@Y%1XG>2CQ9*hNkD zXgx0nRrtQ7YG-7ui>)5jx%>HIRn2du;p5Bd^_@78QaB6l4U;v`uYz&ZoIkDH12{RZ zg-O*_AT>{vW9*m5GBQ%5#_7ja6m#?ZaqP!P0vN3Z>Ncj(OQWATsm0j-lxecj%qOw! zPjMwLElxYBadEbw)s zTzGjzjyC)rRTJwU_+b+yk3x+x9HVJrGa`}Kz(x|~4q_Z{ch?Rc`*9jb92COzj9i!0 z-!*Xzh=1RCFHK{%{gyO@z8(=;^DQbvC}w}Eg_dgQi2w!cwF!U^v}L^8io>jKKZ@y4XLp3 zQ6fm&GlDlyUKY`!8SgF?GxhuD!|c^a@{0@_(-ak zhLqN{V@C>MoEeOgu_zy%CHnP7&Vpz}T^&`zps0WzHv13On?>KPCx$SyYdT6SQpQYi zFLQqT$YUVseAnXOzO)Xxd;7s1R_r}-&=JVRAYR5ky0L;1vPKqlwnkZ}IK^zMJ6fj~ zv_$>t3?uB)+`*ro{Xda1n!fBUfgit zl|9&Ps7l%s%JRHL1ti(IXz6()x;T1k#KW|0P=y_DU?C2S8PzKIND&8F7y1*5c$C5{ z_ip(&(v@@O!f)cT$llf&G!xv<{C8!ZdWE*lsFIf5n78|FalCQiOOVPT6*p`pOySS) zH)KaptsOig$saFuTq(?I59QxjwM`7zW?vVPvlwaPrm3VR5r^})4C`SZX zs4~r2Ek)^+0*~c&HEmomiUPZGq=x2+rxjk6w)qDg4T_b{$Xlgl!p{_xuP(!iK3=kL zqL@W)SA9ksoqW(}Xa-21Nc}!D_Wj(MK4{VMmNsGfH6|NsK5y-#=Fp7)O-2q{b{)3% z{Zy-)vt>oQs)p^JAkh8XGn{XDDv{m(2H@SeB|H!?sJHBvM4Ye7`Mv16IC%e*ly*+E zMmk%{FKG2VE&SjRP_4y_3Feoc?V))ljKx4Qq;M6!pO)2ZN!G+--ZiZqy}E?u!tB%6 zuW2GB0Ic-Oz~dnPDf~>_1 z-DO9Vznn&j7LjyWSzniS=6$$x$^CZcntnfnn4G)Jsz5J>I~^Q+R0NHwYMy$YKnsue zuJ$|m(!_qTOPlYElpxj~y|r~_!MLbjWBU^m%V&}oAWKT~##i56@de;wx^9`wAET-N zMOf@s)-snDzZ$MkjLBn~bk6jJP3<46<9}=awZ7wMI-qr*voWbSAqSLi!6Y|}bA=}~ z!a+xcu6*fn!Hwtm&EDsLA~&$As*qlkz48!h=3~h>7fHYC47dwtp@pJjqN7&3zHX-) zQMA}MDc!sqwaW$7rS^KaI%ApVhDihFa3#DsbF7XdS{FF~5!`G1(P@aPIl$f^Xxxs> zy}^0<7sscygjV=;rA>IZAfomU5J%oMkx)y3-j)O!9)lF6)>%wWYDzf>s|rBYE>)M= z>EXdHaI1(2HnJV^2$|~Pe*7`QWa-P&_i#xFnvOPQ+w2RRF~QDbF?x&HelSEU>nu!C z2s|PQ+ySexEk8Oi1%^nUkrp$3sYi6HOq+SckG2(J)IW6AU_APg+#Ir(o1sjvrQYA$ zoFecWF9^&8#~o{-P^qYzq**U95Gmm5saNw6sO)`a{fk}myp;^M} z0PPv)NF35M^SumJ%Ze+;T(HhawKP#r7pVWcE^LK8m6q0REWS#RW)PbY&cK1#X#Ypf zZ|x*b-!5X>!$$Ax5K-k#@V#yMNST`w@o*&@Bxr~68?0iDI9y^J)&_?7Yz`}mKlR`6 zg=4S$Fz$Sb-x#ElEk|PK^s{)DS#hI`gKqDi6mf1^<;z?wV|Y#70Ni@C`2*ar>kG3Z z4Xi^4Bwh8jORNQde`W7JIdW0x%uFNt*6&0v)zag6J1ibU(u1S54vNT19K?@xZne(X zyd+onhYLTr>U_*7<6Mo8^_@x8dqIfZ@jF~^1DATs>mf7hv2_K;CP0}^>!RtNgH0qv zJS?c?=D*r*SM7ki4CGL>_3n=kZmr=mzx4MO0=}M1g2k9!-FN8kTh4Py4#9^7xXEV+ z3=M?F>3{fE0aZueOVl5F<+SI4b(b9}&uEeJ6hthO{0)ad0aLCpFwg?XN?xVAs1_tY zzRRvfXNi>1Y*(fSzPwTU=Ct?>*BWS+X92IVdFB8Ee1I*35Uy%lsn}!4H`apAlrJf) zwe@qD+W-mW;RSKr47$_&Xn%goqxJ5nIEZQzCP%pXW%QAY`i0bU3h5m9{1doRa3Lph zqRPi`F^jSN!lSEWqJDbR@9c@P;nm8ax^B;f!w;N~pfIR(oZFp0}9 zJiwE*knHEO#*3+xS7=jm8nlTTN|2U%Nt{30^RCXjb9@|HhqdYe6b0jn7ZZ-~CdG)VTuI5?@M97hC}?W72u_rdD4P027_;NluZ>(BIxUkVaSqY^ zh?j&hWz!RfQ)$FH>eeZpF2O#BdNx58;}X+gsK~_#{jFcS0(Sf`Z?D4rLN}S#;DaDJ zsRwED_icsX`cDF)=D}04p1s5gbSF&hZ7`dN3X(uD?Z*}Cqu^2gU zY6~E&SZdb1PGr5o{Pw#~d@f|rF@)RgVMf&zn`SnN70i5TDFZY|_h?4FNT6`|Sg|W% z*Pop4P0~$Fl72>a$@jVs8}}WN~wMs2SK&)#d4?nyv3muETe zWYgQI+nRZn*5uFc(*&F^m0@~`sgJzqo{tJa4P7u`wDqF<3GrRTLT)*H`c$NCFG1A2 zJt-cpLhPofW6tmCzpGbpY!N@FChXgh-Ul7%_hoU*i|`Ge$?#?A6mZ+A`F5_61)}Zw zrM)jQTd7-zpG{`(AtYy9RnLTD!^Gq6PgD+8T_#SAJ>)ZsxW85rV`MDPAZrN-;f)0C%5P(|jam znyz$z;VHO{qMJAQ#N9s?aT273k?NrLNR<)t-g3XA>=;aGTx#!U)nIX4h?H!r{Jg&``CO&(DLWdsTi|{vlIm5 zcpK|xsKjVQ169$qWm_RL`Ew1}W9R6SW|S3%P=mTfIs)it$#LHx5q`b*lh3U5LCg9Z z{e^n+nK;Kz8-TJqA$!(9!akLj%6Z|zt=IO=(JVQ$@*Ma8=cDs1+&X_kiepY1gtIPa zhDw;?QdGDk1J`%24^u2C;uyg!yrj)0<$90z4^+)q63#?r_nh%|G+9ON;b}@L+8{y8 z;b%s*=C#0LT2jtVy85r#QbVV-v8V9UH^ksoZpIoPs)^!kyWWVQPuRv_g2+sneITF} z*#Xm*1F&7jPT_5Tt{9ZsSE8GHVEe1l&&ei%WEAg>sbKem@vE^Ab=UMDUVx?eqCMSV@646+b;_~{N?@gptPs^af2)zHOU}iAfAeyQbBwTw*@dNU!a8oLfE_NOX ziqxA5_+LgK@*853#y+_f@g5eTH()?g#3V*4YA z@gl7yqgM)_u(~4#=ec~^zz7w3QfL9hV<$QEdt10D)Ne%?1~_ew-tf@Pi=QuE06${N z@8A8LjXL}y3hsXp){{#@>tm3;vZxnCuO_YCBTM7)hoD`#zYjiA z-T=P!IcR_RB_$1vY45ANqcgI0WZzz5{M-VrtHD}l>TbQ|`CU44&{!$`!gfy)$+3EVd0-|ye z&40cJrHXhkE~cWwiHyMV@5x%yWGm9(V2NdW-i~Q%P{zL1EL@t>S`BfTaFjaeQ{4Ub zyeK*p(G+5$Sa(Oc@%v0J59I^|bbfqTPx&ENnL%#-=hMApgEi64)xJlC0e!eRdZS}L zGp?DLfZ`?^glrsRVzc+n%4(ij7NJCYB5Kjfz!5bY?lw!j@a3e#i|0|7M0zB0g0O0T z+ku{MK5EmU+wa7bJPA19%Kr<@hYR|1 zW6jiDg1J*!o;#<;#3k^~21&X0a`NyA-03q(_Wo&6_Z%P2f$N~qfkN%hy4)FVM|1dz ztNs6uME+%gKAzq9w_QD^F!#S}{NF)_mE560N%MhxT3eYP7@>E>+~md|)kdE2{{hMG B9zOs8 literal 0 HcmV?d00001