Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Mar 19, 2024
1 parent 6fc805a commit ae2f232
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
9 changes: 8 additions & 1 deletion python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ def set_encoder(model, g, config, train_task):
num_ffn_layers_in_gnn=config.num_ffn_layers_in_gnn)
elif model_encoder_type == "sage":
# we need to check if the graph is homogeneous
assert check_homo(g) == True, 'The graph is not a homogeneous graph'
assert check_homo(g) == True, \
'The graph is not a homogeneous graph, can not use sage model encoder'
# we need to set the num_layers -1 because there is an output layer that is hard coded.
gnn_encoder = SAGEEncoder(h_dim=config.hidden_size,
out_dim=config.hidden_size,
Expand All @@ -594,13 +595,19 @@ def set_encoder(model, g, config, train_task):
num_ffn_layers_in_gnn=config.num_ffn_layers_in_gnn,
norm=config.gnn_norm)
elif model_encoder_type == "gat":
# we need to check if the graph is homogeneous
assert check_homo(g) == True, \
'The graph is not a homogeneous graph, can not use gat model encoder'
gnn_encoder = GATEncoder(h_dim=config.hidden_size,
out_dim=config.hidden_size,
num_heads=config.num_heads,
num_hidden_layers=config.num_layers -1,
dropout=dropout,
num_ffn_layers_in_gnn=config.num_ffn_layers_in_gnn)
elif model_encoder_type == "gatv2":
# we need to check if the graph is homogeneous
assert check_homo(g) == True, \
'The graph is not a homogeneous graph, can not use gatv2 model encoder'
gnn_encoder = GATv2Encoder(h_dim=config.hidden_size,
out_dim=config.hidden_size,
num_heads=config.num_heads,
Expand Down
4 changes: 3 additions & 1 deletion python/graphstorm/model/gat_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,16 @@ class GATEncoder(GraphConvEncoder):
Hidden dimension
out_dim : int
Output dimension
num_head : int
num_heads : int
Number of multi-heads attention
num_hidden_layers : int
Number of hidden layers. Total GNN layers is equal to num_hidden_layers + 1. Default 1
dropout : float
Dropout. Default 0.
activation : callable, optional
Activation function. Default: None
last_layer_act: bool
Whether call activation function in the last GNN layer.
num_ffn_layers_in_gnn: int
Number of ngnn gnn layers between GNN layers
"""
Expand Down
4 changes: 3 additions & 1 deletion python/graphstorm/model/gatv2_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,16 @@ class GATv2Encoder(GraphConvEncoder):
Hidden dimension
out_dim : int
Output dimension
num_head : int
num_heads : int
Number of multi-heads attention
num_hidden_layers : int
Number of hidden layers. Total GNN layers is equal to num_hidden_layers + 1. Default 1
dropout : float
Dropout. Default 0.
activation : callable, optional
Activation function. Default: None
last_layer_act: bool
Whether call activation function in the last GNN layer.
num_ffn_layers_in_gnn: int
Number of ngnn gnn layers between GNN layers
"""
Expand Down

0 comments on commit ae2f232

Please sign in to comment.