Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new resnet50+convergence+spmd test #1013

Closed

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Nov 11, 2023

Description

Add a new test for Resnet50+SPMD for convergence. Use batch sharding.

@yeounoh yeounoh requested review from jonb377 and ManfeiBai November 11, 2023 01:40
@yeounoh yeounoh self-assigned this Nov 11, 2023
@yeounoh yeounoh requested a review from alanwaketan November 11, 2023 01:40
@jonb377
Copy link
Collaborator

jonb377 commented Nov 12, 2023

@yeounoh Do you have a oneshot run for this?

@jonb377
Copy link
Collaborator

jonb377 commented Nov 12, 2023

I tried a local run, it's very slow:

Epoch 1 train begin 22:36:40
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=7.11059 Rate=9.68 GlobalRate=9.68 Time=22:38:26
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=6.48232 Rate=67.85 GlobalRate=101.57 Time=23:10:26

We might need to optimize our resnet training script some more to avoid timeouts.

EDIT: It seems like dataloading is the bottleneck, PjRt convergence tests are also pretty slow: http://shortn/_whmMQ4flXg

jonb377
jonb377 previously approved these changes Nov 13, 2023
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Yeounoh!

@@ -209,5 +209,6 @@ local tpus = import 'templates/tpus.libsonnet';
// SPMD
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['batch']),
resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['spatial']),
resnet50 + convergence + v4_8 + timeouts.Hours(2) + spmd(['batch']),
Copy link
Collaborator

@jonb377 jonb377 Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the timeout 14 hours to match the PjRt tests. I've kicked off a oneshot here, let's merge if it passes: http://shortn/_b8rAgZ0mkY

cloudtop [~/d/x/1/ml-testing-accelerators] % git diff
diff --git a/tests/pytorch/nightly/resnet50-mp.libsonnet b/tests/pytorch/nightly/resnet50-mp.libsonnet
index b24e5bc5..ba86abab 100644
--- a/tests/pytorch/nightly/resnet50-mp.libsonnet
+++ b/tests/pytorch/nightly/resnet50-mp.libsonnet
@@ -209,5 +209,6 @@ local tpus = import 'templates/tpus.libsonnet';
     // SPMD
     resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['batch']),
     resnet50 + functional + v4_8 + timeouts.Hours(2) + spmd(['spatial']),
+    resnet50 + convergence + v4_8 + timeouts.Hours(14) + spmd(['batch']),
   ],
 }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it failed due to RESOURCE_EXHAUSTED after an epoch. Is this the same error you were running into for the convergence test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess this is related to the eval loop's dataloader sharding. I also noticed that our rate is still ~1/8 that of the regular PjRt tests - probably data processing for all devices in a single thread.

@alanwaketan
Copy link
Collaborator

I tried a local run, it's very slow:

Epoch 1 train begin 22:36:40
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=7.11059 Rate=9.68 GlobalRate=9.68 Time=22:38:26
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=6.48232 Rate=67.85 GlobalRate=101.57 Time=23:10:26

We might need to optimize our resnet training script some more to avoid timeouts.

EDIT: It seems like dataloading is the bottleneck, PjRt convergence tests are also pretty slow: http://shortn/_whmMQ4flXg

If that's the case, can we use bert or some other language models?

@yeounoh
Copy link
Contributor Author

yeounoh commented Nov 14, 2023

I tried a local run, it's very slow:

Epoch 1 train begin 22:36:40
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=7.11059 Rate=9.68 GlobalRate=9.68 Time=22:38:26
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=6.48232 Rate=67.85 GlobalRate=101.57 Time=23:10:26

We might need to optimize our resnet training script some more to avoid timeouts.
EDIT: It seems like dataloading is the bottleneck, PjRt convergence tests are also pretty slow: http://shortn/_whmMQ4flXg

If that's the case, can we use bert or some other language models?

Yea, we can just use llama, but it would require us to emit the correct tb metrics. Will follow up with @will-cromar @jonb377

@jonb377 jonb377 self-requested a review November 14, 2023 01:13
@jonb377 jonb377 dismissed their stale review November 14, 2023 01:13

Convergence too slow

@yeounoh yeounoh marked this pull request as draft November 14, 2023 20:28
@alanwaketan
Copy link
Collaborator

Curious on are we going to use resnet or not for the convergence test?

@jonb377
Copy link
Collaborator

jonb377 commented Dec 11, 2023

Curious on are we going to use resnet or not for the convergence test?

I chatted with @will-cromar - there's still hope for resnet. We can increase the number of worker threads to better match the MP PjRt performance without reinventing the SPMD dataloader.

I'll test this out later today (run with --num_workers 64, it's still pretty slow)

@yeounoh yeounoh closed this Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants