diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index aafa4d51ce6..4593962bb1c 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -173,3 +173,22 @@ jobs: sha: ${{ inputs.sha }} date: ${{ inputs.date }} package-name: cugraph-pyg + wheel-build-cugraph-equivariant: + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + script: ci/build_wheel_cugraph-equivariant.sh + wheel-publish-cugraph-equivariant: + needs: wheel-build-cugraph-equivariant + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: cugraph-equivariant diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 8fde0522515..9d0b682f2f5 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -29,6 +29,8 @@ jobs: - wheel-tests-cugraph-dgl - wheel-build-cugraph-pyg - wheel-tests-cugraph-pyg + - wheel-build-cugraph-equivariant + - wheel-tests-cugraph-equivariant - devcontainer secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.04 @@ -161,6 +163,20 @@ jobs: build_type: pull-request script: ci/test_wheel_cugraph-pyg.sh matrix_filter: map(select(.ARCH == "amd64" and .CUDA_VER == "11.8.0")) + wheel-build-cugraph-equivariant: + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.02 + with: + build_type: pull-request + script: ci/build_wheel_cugraph-equivariant.sh + wheel-tests-cugraph-equivariant: + needs: wheel-build-cugraph-equivariant + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.02 + with: + build_type: pull-request + script: ci/test_wheel_cugraph-equivariant.sh + matrix_filter: map(select(.ARCH == "amd64")) devcontainer: secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@branch-24.04 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1c150653bc7..b21229b318e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -75,3 +75,12 @@ jobs: date: ${{ inputs.date }} sha: ${{ inputs.sha }} script: ci/test_wheel_cugraph-pyg.sh + wheel-tests-cugraph-equivariant: + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.02 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + script: ci/test_wheel_cugraph-equivariant.sh diff --git a/build.sh b/build.sh index 0ba3a4defed..5cfd2b5af1c 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # cugraph build script @@ -31,6 +31,7 @@ VALIDARGS=" cugraph-service cugraph-pyg cugraph-dgl + cugraph-equivariant nx-cugraph cpp-mgtests cpp-mtmgtests @@ -60,6 +61,7 @@ HELP="$0 [ ...] [ ...] cugraph-service - build the cugraph-service_client and cugraph-service_server Python package cugraph-pyg - build the cugraph-pyg Python package cugraph-dgl - build the cugraph-dgl extensions for DGL + cugraph-equivariant - build the cugraph-equivariant Python package nx-cugraph - build the nx-cugraph Python package cpp-mgtests - build libcugraph and libcugraph_etl MG tests. Builds MPI communicator, adding MPI as a dependency. cpp-mtmgtests - build libcugraph MTMG tests. Adds UCX as a dependency (temporary). @@ -222,7 +224,7 @@ if hasArg uninstall; then # removes the latest one and leaves the others installed. build.sh uninstall # can be run multiple times to remove all of them, but that is not obvious. pip uninstall -y pylibcugraph cugraph cugraph-service-client cugraph-service-server \ - cugraph-dgl cugraph-pyg nx-cugraph + cugraph-dgl cugraph-pyg cugraph-equivariant nx-cugraph fi if hasArg clean; then @@ -359,6 +361,15 @@ if hasArg cugraph-dgl || hasArg all; then fi fi +# Build and install the cugraph-equivariant Python package +if hasArg cugraph-equivariant || hasArg all; then + if hasArg --clean; then + cleanPythonDir ${REPODIR}/python/cugraph-equivariant + else + python ${PYTHON_ARGS_FOR_INSTALL} ${REPODIR}/python/cugraph-equivariant + fi +fi + # Build and install the nx-cugraph Python package if hasArg nx-cugraph || hasArg all; then if hasArg --clean; then diff --git a/ci/build_python.sh b/ci/build_python.sh index a99e5ce63e8..07a4f59396b 100755 --- a/ci/build_python.sh +++ b/ci/build_python.sh @@ -89,4 +89,9 @@ if [[ ${RAPIDS_CUDA_MAJOR} == "11" ]]; then conda/recipes/cugraph-dgl fi +rapids-conda-retry mambabuild \ + --no-test \ + --channel "${RAPIDS_CONDA_BLD_OUTPUT_DIR}" \ + conda/recipes/cugraph-equivariant + rapids-upload-conda-to-s3 python diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 828d8948143..30a1c98c106 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -57,7 +57,8 @@ python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check # pure-python packages should not have auditwheel run on them. if [[ ${package_name} == "nx-cugraph" ]] || \ [[ ${package_name} == "cugraph-dgl" ]] || \ - [[ ${package_name} == "cugraph-pyg" ]]; then + [[ ${package_name} == "cugraph-pyg" ]] || \ + [[ ${package_name} == "cugraph-equivariant" ]]; then RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 dist else mkdir -p final_dist diff --git a/ci/build_wheel_cugraph-equivariant.sh b/ci/build_wheel_cugraph-equivariant.sh new file mode 100755 index 00000000000..fcc8e0f774c --- /dev/null +++ b/ci/build_wheel_cugraph-equivariant.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -euo pipefail + +./ci/build_wheel.sh cugraph-equivariant python/cugraph-equivariant diff --git a/ci/test_python.sh b/ci/test_python.sh index 7eb5a08edc8..5892c37e35b 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -247,5 +247,46 @@ else rapids-logger "skipping cugraph_pyg pytest on CUDA != 11.8" fi +# test cugraph-equivariant +if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then + if [[ "${RUNNER_ARCH}" != "ARM64" ]]; then + # Reuse cugraph-dgl's test env for cugraph-equivariant + set +u + conda activate test_cugraph_dgl + set -u + rapids-mamba-retry install \ + --channel "${CPP_CHANNEL}" \ + --channel "${PYTHON_CHANNEL}" \ + --channel pytorch \ + --channel nvidia \ + cugraph-equivariant + pip install e3nn==0.5.1 + + rapids-print-env + + rapids-logger "pytest cugraph-equivariant" + pushd python/cugraph-equivariant/cugraph_equivariant + pytest \ + --cache-clear \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-cugraph-equivariant.xml" \ + --cov-config=../../.coveragerc \ + --cov=cugraph_equivariant \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cugraph-equivariant-coverage.xml" \ + --cov-report=term \ + . + popd + + # Reactivate the test environment back + set +u + conda deactivate + conda activate test + set -u + else + rapids-logger "skipping cugraph-equivariant pytest on ARM64" + fi +else + rapids-logger "skipping cugraph-equivariant pytest on CUDA!=11.8" +fi + rapids-logger "Test script exiting with value: $EXITCODE" exit ${EXITCODE} diff --git a/ci/test_wheel_cugraph-equivariant.sh b/ci/test_wheel_cugraph-equivariant.sh new file mode 100755 index 00000000000..f054780b03a --- /dev/null +++ b/ci/test_wheel_cugraph-equivariant.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -eoxu pipefail + +package_name="cugraph-equivariant" +package_dir="python/cugraph-equivariant" + +python_package_name=$(echo ${package_name}|sed 's/-/_/g') + +mkdir -p ./dist +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" + +# use 'ls' to expand wildcard before adding `[extra]` requires for pip +RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./dist +# pip creates wheels using python package names +python -m pip install $(ls ./dist/${python_package_name}*.whl)[test] + + +PKG_CUDA_VER="$(echo ${CUDA_VERSION} | cut -d '.' -f1,2 | tr -d '.')" +PKG_CUDA_VER_MAJOR=${PKG_CUDA_VER:0:2} +if [[ "${PKG_CUDA_VER_MAJOR}" == "12" ]]; then + PYTORCH_CUDA_VER="121" +else + PYTORCH_CUDA_VER=$PKG_CUDA_VER +fi +PYTORCH_URL="https://download.pytorch.org/whl/cu${PYTORCH_CUDA_VER}" + +rapids-logger "Installing PyTorch and e3nn" +rapids-retry python -m pip install torch --index-url ${PYTORCH_URL} +rapids-retry python -m pip install e3nn + +python -m pytest python/cugraph-equivariant/cugraph_equivariant/tests diff --git a/conda/recipes/cugraph-equivariant/build.sh b/conda/recipes/cugraph-equivariant/build.sh new file mode 100644 index 00000000000..f0ff1688b55 --- /dev/null +++ b/conda/recipes/cugraph-equivariant/build.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +# Copyright (c) 2024, NVIDIA CORPORATION. + +# This assumes the script is executed from the root of the repo directory + +./build.sh cugraph-equivariant diff --git a/conda/recipes/cugraph-equivariant/meta.yaml b/conda/recipes/cugraph-equivariant/meta.yaml new file mode 100644 index 00000000000..a952812f845 --- /dev/null +++ b/conda/recipes/cugraph-equivariant/meta.yaml @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +{% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') + environ.get('VERSION_SUFFIX', '') %} +{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} +{% set py_version = environ['CONDA_PY'] %} +{% set date_string = environ['RAPIDS_DATE_STRING'] %} + +package: + name: cugraph-equivariant + version: {{ version }} + +source: + path: ../../.. + +build: + number: {{ GIT_DESCRIBE_NUMBER }} + build: + number: {{ GIT_DESCRIBE_NUMBER }} + string: py{{ py_version }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} + +requirements: + host: + - python + run: + - pylibcugraphops ={{ minor_version }} + - python + +tests: + imports: + - cugraph_equivariant + +about: + home: https://rapids.ai/ + dev_url: https://github.com/rapidsai/cugraph + license: Apache-2.0 + license_file: ../../../LICENSE + summary: GPU-accelerated equivariant convolutional layers. diff --git a/dependencies.yaml b/dependencies.yaml index e7806aa86f3..cfefe3b9ff9 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -198,6 +198,28 @@ files: key: test includes: - test_python_common + py_build_cugraph_equivariant: + output: pyproject + pyproject_dir: python/cugraph-equivariant + extras: + table: build-system + includes: + - python_build_wheel + py_run_cugraph_equivariant: + output: pyproject + pyproject_dir: python/cugraph-equivariant + extras: + table: project + includes: + - depends_on_pylibcugraphops + py_test_cugraph_equivariant: + output: pyproject + pyproject_dir: python/cugraph-equivariant + extras: + table: project.optional-dependencies + key: test + includes: + - test_python_common py_build_cugraph_service_client: output: pyproject pyproject_dir: python/cugraph-service/client diff --git a/python/cugraph-equivariant/LICENSE b/python/cugraph-equivariant/LICENSE new file mode 120000 index 00000000000..30cff7403da --- /dev/null +++ b/python/cugraph-equivariant/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/python/cugraph-equivariant/README.md b/python/cugraph-equivariant/README.md new file mode 100644 index 00000000000..d5de8852709 --- /dev/null +++ b/python/cugraph-equivariant/README.md @@ -0,0 +1,5 @@ +# cugraph-equivariant + +## Description + +cugraph-equivariant library provides fast symmetry-preserving (equivariant) operations and convolutional layers, to accelerate the equivariant neural networks in drug discovery and other domains. diff --git a/python/cugraph-equivariant/cugraph_equivariant/VERSION b/python/cugraph-equivariant/cugraph_equivariant/VERSION new file mode 120000 index 00000000000..d62dc733efd --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/VERSION @@ -0,0 +1 @@ +../../../VERSION \ No newline at end of file diff --git a/python/cugraph-equivariant/cugraph_equivariant/__init__.py b/python/cugraph-equivariant/cugraph_equivariant/__init__.py new file mode 100644 index 00000000000..20507bd9329 --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cugraph_equivariant._version import __git_commit__, __version__ diff --git a/python/cugraph-equivariant/cugraph_equivariant/_version.py b/python/cugraph-equivariant/cugraph_equivariant/_version.py new file mode 100644 index 00000000000..31a707bb17e --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/_version.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources + +# Read VERSION file from the module that is symlinked to VERSION file +# in the root of the repo at build time or copied to the module at +# installation. VERSION is a separate file that allows CI build-time scripts +# to update version info (including commit hashes) without modifying +# source files. +__version__ = ( + importlib.resources.files("cugraph_equivariant") + .joinpath("VERSION") + .read_text() + .strip() +) +__git_commit__ = "" diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/__init__.py b/python/cugraph-equivariant/cugraph_equivariant/nn/__init__.py new file mode 100644 index 00000000000..8f4d8de0042 --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tensor_product_conv import FullyConnectedTensorProductConv + +DiffDockTensorProductConv = FullyConnectedTensorProductConv + +__all__ = [ + "FullyConnectedTensorProductConv", + "DiffDockTensorProductConv", +] diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py new file mode 100644 index 00000000000..5120a23180d --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py @@ -0,0 +1,259 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import torch +from torch import nn +from e3nn import o3 +from e3nn.nn import BatchNorm + +from cugraph_equivariant.utils import scatter_reduce + +from pylibcugraphops.pytorch.operators import FusedFullyConnectedTensorProduct + + +class FullyConnectedTensorProductConv(nn.Module): + r"""Message passing layer for tensor products in DiffDock-like architectures. + The left operand of tensor product is the spherical harmonic representation + of edge vector; the right operand consists of node features in irreps. + + .. math:: + \sum_{b \in \mathcal{N}_a} Y\left(\hat{r}_{a b}\right) + \otimes_{\psi_{a b}} \mathbf{h}_b + + where the path weights :math:`\psi_{a b}` can be constructed from edge + embeddings and scalar features using an MLP: + + .. math:: + \psi_{a b} = \operatorname{MLP} + \left(e_{a b}, \mathbf{h}_a^0, \mathbf{h}_b^0\right) + + Users have the option to either directly input the weights or provide the + MLP parameters and scalar features from edges and nodes. + + Parameters + ---------- + in_irreps : e3nn.o3.Irreps + Irreps for the input node features. + + sh_irreps : e3nn.o3.Irreps + Irreps for the spherical harmonic representations of edge vectors. + + out_irreps : e3nn.o3.Irreps + Irreps for the output. + + batch_norm : bool, optional (default=True) + If true, batch normalization is applied. + + mlp_channels : sequence of ints, optional (default=None) + A sequence of integers defining number of neurons in each layer in MLP + before the output layer. If `None`, no MLP will be added. The input layer + contains edge embeddings and node scalar features. + + mlp_activation : nn.Module or sequence of nn.Module, optional (default=nn.GELU()) + A sequence of functions to be applied in between linear layers in MLP, + e.g., `nn.Sequential(nn.ReLU(), nn.Dropout(0.4))`. + + e3nn_compat_mode: bool, optional (default=False) + cugraph-ops and e3nn use different memory layout for Irreps-tensors. + The last (fastest moving) dimension is num_channels for cugraph-ops and + ir.dim for e3nn. When enabled, the input and output of this layer will + follow e3nn's memory layout. + + Examples + -------- + >>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers + >>> # having 16 channels. edge_emb.size(1) must match the size of + >>> # the input layer: 6 + >>> + >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda() + >>> out = conv1(src_features, edge_sh, edge_emb, graph) + >>> + >>> # Case 2: Same as case 1 but with the scalar features from edges, sources + >>> # and destinations passed in separately. + >>> + >>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda() + >>> out = conv3(src_features, edge_sh, edge_scalars, graph, + >>> src_scalars=src_scalars, dst_scalars=dst_scalars) + >>> + >>> # Case 3: No MLP, edge_emb will be directly used as the tensor product weights + >>> + >>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=None).cuda() + >>> out = conv2(src_features, edge_sh, edge_emb, graph) + + """ + + def __init__( + self, + in_irreps: o3.Irreps, + sh_irreps: o3.Irreps, + out_irreps: o3.Irreps, + batch_norm: bool = True, + mlp_channels: Optional[Sequence[int]] = None, + mlp_activation: Union[nn.Module, Sequence[nn.Module]] = nn.GELU(), + e3nn_compat_mode: bool = False, + ): + super().__init__() + self.in_irreps = in_irreps + self.out_irreps = out_irreps + self.sh_irreps = sh_irreps + self.e3nn_compat_mode = e3nn_compat_mode + + self.tp = FusedFullyConnectedTensorProduct( + in_irreps, sh_irreps, out_irreps, e3nn_compat_mode=e3nn_compat_mode + ) + + self.batch_norm = BatchNorm(out_irreps) if batch_norm else None + + if mlp_activation is None: + mlp_activation = [] + elif hasattr(mlp_activation, "__len__") and hasattr( + mlp_activation, "__getitem__" + ): + mlp_activation = list(mlp_activation) + else: + mlp_activation = [mlp_activation] + + if mlp_channels is not None: + dims = list(mlp_channels) + [self.tp.weight_numel] + mlp = [] + for i in range(len(dims) - 1): + mlp.append(nn.Linear(dims[i], dims[i + 1])) + if i != len(dims) - 2: + mlp.extend(mlp_activation) + self.mlp = nn.Sequential(*mlp) + else: + self.mlp = None + + def forward( + self, + src_features: torch.Tensor, + edge_sh: torch.Tensor, + edge_emb: torch.Tensor, + graph: tuple[torch.Tensor, tuple[int, int]], + src_scalars: Optional[torch.Tensor] = None, + dst_scalars: Optional[torch.Tensor] = None, + reduce: str = "mean", + edge_envelope: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + src_features : torch.Tensor + Source node features. + Shape: (num_src_nodes, in_irreps.dim) + + edge_sh : torch.Tensor + The spherical harmonic representations of the edge vectors. + Shape: (num_edges, sh_irreps.dim) + + edge_emb: torch.Tensor + Edge embeddings that are fed into MLPs to generate tensor product weights. + Shape: (num_edges, dim), where `dim` should be: + - `tp.weight_numel` when the layer does not contain MLPs. + - num_edge_scalars, with the sum of num_[edge/src/dst]_scalars being + mlp_channels[0] + + graph : tuple + A tuple that stores the graph information, with the first element being + the adjacency matrix in COO, and the second element being its shape: + (num_src_nodes, num_dst_nodes). + + src_scalars: torch.Tensor, optional + Scalar features of source nodes. + Shape: (num_src_nodes, num_src_scalars) + + dst_scalars: torch.Tensor, optional + Scalar features of destination nodes. + Shape: (num_dst_nodes, num_dst_scalars) + + reduce : str, optional (default="mean") + Reduction operator. Choose between "mean" and "sum". + + edge_envelope: torch.Tensor, optional + Typically used as attenuation factors to fade out messages coming + from nodes close to the cutoff distance used to create the graph. + This is important to make the model smooth to the changes in node's + coordinates. + Shape: (num_edges,) + + Returns + ------- + torch.Tensor + Output node features. + Shape: (num_dst_nodes, out_irreps.dim) + """ + edge_emb_size = edge_emb.size(-1) + src_scalars_size = 0 if src_scalars is None else src_scalars.size(-1) + dst_scalars_size = 0 if dst_scalars is None else dst_scalars.size(-1) + + if self.mlp is None: + if self.tp.weight_numel != edge_emb_size: + raise RuntimeError( + f"When MLP is not present, edge_emb's last dimension must " + f"equal tp.weight_numel (but got {edge_emb_size} and " + f"{self.tp.weight_numel})" + ) + else: + total_size = edge_emb_size + src_scalars_size + dst_scalars_size + if self.mlp[0].in_features != total_size: + raise RuntimeError( + f"The size of MLP's input layer ({self.mlp[0].in_features}) " + f"does not match the total number of scalar features from " + f"edge_emb, src_scalars and dst_scalars ({total_size})" + ) + + if reduce not in ["mean", "sum"]: + raise RuntimeError( + f"reduce argument must be either 'mean' or 'sum', got {reduce}." + ) + + (src, dst), (num_src_nodes, num_dst_nodes) = graph + + if self.mlp is not None: + if src_scalars is None and dst_scalars is None: + tp_weights = self.mlp(edge_emb) + else: + w_edge, w_src, w_dst = torch.split( + self.mlp[0].weight, + (edge_emb_size, src_scalars_size, dst_scalars_size), + dim=-1, + ) + tp_weights = edge_emb @ w_edge.T + self.mlp[0].bias + + if src_scalars is not None: + tp_weights += (src_scalars @ w_src.T)[src] + + if dst_scalars is not None: + tp_weights += (dst_scalars @ w_dst.T)[dst] + + tp_weights = self.mlp[1:](tp_weights) + else: + tp_weights = edge_emb + + out = self.tp(src_features[src], edge_sh, tp_weights) + + if edge_envelope is not None: + out = out * edge_envelope.view(-1, 1) + + out = scatter_reduce(out, dst, dim=0, dim_size=num_dst_nodes, reduce=reduce) + + if self.batch_norm: + out = self.batch_norm(out) + + return out diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py b/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py new file mode 100644 index 00000000000..c7c6bad07db --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/conftest.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + + +@pytest.fixture +def example_scatter_data(): + src_feat = torch.Tensor([3, 1, 0, 1, 1, 2]) + dst_indices = torch.Tensor([0, 1, 2, 2, 3, 1]) + + results = { + "sum": torch.Tensor([3.0, 3.0, 1.0, 1.0]), + "mean": torch.Tensor([3.0, 1.5, 0.5, 1.0]), + "prod": torch.Tensor([3.0, 2.0, 0.0, 1.0]), + "amax": torch.Tensor([3.0, 2.0, 1.0, 1.0]), + "amin": torch.Tensor([3.0, 1.0, 0.0, 1.0]), + } + + return src_feat, dst_indices, results diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py b/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py new file mode 100644 index 00000000000..ff8048468ee --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/test_scatter.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from cugraph_equivariant.utils import scatter_reduce + + +@pytest.mark.parametrize("reduce", ["sum", "mean", "prod", "amax", "amin"]) +def test_scatter_reduce(example_scatter_data, reduce): + device = torch.device("cuda:0") + src, index, out_true = example_scatter_data + src = src.to(device) + index = index.to(device) + + out = scatter_reduce(src, index, dim=0, dim_size=None, reduce=reduce) + + assert torch.allclose(out.cpu(), out_true[reduce]) diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py new file mode 100644 index 00000000000..a2a13b32cd2 --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import torch +from torch import nn +from e3nn import o3 +from cugraph_equivariant.nn import FullyConnectedTensorProductConv + +device = torch.device("cuda:0") + + +@pytest.mark.parametrize("e3nn_compat_mode", [True, False]) +@pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize( + "mlp_channels, mlp_activation, scalar_sizes", + [ + [(30, 8, 8), nn.Sequential(nn.Dropout(0.3), nn.ReLU()), (15, 15, 0)], + [(7,), nn.GELU(), (2, 3, 2)], + [None, None, None], + ], +) +def test_tensor_product_conv_equivariance( + mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode +): + torch.manual_seed(12345) + + in_irreps = o3.Irreps("10x0e + 10x1e") + out_irreps = o3.Irreps("20x0e + 10x1e") + sh_irreps = o3.Irreps.spherical_harmonics(lmax=2) + + tp_conv = FullyConnectedTensorProductConv( + in_irreps=in_irreps, + sh_irreps=sh_irreps, + out_irreps=out_irreps, + mlp_channels=mlp_channels, + mlp_activation=mlp_activation, + batch_norm=batch_norm, + e3nn_compat_mode=e3nn_compat_mode, + ).to(device) + + num_src_nodes, num_dst_nodes = 9, 7 + num_edges = 40 + src = torch.randint(num_src_nodes, (num_edges,), device=device) + dst = torch.randint(num_dst_nodes, (num_edges,), device=device) + edge_index = torch.vstack((src, dst)) + + src_pos = torch.randn(num_src_nodes, 3, device=device) + dst_pos = torch.randn(num_dst_nodes, 3, device=device) + edge_vec = dst_pos[dst] - src_pos[src] + edge_sh = o3.spherical_harmonics( + tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component" + ).to(device) + src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device) + + rot = o3.rand_matrix() + D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device) + D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device) + D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device) + + if mlp_channels is None: + edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device) + src_scalars = dst_scalars = None + else: + if scalar_sizes: + edge_emb = torch.randn(num_edges, scalar_sizes[0], device=device) + src_scalars = ( + None + if scalar_sizes[1] == 0 + else torch.randn(num_src_nodes, scalar_sizes[1], device=device) + ) + dst_scalars = ( + None + if scalar_sizes[2] == 0 + else torch.randn(num_dst_nodes, scalar_sizes[2], device=device) + ) + else: + edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device) + src_scalars = dst_scalars = None + + # rotate before + out_before = tp_conv( + src_features=src_features @ D_in.T, + edge_sh=edge_sh @ D_sh.T, + edge_emb=edge_emb, + graph=(edge_index, (num_src_nodes, num_dst_nodes)), + src_scalars=src_scalars, + dst_scalars=dst_scalars, + ) + + # rotate after + out_after = ( + tp_conv( + src_features=src_features, + edge_sh=edge_sh, + edge_emb=edge_emb, + graph=(edge_index, (num_src_nodes, num_dst_nodes)), + src_scalars=src_scalars, + dst_scalars=dst_scalars, + ) + @ D_out.T + ) + + torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4) diff --git a/python/cugraph-equivariant/cugraph_equivariant/utils/__init__.py b/python/cugraph-equivariant/cugraph_equivariant/utils/__init__.py new file mode 100644 index 00000000000..b4acfe8d090 --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .scatter import scatter_reduce + +__all__ = [ + "scatter_reduce", +] diff --git a/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py b/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py new file mode 100644 index 00000000000..45cc541fc7b --- /dev/null +++ b/python/cugraph-equivariant/cugraph_equivariant/utils/scatter.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +def broadcast(src: torch.Tensor, ref: torch.Tensor, dim: int) -> torch.Tensor: + size = ((1,) * dim) + (-1,) + ((1,) * (ref.dim() - dim - 1)) + return src.view(size).expand_as(ref) + + +def scatter_reduce( + src: torch.Tensor, + index: torch.Tensor, + dim: int = 0, + dim_size: Optional[int] = None, # value of out.size(dim) + reduce: str = "sum", # "sum", "prod", "mean", "amax", "amin" +): + # scatter() expects index to be int64 + index = broadcast(index, src, dim).to(torch.int64) + + size = list(src.size()) + + if dim_size is not None: + assert dim_size >= int(index.max()) + 1 + size[dim] = dim_size + else: + size[dim] = int(index.max()) + 1 + + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_reduce_(dim, index, src, reduce, include_self=False) diff --git a/python/cugraph-equivariant/pyproject.toml b/python/cugraph-equivariant/pyproject.toml new file mode 100644 index 00000000000..f261b0e3535 --- /dev/null +++ b/python/cugraph-equivariant/pyproject.toml @@ -0,0 +1,64 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = [ + "setuptools>=61.0.0", + "wheel", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. + +[project] +name = "cugraph-equivariant" +dynamic = ["version"] +description = "Fast GPU-based equivariant operations and convolutional layers." +readme = { file = "README.md", content-type = "text/markdown" } +authors = [ + { name = "NVIDIA Corporation" }, +] +license = { text = "Apache 2.0" } +requires-python = ">=3.9" +classifiers = [ + "Intended Audience :: Developers", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", +] +dependencies = [ + "pylibcugraphops==24.2.*", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. + +[project.urls] +Homepage = "https://github.com/rapidsai/cugraph" +Documentation = "https://docs.rapids.ai/api/cugraph/stable/api_docs/cugraph-ops/" + +[project.optional-dependencies] +test = [ + "pandas", + "pytest", + "pytest-benchmark", + "pytest-cov", + "pytest-xdist", + "scipy", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. + +[tool.setuptools] +license-files = ["LICENSE"] + +[tool.setuptools.dynamic] +version = {file = "cugraph_equivariant/VERSION"} + +[tool.setuptools.packages.find] +include = [ + "cugraph_equivariant*", + "cugraph_equivariant.*", +] diff --git a/python/cugraph-equivariant/setup.py b/python/cugraph-equivariant/setup.py new file mode 100644 index 00000000000..acd0df3f717 --- /dev/null +++ b/python/cugraph-equivariant/setup.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import find_packages, setup + +if __name__ == "__main__": + packages = find_packages(include=["cugraph_equivariant*"]) + setup( + package_data={key: ["VERSION"] for key in packages}, + )