Skip to content

Commit

Permalink
update based on cugraph-ops pr
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Jan 18, 2024
1 parent 3fd56dc commit 924e55b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

from cugraph_equivariant.utils import scatter_reduce

try:
from pylibcugraphops.equivariant import TensorProduct

HAS_TP_LIB = True
except ImportError:
HAS_TP_LIB = False
from pylibcugraphops.pytorch.operators import (
FusedFullyConnectedTensorProduct,
transpose_irrep_to_m_last,
transpose_irrep_to_channels_last,
)


class FullyConnectedTensorProductConv(nn.Module):
Expand Down Expand Up @@ -73,6 +72,33 @@ class FullyConnectedTensorProductConv(nn.Module):
leading to a lower complexity in most use cases. This option requires
users to explicitly pass in `src_scalars` and `dst_scalars` in
`forward()` call.
use_e3nn_tp: bool, optional (default=False)
If `True`, use TensorProduct functions from e3nn.
Examples
--------
>>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers
>>> # having 16 channels. edge_emb.size(1) must match the size of
>>> # the input layer: 6
>>>
>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU).cuda()
>>> out = conv1(src_features, edge_sh, edge_emb, graph)
>>>
>>> # Case 2: No MLP, edge_emb will be directly used as the tensor product weights
>>>
>>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=None).cuda()
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
>>>
>>> # Case 3: Same as case 1 but with `mlp_fast_first_layer=True`. The scalar features
>>> # from edges, sources and destinations have to be passed in separately.
>>>
>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_fast_first_layer=True).cuda()
>>> out = conv3(src_features, edge_sh, edge_scalars, graph,
>>> src_scalars=src_scalars, dst_scalars=dst_scalars)
"""

def __init__(
Expand All @@ -84,18 +110,20 @@ def __init__(
mlp_channels: Optional[Sequence[int]] = None,
mlp_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
mlp_fast_first_layer: bool = False,
use_e3nn_tp: bool = False,
):
super().__init__()
self.in_irreps = in_irreps
self.out_irreps = out_irreps
self.sh_irreps = sh_irreps

if HAS_TP_LIB:
self.tp = TensorProduct(str(in_irreps), str(sh_irreps), str(out_irreps))
else:
if use_e3nn_tp:
self.tp = o3.FullyConnectedTensorProduct(
in_irreps, sh_irreps, out_irreps, shared_weights=False
)
else:
self.tp = FusedFullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps)
self.use_e3nn_tp = use_e3nn_tp

self.batch_norm = BatchNorm(out_irreps) if batch_norm else None

Expand Down Expand Up @@ -204,7 +232,13 @@ def forward(
else:
tp_weights = edge_emb

out = self.tp(src_features[src], edge_sh, tp_weights)
if not self.use_e3nn_tp:
out = self.tp(src_features[src], edge_sh, tp_weights)
else:
src_features = transpose_irrep_to_m_last(src_features, self.in_irreps)
edge_sh = transpose_irrep_to_m_last(edge_sh, self.sh_irreps)
out = self.tp(src_features[src], edge_sh, tp_weights)
out = transpose_irrep_to_channels_last(out, self.out_irreps)

if edge_envelope is not None:
out = out * edge_envelope.view(-1, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from e3nn import o3
from cugraph_equivariant.nn import FullyConnectedTensorProductConv

device = torch.device("cuda:0")


@pytest.mark.parametrize(
"mlp_channels, mlp_fast_first_layer",
Expand All @@ -35,37 +37,37 @@ def test_tensor_product_conv_equivariance(mlp_channels, mlp_fast_first_layer):
out_irreps=out_irreps,
mlp_channels=mlp_channels,
mlp_fast_first_layer=mlp_fast_first_layer,
)
).to(device)

num_src_nodes, num_dst_nodes = 9, 7
num_edges = 40
src = torch.randint(num_src_nodes, (num_edges,))
dst = torch.randint(num_dst_nodes, (num_edges,))
src = torch.randint(num_src_nodes, (num_edges,), device=device)
dst = torch.randint(num_dst_nodes, (num_edges,), device=device)
edge_index = torch.vstack((src, dst))

src_pos = torch.randn(num_src_nodes, 3)
dst_pos = torch.randn(num_dst_nodes, 3)
src_pos = torch.randn(num_src_nodes, 3, device=device)
dst_pos = torch.randn(num_dst_nodes, 3, device=device)
edge_vec = dst_pos[dst] - src_pos[src]
edge_sh = o3.spherical_harmonics(
tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component"
)
src_features = torch.randn(num_src_nodes, in_irreps.dim)
).to(device)
src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device)

rot = o3.rand_matrix()
D_in = tp_conv.in_irreps.D_from_matrix(rot)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot)
D_out = tp_conv.out_irreps.D_from_matrix(rot)
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device)
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device)

if mlp_channels is None:
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel)
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device)
src_scalars = dst_scalars = None
else:
if mlp_fast_first_layer:
edge_emb = torch.randn(num_edges, tp_conv.num_scalars)
src_scalars = torch.randn(num_src_nodes, tp_conv.num_scalars)
dst_scalars = torch.randn(num_dst_nodes, tp_conv.num_scalars)
edge_emb = torch.randn(num_edges, tp_conv.num_scalars, device=device)
src_scalars = torch.randn(num_src_nodes, tp_conv.num_scalars, device=device)
dst_scalars = torch.randn(num_dst_nodes, tp_conv.num_scalars, device=device)
else:
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features)
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device)
src_scalars = dst_scalars = None

# rotate before
Expand Down

0 comments on commit 924e55b

Please sign in to comment.