-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add IDGNN #2
add IDGNN #2
Conversation
c708da0
to
c9444a2
Compare
hybridgnn/nn/encoder.py
Outdated
tf_dict: Dict[NodeType, torch_frame.TensorFrame], | ||
) -> Dict[NodeType, Tensor]: | ||
x_dict = { | ||
node_type: self.encoders[node_type](tf)[0].mean(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we take mean here? Maybe sum is a little bit better. Or we should use ResNet etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really what to include resnet?
I want to compare 4 versions, idgnn, idgnn with resent, hybridgnn, hybridgnn with resnet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ResNet is for single table right? If we don't want to use ResNet, MLP https://github.com/pyg-team/pytorch-frame/blob/d81b8f7a9e0643fa553d7cb7a1343ef662fd6835/torch_frame/nn/models/mlp.py#L28 or sum may be a better choice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I changed to sum. Do you know why it's better to use sum? Is is because it's the best performing one in Kumo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments, mostly minor suggestions
hybridgnn/nn/encoder.py
Outdated
tf_dict: Dict[NodeType, torch_frame.TensorFrame], | ||
) -> Dict[NodeType, Tensor]: | ||
x_dict = { | ||
node_type: self.encoders[node_type](tf)[0].sum(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is sum
the only aggregation method here? If not, maybe provide it as an argument so it is easy to change later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the code to support more advanced aggregations.
(channels, channels), channels, aggr=aggr) | ||
for edge_type in edge_types | ||
}, | ||
aggr="sum", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the input argument has aggr="mean"
but here aggr
is hard coded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's intended to use sum
here? cc @zechengz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use sum
here for now. The aggr = "mean"
is used for the SAGEConv
aggregation. Here the aggr
seems to have a different meaning, which aggregates embeddings for the same node type together (if I remembered correctly)
) -> Tensor: | ||
seed_time = batch[entity_table].seed_time | ||
x_dict = self.encoder(batch.tf_dict) | ||
# Add ID-awareness to the root node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so standard GNN is basically just IDGNN without this id_awareness_emb
? These class can be reused to include standard GNN without ID awareness then just by making this optional, maybe as an argument?
batch = next(iter(train_loader)) | ||
|
||
assert len(batch[task.dst_entity_table].batch) > 0 | ||
model = IDGNN(data=data, col_stats_dict=col_stats_dict, num_layers=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, the aggr
is hard coded in IDGNN so this sum
here is actually not used
Add idgnn model and test cases.
A slight difference is that we are not using resnet as encoder but only using the
StypeWiseEncoder