From 924e55b6316909c3c94b14def90782327aad9d0c Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 18 Jan 2024 18:29:57 -0500 Subject: [PATCH] update based on cugraph-ops pr --- .../nn/tensor_product_conv.py | 54 +++++++++++++++---- .../tests/test_tensor_product_conv.py | 32 +++++------ 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py index 46f1eea019f..b6a10d738a9 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py @@ -20,12 +20,11 @@ from cugraph_equivariant.utils import scatter_reduce -try: - from pylibcugraphops.equivariant import TensorProduct - - HAS_TP_LIB = True -except ImportError: - HAS_TP_LIB = False +from pylibcugraphops.pytorch.operators import ( + FusedFullyConnectedTensorProduct, + transpose_irrep_to_m_last, + transpose_irrep_to_channels_last, +) class FullyConnectedTensorProductConv(nn.Module): @@ -73,6 +72,33 @@ class FullyConnectedTensorProductConv(nn.Module): leading to a lower complexity in most use cases. This option requires users to explicitly pass in `src_scalars` and `dst_scalars` in `forward()` call. + + use_e3nn_tp: bool, optional (default=False) + If `True`, use TensorProduct functions from e3nn. + + Examples + -------- + >>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers + >>> # having 16 channels. edge_emb.size(1) must match the size of + >>> # the input layer: 6 + >>> + >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU).cuda() + >>> out = conv1(src_features, edge_sh, edge_emb, graph) + >>> + >>> # Case 2: No MLP, edge_emb will be directly used as the tensor product weights + >>> + >>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=None).cuda() + >>> out = conv2(src_features, edge_sh, edge_emb, graph) + >>> + >>> # Case 3: Same as case 1 but with `mlp_fast_first_layer=True`. The scalar features + >>> # from edges, sources and destinations have to be passed in separately. + >>> + >>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, + >>> mlp_channels=[6, 16, 16], mlp_fast_first_layer=True).cuda() + >>> out = conv3(src_features, edge_sh, edge_scalars, graph, + >>> src_scalars=src_scalars, dst_scalars=dst_scalars) """ def __init__( @@ -84,18 +110,20 @@ def __init__( mlp_channels: Optional[Sequence[int]] = None, mlp_activation: Optional[Callable[..., nn.Module]] = nn.GELU, mlp_fast_first_layer: bool = False, + use_e3nn_tp: bool = False, ): super().__init__() self.in_irreps = in_irreps self.out_irreps = out_irreps self.sh_irreps = sh_irreps - if HAS_TP_LIB: - self.tp = TensorProduct(str(in_irreps), str(sh_irreps), str(out_irreps)) - else: + if use_e3nn_tp: self.tp = o3.FullyConnectedTensorProduct( in_irreps, sh_irreps, out_irreps, shared_weights=False ) + else: + self.tp = FusedFullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps) + self.use_e3nn_tp = use_e3nn_tp self.batch_norm = BatchNorm(out_irreps) if batch_norm else None @@ -204,7 +232,13 @@ def forward( else: tp_weights = edge_emb - out = self.tp(src_features[src], edge_sh, tp_weights) + if not self.use_e3nn_tp: + out = self.tp(src_features[src], edge_sh, tp_weights) + else: + src_features = transpose_irrep_to_m_last(src_features, self.in_irreps) + edge_sh = transpose_irrep_to_m_last(edge_sh, self.sh_irreps) + out = self.tp(src_features[src], edge_sh, tp_weights) + out = transpose_irrep_to_channels_last(out, self.out_irreps) if edge_envelope is not None: out = out * edge_envelope.view(-1, 1) diff --git a/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py index 3e9c0ca3df6..5dc90fe3f0d 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py @@ -17,6 +17,8 @@ from e3nn import o3 from cugraph_equivariant.nn import FullyConnectedTensorProductConv +device = torch.device("cuda:0") + @pytest.mark.parametrize( "mlp_channels, mlp_fast_first_layer", @@ -35,37 +37,37 @@ def test_tensor_product_conv_equivariance(mlp_channels, mlp_fast_first_layer): out_irreps=out_irreps, mlp_channels=mlp_channels, mlp_fast_first_layer=mlp_fast_first_layer, - ) + ).to(device) num_src_nodes, num_dst_nodes = 9, 7 num_edges = 40 - src = torch.randint(num_src_nodes, (num_edges,)) - dst = torch.randint(num_dst_nodes, (num_edges,)) + src = torch.randint(num_src_nodes, (num_edges,), device=device) + dst = torch.randint(num_dst_nodes, (num_edges,), device=device) edge_index = torch.vstack((src, dst)) - src_pos = torch.randn(num_src_nodes, 3) - dst_pos = torch.randn(num_dst_nodes, 3) + src_pos = torch.randn(num_src_nodes, 3, device=device) + dst_pos = torch.randn(num_dst_nodes, 3, device=device) edge_vec = dst_pos[dst] - src_pos[src] edge_sh = o3.spherical_harmonics( tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component" - ) - src_features = torch.randn(num_src_nodes, in_irreps.dim) + ).to(device) + src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device) rot = o3.rand_matrix() - D_in = tp_conv.in_irreps.D_from_matrix(rot) - D_sh = tp_conv.sh_irreps.D_from_matrix(rot) - D_out = tp_conv.out_irreps.D_from_matrix(rot) + D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device) + D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device) + D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device) if mlp_channels is None: - edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel) + edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device) src_scalars = dst_scalars = None else: if mlp_fast_first_layer: - edge_emb = torch.randn(num_edges, tp_conv.num_scalars) - src_scalars = torch.randn(num_src_nodes, tp_conv.num_scalars) - dst_scalars = torch.randn(num_dst_nodes, tp_conv.num_scalars) + edge_emb = torch.randn(num_edges, tp_conv.num_scalars, device=device) + src_scalars = torch.randn(num_src_nodes, tp_conv.num_scalars, device=device) + dst_scalars = torch.randn(num_dst_nodes, tp_conv.num_scalars, device=device) else: - edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features) + edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device) src_scalars = dst_scalars = None # rotate before