-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a new resnet50+convergence+spmd test #1013
Conversation
@yeounoh Do you have a oneshot run for this? |
I tried a local run, it's very slow:
We might need to optimize our resnet training script some more to avoid timeouts. EDIT: It seems like dataloading is the bottleneck, PjRt convergence tests are also pretty slow: http://shortn/_whmMQ4flXg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Yeounoh!
@@ -209,5 +209,6 @@ local tpus = import 'templates/tpus.libsonnet'; | |||
// SPMD | |||
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['batch']), | |||
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['spatial']), | |||
resnet50 + convergence + v4_8 + timeouts.Hours(2) + spmd(['batch']), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make the timeout 14 hours to match the PjRt tests. I've kicked off a oneshot here, let's merge if it passes: http://shortn/_b8rAgZ0mkY
cloudtop [~/d/x/1/ml-testing-accelerators] % git diff
diff --git a/tests/pytorch/nightly/resnet50-mp.libsonnet b/tests/pytorch/nightly/resnet50-mp.libsonnet
index b24e5bc5..ba86abab 100644
--- a/tests/pytorch/nightly/resnet50-mp.libsonnet
+++ b/tests/pytorch/nightly/resnet50-mp.libsonnet
@@ -209,5 +209,6 @@ local tpus = import 'templates/tpus.libsonnet';
// SPMD
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['batch']),
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['spatial']),
+ resnet50 + convergence + v4_8 + timeouts.Hours(14) + spmd(['batch']),
],
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it failed due to RESOURCE_EXHAUSTED after an epoch. Is this the same error you were running into for the convergence test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would guess this is related to the eval loop's dataloader sharding. I also noticed that our rate is still ~1/8 that of the regular PjRt tests - probably data processing for all devices in a single thread.
4b52d17
to
2967337
Compare
If that's the case, can we use bert or some other language models? |
Yea, we can just use llama, but it would require us to emit the correct tb metrics. Will follow up with @will-cromar @jonb377 |
Curious on are we going to use resnet or not for the convergence test? |
I chatted with @will-cromar - there's still hope for resnet. We can increase the number of worker threads to better match the MP PjRt performance without reinventing the SPMD dataloader. I'll test this out later today (run with |
Description
Add a new test for Resnet50+SPMD for convergence. Use batch sharding.