diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 542e9cacb77..a314b8c7185 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: - id: rapids-dependency-file-generator args: ["--clean"] - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.0.1 + rev: v0.0.3 hooks: - id: verify-copyright files: | diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py index 5120a23180d..af1d0efa76c 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py @@ -251,7 +251,10 @@ def forward( if edge_envelope is not None: out = out * edge_envelope.view(-1, 1) - out = scatter_reduce(out, dst, dim=0, dim_size=num_dst_nodes, reduce=reduce) + dtype = out.dtype + out = scatter_reduce( + out.float(), dst, dim=0, dim_size=num_dst_nodes, reduce=reduce + ).to(dtype) if self.batch_norm: out = self.batch_norm(out)