Skip to content

Commit

Permalink
Update python/cugraph-equivariant/cugraph_equivariant/tests/test_tens…
Browse files Browse the repository at this point in the history
…or_product_conv.py

Co-authored-by: Mario Geiger <[email protected]>
  • Loading branch information
tingyu66 and mariogeiger authored Jan 24, 2024
1 parent c0abf8c commit 40b972c
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def test_tensor_product_conv_equivariance(

# rotate before
out_before = tp_conv(
src_features=src_features @ D_in,
edge_sh=edge_sh @ D_sh,
src_features=src_features @ D_in.T,
edge_sh=edge_sh @ D_sh.T,
edge_emb=edge_emb,
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
src_scalars=src_scalars,
Expand All @@ -109,7 +109,7 @@ def test_tensor_product_conv_equivariance(
src_scalars=src_scalars,
dst_scalars=dst_scalars,
)
@ D_out
@ D_out.T
)

torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4)

0 comments on commit 40b972c

Please sign in to comment.