diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py index 41303b973a0..08c51400e7f 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py @@ -172,13 +172,11 @@ def split_tensors( x_src_dict[src_rel] = t_list[i] for i, dst_rel in enumerate(self.relations_per_ntype[ntype][1]): - # src_type, _, dst_type = dst_rel.split("__") - # if src_type != dst_type: x_dst_dict[dst_rel] = t_list[i + n_src_rel] return x_src_dict, x_dst_dict - def reset_parameters(self, seed: Optional[int] = None) -> None: + def reset_parameters(self, seed: Optional[int] = None): if seed is not None: torch.manual_seed(seed)