Skip to content

Commit

Permalink
Merge branch 'multi-task-trainer' into multi-task
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed May 31, 2024
2 parents 062cde0 + 14a49a0 commit 0303ceb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
21 changes: 13 additions & 8 deletions python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,26 @@ def forward(self, task_mini_batches):
""" The forward function for multi-task learning
It will iterate over the mini-batches and call
forward for each task.
Return
mt_loss: overall loss
losses: per task loss (used for debug)
"""
losses = []
for (task_info, mini_batch) in task_mini_batches:
loss, weight = self._run_mini_batch(task_info, mini_batch)
losses.append((loss, weight))

reg_loss = th.tensor(0.).to(losses[0][0].device)
for d_para in self.get_dense_params():
reg_loss += d_para.square().sum()
alpha_l2norm = self.alpha_l2norm
reg_loss = th.tensor(0.).to(losses[0][0].device)
for d_para in self.get_dense_params():
reg_loss += d_para.square().sum()
alpha_l2norm = self.alpha_l2norm

mt_loss = reg_loss * alpha_l2norm
for loss, weight in losses:
mt_loss += loss * weight

mt_loss = reg_loss * alpha_l2norm
for loss, weight in losses:
mt_loss += loss * weight
return mt_loss
return mt_loss, losses

# pylint: disable=unused-argument
def _forward(self, task_id, encoder_data, decoder_data):
Expand Down
1 change: 0 additions & 1 deletion python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from graphstorm.dataloading import (GSgnnNodeDataLoader,
GSgnnEdgeDataLoader,
GSgnnMultiTaskDataLoader)

from graphstorm.eval import (GSgnnClassificationEvaluator,
GSgnnRegressionEvaluator,
GSgnnPerEtypeMrrLPEvaluator,
Expand Down
7 changes: 6 additions & 1 deletion python/graphstorm/trainer/mt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def fit(self, train_loader,
mini_batches.append((task_info, \
self._prepare_mini_batch(data, task_info, mini_batch, device)))

loss = model(mini_batches)
loss, task_losses = model(mini_batches)

rt_profiler.record('train_forward')
self.optimizer.zero_grad()
Expand All @@ -377,8 +377,13 @@ def fit(self, train_loader,

if i % 20 == 0 and get_rank() == 0:
rt_profiler.print_stats()
per_task_loss = {}
for mini_batch, task_loss in zip(mini_batches, task_losses):
task_info, _ = mini_batch
per_task_loss[task_info.task_id] = task_loss.item()
logging.info("Epoch %05d | Batch %03d | Train Loss: %.4f | Time: %.4f",
epoch, i, loss.item(), time.time() - batch_tic)
logging.debug("Per task Loss: %s", per_task_loss)

val_score = None
if self.evaluator is not None and \
Expand Down
12 changes: 12 additions & 0 deletions training_scripts/gsgnn_mt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,16 @@ python3 -m graphstorm.gconstruct.construct_graph \

## Run the example
```
python3 -m graphstorm.run.gs_multi_task_learning \
--workspace $GS_HOME/training_scripts/gsgnn_mt \
--num-trainers 4 \
--num-servers 1 \
--num-samplers 0 \
--part-config movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt \
--ssh-port 2222 \
--cf ml_nc_ec_er_lp.yaml \
--save-model-path /data/gsgnn_mt/ \
--save-model-frequency 1000 \
--logging-file /tmp/train_log.txt \
--logging-level debug
```

0 comments on commit 0303ceb

Please sign in to comment.