Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register BaseOutput subclasses as supported torch.utils._pytree nodes #5459

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/diffusers/utils/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ class BaseOutput(OrderedDict):
</Tip>
"""

def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.

This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
import torch.utils._pytree

torch.utils._pytree._register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)

def __post_init__(self):
class_fields = fields(self)

Expand Down
22 changes: 22 additions & 0 deletions tests/others/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import PIL.Image

from diffusers.utils.outputs import BaseOutput
from diffusers.utils.testing_utils import require_torch


@dataclass
Expand Down Expand Up @@ -69,3 +70,24 @@ def test_outputs_serialization(self):
assert dir(outputs_orig) == dir(outputs_copy)
assert dict(outputs_orig) == dict(outputs_copy)
assert vars(outputs_orig) == vars(outputs_copy)

@require_torch
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch
import torch.utils._pytree

data = np.random.rand(1, 3, 4, 4)
x = CustomOutput(images=data)
self.assertFalse(torch.utils._pytree._is_leaf(x))

expected_flat_outs = [data]
expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()])

actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)

unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)