diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py index 5a67fbe1502..923edbfc44a 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Union, NamedTuple import torch from torch import nn @@ -31,6 +31,11 @@ ) from exc +class Graph(NamedTuple): + edge_index: torch.Tensor + size: tuple[int, int] + + class FullyConnectedTensorProductConv(nn.Module): r"""Message passing layer for tensor products in DiffDock-like architectures. The left operand of tensor product is the spherical harmonic representation @@ -81,27 +86,35 @@ class FullyConnectedTensorProductConv(nn.Module): Examples -------- - >>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers - >>> # having 16 channels. edge_emb.size(1) must match the size of - >>> # the input layer: 6 - >>> + Case 1: MLP with the input layer having 6 channels and 2 hidden layers + having 16 channels. edge_emb.size(1) must match the size of the input layer: 6 + >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, >>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda() >>> out = conv1(src_features, edge_sh, edge_emb, graph) - >>> - >>> # Case 2: Same as case 1 but with the scalar features from edges, sources - >>> # and destinations passed in separately. - >>> + + Case 2: If `edge_emb` is constructed by concatenating scalar features from + edges, sources and destinations, as in DiffDock, the layer can accept each + scalar component separately: + >>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, >>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda() - >>> out = conv3(src_features, edge_sh, edge_scalars, graph, + >>> out = conv2(src_features, edge_sh, edge_scalars, graph, >>> src_scalars=src_scalars, dst_scalars=dst_scalars) - >>> - >>> # Case 3: No MLP, edge_emb will be directly used as the tensor product weights - >>> + + This allows a smaller GEMM in the first MLP layer by performing GEMM on each + component before indexing. The first-layer weights are split into sections + for edges, sources and destinations, in that order.This is equivalent to + + >>> src, dst = graph.edge_index + >>> edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst])) + >>> out = conv2(src_features, edge_sh, edge_emb, graph) + + Case 3: No MLP, `edge_emb` will be directly used as the tensor product weights: + >>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, >>> mlp_channels=None).cuda() - >>> out = conv2(src_features, edge_sh, edge_emb, graph) + >>> out = conv3(src_features, edge_sh, edge_emb, graph) """ @@ -174,8 +187,8 @@ def forward( Edge embeddings that are fed into MLPs to generate tensor product weights. Shape: (num_edges, dim), where `dim` should be: - `tp.weight_numel` when the layer does not contain MLPs. - - num_edge_scalars, with the sum of num_[edge/src/dst]_scalars being - mlp_channels[0] + - num_edge_scalars, when scalar features from edges, sources and + destinations are passed in separately. graph : tuple A tuple that stores the graph information, with the first element being @@ -183,11 +196,11 @@ def forward( (num_src_nodes, num_dst_nodes). src_scalars: torch.Tensor, optional - Scalar features of source nodes. + Scalar features of source nodes. See examples for usage. Shape: (num_src_nodes, num_src_scalars) dst_scalars: torch.Tensor, optional - Scalar features of destination nodes. + Scalar features of destination nodes. See examples for usage. Shape: (num_dst_nodes, num_dst_scalars) reduce : str, optional (default="mean") diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py index 7fbab1dc934..ce325c47aa0 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py @@ -13,10 +13,6 @@ import pytest -import torch -from torch import nn -from e3nn import o3 - try: from cugraph_equivariant.nn import FullyConnectedTensorProductConv except RuntimeError: @@ -25,9 +21,29 @@ allow_module_level=True, ) -device = torch.device("cuda:0") +import torch +from torch import nn +from e3nn import o3 +from cugraph_equivariant.nn.tensor_product_conv import Graph + +device = torch.device("cuda") +def create_random_graph( + num_src_nodes, + num_dst_nodes, + num_edges, + dtype=None, + device=None, +): + row = torch.randint(num_src_nodes, (num_edges,), dtype=dtype, device=device) + col = torch.randint(num_dst_nodes, (num_edges,), dtype=dtype, device=device) + edge_index = torch.stack([row, col], dim=0) + + return Graph(edge_index, (num_src_nodes, num_dst_nodes)) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("e3nn_compat_mode", [True, False]) @pytest.mark.parametrize("batch_norm", [True, False]) @pytest.mark.parametrize( @@ -39,9 +55,10 @@ ], ) def test_tensor_product_conv_equivariance( - mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode + mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode, dtype ): torch.manual_seed(12345) + to_kwargs = {"device": device, "dtype": dtype} in_irreps = o3.Irreps("10x0e + 10x1e") out_irreps = o3.Irreps("20x0e + 10x1e") @@ -55,68 +72,65 @@ def test_tensor_product_conv_equivariance( mlp_activation=mlp_activation, batch_norm=batch_norm, e3nn_compat_mode=e3nn_compat_mode, - ).to(device) + ).to(**to_kwargs) num_src_nodes, num_dst_nodes = 9, 7 num_edges = 40 - src = torch.randint(num_src_nodes, (num_edges,), device=device) - dst = torch.randint(num_dst_nodes, (num_edges,), device=device) - edge_index = torch.vstack((src, dst)) - - src_pos = torch.randn(num_src_nodes, 3, device=device) - dst_pos = torch.randn(num_dst_nodes, 3, device=device) - edge_vec = dst_pos[dst] - src_pos[src] - edge_sh = o3.spherical_harmonics( - tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component" - ).to(device) - src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device) + graph = create_random_graph(num_src_nodes, num_dst_nodes, num_edges, device=device) + + edge_sh = torch.randn(num_edges, sh_irreps.dim, **to_kwargs) + src_features = torch.randn(num_src_nodes, in_irreps.dim, **to_kwargs) rot = o3.rand_matrix() - D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device) - D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device) - D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device) + D_in = tp_conv.in_irreps.D_from_matrix(rot).to(**to_kwargs) + D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(**to_kwargs) + D_out = tp_conv.out_irreps.D_from_matrix(rot).to(**to_kwargs) if mlp_channels is None: - edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device) + edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, **to_kwargs) src_scalars = dst_scalars = None else: if scalar_sizes: - edge_emb = torch.randn(num_edges, scalar_sizes[0], device=device) + edge_emb = torch.randn(num_edges, scalar_sizes[0], **to_kwargs) src_scalars = ( None if scalar_sizes[1] == 0 - else torch.randn(num_src_nodes, scalar_sizes[1], device=device) + else torch.randn(num_src_nodes, scalar_sizes[1], **to_kwargs) ) dst_scalars = ( None if scalar_sizes[2] == 0 - else torch.randn(num_dst_nodes, scalar_sizes[2], device=device) + else torch.randn(num_dst_nodes, scalar_sizes[2], **to_kwargs) ) else: - edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device) + edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, **to_kwargs) src_scalars = dst_scalars = None # rotate before + torch.manual_seed(12345) out_before = tp_conv( src_features=src_features @ D_in.T, edge_sh=edge_sh @ D_sh.T, edge_emb=edge_emb, - graph=(edge_index, (num_src_nodes, num_dst_nodes)), + graph=graph, src_scalars=src_scalars, dst_scalars=dst_scalars, ) # rotate after + torch.manual_seed(12345) out_after = ( tp_conv( src_features=src_features, edge_sh=edge_sh, edge_emb=edge_emb, - graph=(edge_index, (num_src_nodes, num_dst_nodes)), + graph=graph, src_scalars=src_scalars, dst_scalars=dst_scalars, ) @ D_out.T ) - torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4) + atol = 1e-3 if dtype == torch.float32 else 1e-1 + if e3nn_compat_mode: + assert torch.allclose(out_before, out_after, rtol=1e-4, atol=atol)