From e11d3ceff3a49378796cdff5b466586d877d5c60 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 22 Nov 2024 15:18:15 -0500 Subject: [PATCH] Allow for full dynamo config passed to Accelerator (#3251) * Allow for full dynamo config * Clean --- src/accelerate/accelerator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab949c42e43..9baad9b56df 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -224,6 +224,9 @@ class Accelerator: or mixed precision are created. See [kwargs](kwargs) for more information. dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `"no"`): Set to one of the possible dynamo backends to optimize your training with torch dynamo. + dynamo_plugin ([`~utils.TorchDynamoPlugin`], *optional*): + A configuration for how torch dynamo should be handled, if more tweaking than just the `backend` or `mode` + is needed. gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*): A configuration for how gradient accumulation should be handled, if more tweaking than just the `gradient_accumulation_steps` is needed. @@ -263,6 +266,7 @@ def __init__( step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: DynamoBackend | str | None = None, + dynamo_plugin: TorchDynamoPlugin | None = None, deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None, ): self.trackers = [] @@ -279,7 +283,12 @@ def __init__( f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}" ) - dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend) + if dynamo_plugin is not None and dynamo_backend is not None: + raise ValueError("You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.") + if dynamo_backend is not None: + dynamo_plugin = TorchDynamoPlugin(backend=dynamo_backend) + elif dynamo_plugin is None: + dynamo_plugin = TorchDynamoPlugin() if deepspeed_plugins is not None and deepspeed_plugin is not None: raise ValueError("You cannot pass in both `deepspeed_plugins` and `deepspeed_plugin`.")