Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add IDGNN #2

Merged
merged 3 commits into from
Jul 22, 2024
Merged

add IDGNN #2

merged 3 commits into from
Jul 22, 2024

Conversation

yiweny
Copy link
Contributor

@yiweny yiweny commented Jul 10, 2024

Add idgnn model and test cases.
A slight difference is that we are not using resnet as encoder but only using the StypeWiseEncoder

@yiweny yiweny force-pushed the yyuan/add-id-gnn branch 11 times, most recently from c708da0 to c9444a2 Compare July 14, 2024 04:44
@yiweny yiweny force-pushed the yyuan/add-id-gnn branch from 58c2467 to ac1f221 Compare July 14, 2024 06:22
@yiweny yiweny changed the title wip add encoders wip add IDGNN Jul 18, 2024
@yiweny yiweny changed the title wip add IDGNN add IDGNN Jul 20, 2024
hybridgnn/nn/encoder.py Outdated Show resolved Hide resolved
hybridgnn/nn/encoder.py Outdated Show resolved Hide resolved
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
) -> Dict[NodeType, Tensor]:
x_dict = {
node_type: self.encoders[node_type](tf)[0].mean(axis=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we take mean here? Maybe sum is a little bit better. Or we should use ResNet etc.

Copy link
Contributor Author

@yiweny yiweny Jul 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really what to include resnet?
I want to compare 4 versions, idgnn, idgnn with resent, hybridgnn, hybridgnn with resnet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ResNet is for single table right? If we don't want to use ResNet, MLP https://github.com/pyg-team/pytorch-frame/blob/d81b8f7a9e0643fa553d7cb7a1343ef662fd6835/torch_frame/nn/models/mlp.py#L28 or sum may be a better choice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I changed to sum. Do you know why it's better to use sum? Is is because it's the best performing one in Kumo?

@yiweny yiweny requested a review from zechengz July 20, 2024 21:16
@yiweny yiweny force-pushed the yyuan/add-id-gnn branch from 357d288 to 3e67440 Compare July 20, 2024 21:39
@yiweny yiweny force-pushed the yyuan/add-id-gnn branch from 3e67440 to 7fed12d Compare July 20, 2024 21:40
Copy link
Contributor

@andyhuang-kumo andyhuang-kumo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments, mostly minor suggestions

hybridgnn/nn/encoder.py Outdated Show resolved Hide resolved
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
) -> Dict[NodeType, Tensor]:
x_dict = {
node_type: self.encoders[node_type](tf)[0].sum(axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is sum the only aggregation method here? If not, maybe provide it as an argument so it is easy to change later on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the code to support more advanced aggregations.

hybridgnn/nn/encoder.py Outdated Show resolved Hide resolved
(channels, channels), channels, aggr=aggr)
for edge_type in edge_types
},
aggr="sum",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the input argument has aggr="mean" but here aggr is hard coded

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's intended to use sum here? cc @zechengz

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use sum here for now. The aggr = "mean" is used for the SAGEConv aggregation. Here the aggr seems to have a different meaning, which aggregates embeddings for the same node type together (if I remembered correctly)

) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so standard GNN is basically just IDGNN without this id_awareness_emb ? These class can be reused to include standard GNN without ID awareness then just by making this optional, maybe as an argument?

test/nn/test_model.py Show resolved Hide resolved
batch = next(iter(train_loader))

assert len(batch[task.dst_entity_table].batch) > 0
model = IDGNN(data=data, col_stats_dict=col_stats_dict, num_layers=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, the aggr is hard coded in IDGNN so this sum here is actually not used

@yiweny yiweny merged commit a3c0583 into master Jul 22, 2024
2 checks passed
@akihironitta akihironitta deleted the yyuan/add-id-gnn branch July 29, 2024 10:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants