Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup module imports #26308

Closed
apoorvkh opened this issue Sep 21, 2023 · 9 comments
Closed

Speedup module imports #26308

apoorvkh opened this issue Sep 21, 2023 · 9 comments

Comments

@apoorvkh
Copy link
Contributor

Feature request

Can we please consider importing the deepspeed module when needed, rather than in the import header of trainer.py?

Motivation

When deepspeed is installed, from transformers import Trainer takes a long time!

On my system that's 9 seconds!

>>> import timeit; timeit.timeit("from transformers import Trainer")
[2023-09-20 23:49:13,899] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
8.906949461437762

I believe this import is the culprit. As we can see, it takes 8.5 seconds of the load time.

from accelerate.utils import DeepSpeedSchedulerWrapper

>>> timeit.timeit("from accelerate.utils import DeepSpeedSchedulerWrapper")
[2023-09-20 23:45:53,185] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
8.525534554384649

This is quite cumbersome, because all scripts that import Trainer (e.g. even for typing) are impacted!

Your contribution

Happy to submit a PR. We could make this a class variable or just import it directly at both places it's used.

is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)

if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening this issue, are you using main. There was a PR recently to fix this, see #26090 and #26106

@apoorvkh
Copy link
Contributor Author

I am indeed using main (specifically, transformers[deepspeed] at commit 382ba67)!

@apoorvkh
Copy link
Contributor Author

The code I mentioned above is run directly in the header of trainer.py. And, if I understand correctly, I think accelerate is not covered by the Lazy imports in #26090.

if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper

@ArthurZucker
Copy link
Collaborator

Cc @younesbelkada I think you mentioned that accelerate is the bottleneck that we can’t get rid of no?

@younesbelkada
Copy link
Contributor

Hi @apoorvkh
huggingface/accelerate#1963 being merged in accelerate I think you can switch to accelerate main and see if it resolves your issue

@apoorvkh
Copy link
Contributor Author

Hey, thanks! I think that commit (huggingface/accelerate@5dec654) reduces the runtime for the import from 8-9 seconds to 3-4 seconds (on my machine). That is still not ideal but is certainly more tolerable.

@younesbelkada
Copy link
Contributor

younesbelkada commented Sep 28, 2023

Thanks!
Hm I see ok, I am curious what module takes so much time for import, would you be able to run a quick benchmark with tuna and share the results here?

# benchmark
python -X importtime -c "import transformers" 2> transformers-import-profile.log

# visualize
tuna <path to log file>

@apoorvkh
Copy link
Contributor Author

apoorvkh commented Sep 29, 2023

For sure. That's a nice tool!

Really quickly, I found that from transformers import Trainer was particularly taking 4 seconds to import -- whereas import transformers is actually faster (< 1 second).

We can see the result for from transformers import Trainer below:

image

Also, for from transformers import TrainingArguments:

image

And we can compare to import transformers:

image

Seems like accelerate is no longer the biggest culprit. A lot of time is also spent importing torch.

My point is that we sometimes just import these tools for typing purposes or in an interactive terminal for later use. From a developer perspective, it would be more convenient to have fast imports and move the time-consuming parts to the moment we actually want to init/use the modules (and are actually expecting to expend time). Thanks!

@apoorvkh apoorvkh changed the title Lazy import deepspeed module (10 sec speedup for from transformers import Trainer) Speedup module imports Sep 29, 2023
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Nov 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants