Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-task Learning] Add multi-task trainer #849

Merged
merged 90 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 85 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
a59b038
Drafting multi-task learning
May 3, 2024
8825cb1
Update
May 3, 2024
82489ed
Update
May 6, 2024
3efbaa1
Update gsf
May 6, 2024
e82da16
Update
May 6, 2024
9a8f8b4
update
May 6, 2024
00f1211
Update
May 6, 2024
e8573e4
Merge branch 'master' into multi-task-trainer
May 7, 2024
2813d3a
Update
May 7, 2024
20d306a
update
May 9, 2024
9613c6a
Update
May 9, 2024
166d9d5
Update
May 9, 2024
ab62c45
Add unit tests
May 9, 2024
b6a7267
Fix bugs
May 10, 2024
71ad941
Fix CI
May 10, 2024
6344f91
Merge branch 'multi-task' into multi-task-trainer
May 11, 2024
75b5435
Update
May 13, 2024
729a5e1
Update
May 13, 2024
515a642
update
May 13, 2024
23c24ca
Add test
May 13, 2024
3c86ab3
Fix
May 13, 2024
e541837
update
May 14, 2024
1dcc4aa
add test for evaluator
May 14, 2024
1335ce1
Update multi-task evaluator
May 14, 2024
d4a74a9
Add movielens test data
May 15, 2024
858b751
update
May 15, 2024
9de160b
Update
May 15, 2024
0c630ce
Update
May 15, 2024
771adf7
Add multi-task entry point
May 15, 2024
c79b15b
Merge branch 'multi-task' into multi-task-trainer
May 15, 2024
0d9f30d
Update
May 15, 2024
6ca1d65
Fix some bugs
May 16, 2024
1fba83a
Merge branch 'multi-task' into multi-task-trainer
May 16, 2024
b265660
Update
May 16, 2024
3b6f8bf
Update
May 16, 2024
3775b94
Update
May 16, 2024
5c41c8a
Fix some bugs
May 16, 2024
25f4817
Fix bugs
May 16, 2024
216c610
Update
May 17, 2024
017c00f
Update
May 17, 2024
32d905f
Merge branch 'multi-task' into multi-task-trainer
May 19, 2024
a778c08
Merge branch 'multi-task' into multi-task-trainer
May 20, 2024
3b592c4
clean up duplicated code
May 20, 2024
d7e0405
Update
May 20, 2024
08e3fe6
update init
May 20, 2024
4945c2c
update ep_gnn.py
May 20, 2024
d0b37b4
update lp_gnn.py
May 20, 2024
46da6ca
update
May 20, 2024
1dc4dcc
update
May 20, 2024
9b19427
update
May 20, 2024
28ec04c
Add unitests
May 21, 2024
2675b6d
Update docstr
May 21, 2024
775d4b2
Update
May 21, 2024
44c5357
update
May 21, 2024
c719891
Add test
May 21, 2024
958e2b9
Update
May 21, 2024
87e1df2
Merge branch 'multi-task-refactor-nn' into multi-task-trainer
May 21, 2024
4f56db3
Merge branch 'multi-task' into multi-task-trainer
classicsong May 21, 2024
04ec3e9
Add test for multitask_gnn.py
May 23, 2024
601aa1b
Add GSgnnMultiTaskSharedEncoderModel
May 23, 2024
4a3ec01
Add unitests
May 23, 2024
46125ca
Update
May 23, 2024
2403930
update dataloader
May 23, 2024
caa851a
Fix lint
May 23, 2024
e3e33f8
Fix lint
May 24, 2024
36f4c60
update
May 24, 2024
3543281
Merge branch 'multi-task-model' into multi-task-trainer
May 24, 2024
4a0a7df
Fix DDP bug
May 24, 2024
424b4b2
Update
May 24, 2024
f62e035
update
May 24, 2024
d8dd046
update test
May 24, 2024
1a5a165
Update
May 24, 2024
1d92008
Update
May 24, 2024
d6c2cb3
Update
May 24, 2024
c90fb5f
Update
May 26, 2024
3a98131
Update
May 26, 2024
9baac44
Update
May 27, 2024
6003cdc
update
May 27, 2024
a7c14e3
Update
May 27, 2024
2a9aa88
Fix lint
May 27, 2024
fc45ea2
Fix lint
May 27, 2024
fd509c9
Merge branch 'multi-task-model' into multi-task-trainer
May 27, 2024
7f9947c
Update
May 27, 2024
2da6621
update
May 28, 2024
bdd41cb
Merge branch 'multi-task' into multi-task-trainer
May 28, 2024
2818176
Update
May 28, 2024
0340223
Update
May 29, 2024
6442b60
Update
May 29, 2024
0e7c660
Update
May 30, 2024
8ffeea6
update
May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflow_scripts/e2e_mgpu_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ sh ./tests/end2end-tests/create_data.sh
bash ./tests/end2end-tests/graphstorm-lp/mgpu_test.sh
bash ./tests/end2end-tests/graphstorm-nc/mgpu_test.sh
bash ./tests/end2end-tests/graphstorm-ec/mgpu_test.sh
bash ./tests/end2end-tests/graphstorm-mt/mgpu_test.sh

5 changes: 5 additions & 0 deletions python/graphstorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,10 @@
from .gsf import create_builtin_lp_model
from .gsf import create_builtin_edge_model
from .gsf import create_builtin_node_model
from .gsf import create_task_decoder

from .gsf import (create_builtin_node_decoder,
create_builtin_edge_decoder,
create_builtin_lp_decoder)
from .gsf import (get_builtin_lp_train_dataloader_class,
get_builtin_lp_eval_dataloader_class)
26 changes: 18 additions & 8 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def __init__(self, cmd_args):
configuration = self.load_yaml_config(cmd_args.yaml_config_file)

multi_task_config = None
if 'multi_task_learning' in configuration:
multi_task_config = configuration['multi_task_learning']
del configuration['multi_task_learning']
if 'multi_task_learning' in configuration['gsf']:
multi_task_config = configuration['gsf']['multi_task_learning']
del configuration['gsf']['multi_task_learning']

self.set_attributes(configuration)
# Override class attributes using command-line arguments
Expand Down Expand Up @@ -305,7 +305,9 @@ def _parse_node_classification_task(self, task_config):
task_id = get_mttask_id(task_type=task_type,
ntype=target_ntype,
label=label_field)
setattr(task_info, "mask_fields", mask_fields)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)

return TaskInfo(task_type=task_type,
Expand Down Expand Up @@ -336,7 +338,9 @@ def _parse_node_regression_task(self, task_config):
task_id = get_mttask_id(task_type=task_type,
ntype=target_ntype,
label=label_field)
setattr(task_info, "mask_fields", mask_fields)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)

return TaskInfo(task_type=task_type,
Expand Down Expand Up @@ -367,7 +371,9 @@ def _parse_edge_classification_task(self, task_config):
task_id = get_mttask_id(task_type=task_type,
etype=target_etype,
label=label_field)
setattr(task_info, "mask_fields", mask_fields)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)
return TaskInfo(task_type=task_type,
task_id=task_id,
Expand Down Expand Up @@ -397,7 +403,9 @@ def _parse_edge_regression_task(self, task_config):
task_id = get_mttask_id(task_type=task_type,
etype=target_etype,
label=label_field)
setattr(task_info, "mask_fields", mask_fields)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)
return TaskInfo(task_type=task_type,
task_id=task_id,
Expand Down Expand Up @@ -425,7 +433,9 @@ def _parse_link_prediction_task(self, task_config):
task_id = get_mttask_id(
task_type=task_type,
etype=train_etype if train_etype is not None else "ALL_ETYPE")
setattr(task_info, "mask_fields", mask_fields)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)
return TaskInfo(task_type=task_type,
task_id=task_id,
Expand Down
37 changes: 33 additions & 4 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@

from .utils import sys_tracker, get_rank
from .utils import setup_device
from .config import BUILTIN_TASK_NODE_CLASSIFICATION
from .config import BUILTIN_TASK_NODE_REGRESSION
from .config import BUILTIN_TASK_EDGE_CLASSIFICATION
from .config import BUILTIN_TASK_EDGE_REGRESSION
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
classicsong marked this conversation as resolved.
Show resolved Hide resolved
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
from .config import BUILTIN_LP_DOT_DECODER
from .config import BUILTIN_LP_DISTMULT_DECODER
from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY,
Expand Down Expand Up @@ -842,3 +843,31 @@ def get_builtin_lp_train_dataloader_class(config):
raise ValueError('Unknown negative sampler')

return dataloader_cls

def create_task_decoder(task_info, g, decoder_input_dim, train_task):
""" Create task decoders according to task_info.
classicsong marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
task_info: TaskInfo
Task info.
g: Dist DGLGraph
Graph
decoder_input_dim: int
The dimension of the input embedding of the decoder
train_task: bool
Whether the task is a training task

Return
------
decoder: The node task decoder(s)
classicsong marked this conversation as resolved.
Show resolved Hide resolved
loss_func: The loss function(s)
classicsong marked this conversation as resolved.
Show resolved Hide resolved
"""
if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]:
return create_builtin_node_decoder(g, decoder_input_dim, task_info.task_config, train_task)
elif task_info.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]:
return create_builtin_edge_decoder(g, decoder_input_dim, task_info.task_config, train_task)
elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
return create_builtin_lp_decoder(g, decoder_input_dim, task_info.task_config, train_task)

classicsong marked this conversation as resolved.
Show resolved Hide resolved
return None, None
2 changes: 1 addition & 1 deletion python/graphstorm/inference/lp_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def infer(self, data, loader, save_embed_path,
save_embed_format : str
Specify the format of saved embeddings.
infer_batch_size: int
Specify the inference batch size when computing node embeddings
Specify the inference batch size when computing node embeddings
with mini batch inference.
"""
sys_tracker.check('start inferencing')
Expand Down
52 changes: 52 additions & 0 deletions python/graphstorm/run/gs_multi_task_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Entry point for running multi-task learning.

Run as:
python3 -m graphstorm.run.gs_multi_task_learning <Launch args> <Train/Infer args>
"""
import os
import logging

from .launch import get_argument_parser
from .launch import check_input_arguments
from .launch import submit_jobs

def main():
""" Main function
"""
parser = get_argument_parser()
args, exec_script_args = parser.parse_known_args()
check_input_arguments(args)

lib_dir = os.path.abspath(os.path.dirname(__file__))
if args.inference:
cmd = "gsgnn_mt/gsgnn_infer_mt.py"
else:
cmd = "gsgnn_mt/gsgnn_mt.py"
cmd_path = os.path.join(lib_dir, cmd)
exec_script_args = [cmd_path] + exec_script_args

if "coo" not in args.graph_format:
args.graph_format = f"{args.graph_format},coo"
logging.debug("Automatically add COO format to graph formats for link prediction. " + \
"New graph_format is %s", args.graph_format)
submit_jobs(args, exec_script_args)

if __name__ == "__main__":
FMT = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=FMT, level=logging.INFO)
main()
Empty file.
Loading
Loading