Skip to content

Commit

Permalink
Fix TensorProductConv test and improve docs (#4480)
Browse files Browse the repository at this point in the history
Closes #4459

Authors:
  - Tingyu Wang (https://github.com/tingyu66)
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - https://github.com/DejunL
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4480
  • Loading branch information
tingyu66 authored Jul 9, 2024
1 parent 2e969da commit 407cdab
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, NamedTuple

import torch
from torch import nn
Expand All @@ -31,6 +31,11 @@
) from exc


class Graph(NamedTuple):
edge_index: torch.Tensor
size: tuple[int, int]


class FullyConnectedTensorProductConv(nn.Module):
r"""Message passing layer for tensor products in DiffDock-like architectures.
The left operand of tensor product is the spherical harmonic representation
Expand Down Expand Up @@ -81,27 +86,35 @@ class FullyConnectedTensorProductConv(nn.Module):
Examples
--------
>>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers
>>> # having 16 channels. edge_emb.size(1) must match the size of
>>> # the input layer: 6
>>>
Case 1: MLP with the input layer having 6 channels and 2 hidden layers
having 16 channels. edge_emb.size(1) must match the size of the input layer: 6
>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
>>> out = conv1(src_features, edge_sh, edge_emb, graph)
>>>
>>> # Case 2: Same as case 1 but with the scalar features from edges, sources
>>> # and destinations passed in separately.
>>>
Case 2: If `edge_emb` is constructed by concatenating scalar features from
edges, sources and destinations, as in DiffDock, the layer can accept each
scalar component separately:
>>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
>>> out = conv3(src_features, edge_sh, edge_scalars, graph,
>>> out = conv2(src_features, edge_sh, edge_scalars, graph,
>>> src_scalars=src_scalars, dst_scalars=dst_scalars)
>>>
>>> # Case 3: No MLP, edge_emb will be directly used as the tensor product weights
>>>
This allows a smaller GEMM in the first MLP layer by performing GEMM on each
component before indexing. The first-layer weights are split into sections
for edges, sources and destinations, in that order.This is equivalent to
>>> src, dst = graph.edge_index
>>> edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst]))
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
Case 3: No MLP, `edge_emb` will be directly used as the tensor product weights:
>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=None).cuda()
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
>>> out = conv3(src_features, edge_sh, edge_emb, graph)
"""

Expand Down Expand Up @@ -174,20 +187,20 @@ def forward(
Edge embeddings that are fed into MLPs to generate tensor product weights.
Shape: (num_edges, dim), where `dim` should be:
- `tp.weight_numel` when the layer does not contain MLPs.
- num_edge_scalars, with the sum of num_[edge/src/dst]_scalars being
mlp_channels[0]
- num_edge_scalars, when scalar features from edges, sources and
destinations are passed in separately.
graph : tuple
A tuple that stores the graph information, with the first element being
the adjacency matrix in COO, and the second element being its shape:
(num_src_nodes, num_dst_nodes).
src_scalars: torch.Tensor, optional
Scalar features of source nodes.
Scalar features of source nodes. See examples for usage.
Shape: (num_src_nodes, num_src_scalars)
dst_scalars: torch.Tensor, optional
Scalar features of destination nodes.
Scalar features of destination nodes. See examples for usage.
Shape: (num_dst_nodes, num_dst_scalars)
reduce : str, optional (default="mean")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

import pytest

import torch
from torch import nn
from e3nn import o3

try:
from cugraph_equivariant.nn import FullyConnectedTensorProductConv
except RuntimeError:
Expand All @@ -25,9 +21,29 @@
allow_module_level=True,
)

device = torch.device("cuda:0")
import torch
from torch import nn
from e3nn import o3
from cugraph_equivariant.nn.tensor_product_conv import Graph

device = torch.device("cuda")


def create_random_graph(
num_src_nodes,
num_dst_nodes,
num_edges,
dtype=None,
device=None,
):
row = torch.randint(num_src_nodes, (num_edges,), dtype=dtype, device=device)
col = torch.randint(num_dst_nodes, (num_edges,), dtype=dtype, device=device)
edge_index = torch.stack([row, col], dim=0)

return Graph(edge_index, (num_src_nodes, num_dst_nodes))


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("e3nn_compat_mode", [True, False])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize(
Expand All @@ -39,9 +55,10 @@
],
)
def test_tensor_product_conv_equivariance(
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode, dtype
):
torch.manual_seed(12345)
to_kwargs = {"device": device, "dtype": dtype}

in_irreps = o3.Irreps("10x0e + 10x1e")
out_irreps = o3.Irreps("20x0e + 10x1e")
Expand All @@ -55,68 +72,65 @@ def test_tensor_product_conv_equivariance(
mlp_activation=mlp_activation,
batch_norm=batch_norm,
e3nn_compat_mode=e3nn_compat_mode,
).to(device)
).to(**to_kwargs)

num_src_nodes, num_dst_nodes = 9, 7
num_edges = 40
src = torch.randint(num_src_nodes, (num_edges,), device=device)
dst = torch.randint(num_dst_nodes, (num_edges,), device=device)
edge_index = torch.vstack((src, dst))

src_pos = torch.randn(num_src_nodes, 3, device=device)
dst_pos = torch.randn(num_dst_nodes, 3, device=device)
edge_vec = dst_pos[dst] - src_pos[src]
edge_sh = o3.spherical_harmonics(
tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component"
).to(device)
src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device)
graph = create_random_graph(num_src_nodes, num_dst_nodes, num_edges, device=device)

edge_sh = torch.randn(num_edges, sh_irreps.dim, **to_kwargs)
src_features = torch.randn(num_src_nodes, in_irreps.dim, **to_kwargs)

rot = o3.rand_matrix()
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device)
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device)
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(**to_kwargs)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(**to_kwargs)
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(**to_kwargs)

if mlp_channels is None:
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device)
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, **to_kwargs)
src_scalars = dst_scalars = None
else:
if scalar_sizes:
edge_emb = torch.randn(num_edges, scalar_sizes[0], device=device)
edge_emb = torch.randn(num_edges, scalar_sizes[0], **to_kwargs)
src_scalars = (
None
if scalar_sizes[1] == 0
else torch.randn(num_src_nodes, scalar_sizes[1], device=device)
else torch.randn(num_src_nodes, scalar_sizes[1], **to_kwargs)
)
dst_scalars = (
None
if scalar_sizes[2] == 0
else torch.randn(num_dst_nodes, scalar_sizes[2], device=device)
else torch.randn(num_dst_nodes, scalar_sizes[2], **to_kwargs)
)
else:
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device)
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, **to_kwargs)
src_scalars = dst_scalars = None

# rotate before
torch.manual_seed(12345)
out_before = tp_conv(
src_features=src_features @ D_in.T,
edge_sh=edge_sh @ D_sh.T,
edge_emb=edge_emb,
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
graph=graph,
src_scalars=src_scalars,
dst_scalars=dst_scalars,
)

# rotate after
torch.manual_seed(12345)
out_after = (
tp_conv(
src_features=src_features,
edge_sh=edge_sh,
edge_emb=edge_emb,
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
graph=graph,
src_scalars=src_scalars,
dst_scalars=dst_scalars,
)
@ D_out.T
)

torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4)
atol = 1e-3 if dtype == torch.float32 else 1e-1
if e3nn_compat_mode:
assert torch.allclose(out_before, out_after, rtol=1e-4, atol=atol)

0 comments on commit 407cdab

Please sign in to comment.