Skip to content

Commit

Permalink
Allow for full dynamo config passed to Accelerator (#3251)
Browse files Browse the repository at this point in the history
* Allow for full dynamo config

* Clean
  • Loading branch information
muellerzr authored Nov 22, 2024
1 parent 08101b9 commit e11d3ce
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class Accelerator:
or mixed precision are created. See [kwargs](kwargs) for more information.
dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `"no"`):
Set to one of the possible dynamo backends to optimize your training with torch dynamo.
dynamo_plugin ([`~utils.TorchDynamoPlugin`], *optional*):
A configuration for how torch dynamo should be handled, if more tweaking than just the `backend` or `mode`
is needed.
gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*):
A configuration for how gradient accumulation should be handled, if more tweaking than just the
`gradient_accumulation_steps` is needed.
Expand Down Expand Up @@ -263,6 +266,7 @@ def __init__(
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: DynamoBackend | str | None = None,
dynamo_plugin: TorchDynamoPlugin | None = None,
deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
):
self.trackers = []
Expand All @@ -279,7 +283,12 @@ def __init__(
f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
)

dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend)
if dynamo_plugin is not None and dynamo_backend is not None:
raise ValueError("You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.")
if dynamo_backend is not None:
dynamo_plugin = TorchDynamoPlugin(backend=dynamo_backend)
elif dynamo_plugin is None:
dynamo_plugin = TorchDynamoPlugin()

if deepspeed_plugins is not None and deepspeed_plugin is not None:
raise ValueError("You cannot pass in both `deepspeed_plugins` and `deepspeed_plugin`.")
Expand Down

0 comments on commit e11d3ce

Please sign in to comment.