From 368de363c0b4e070f737d3965f279ed016f96565 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 19 Oct 2023 15:11:35 -0700 Subject: [PATCH 1/2] Register BaseOutput subclasses as supported torch.utils._pytree nodes --- src/diffusers/utils/outputs.py | 15 +++++++++++++++ tests/others/test_outputs.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 802c699eb9cc..a057b506aec0 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -51,6 +51,21 @@ class BaseOutput(OrderedDict): """ + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + if is_torch_available(): + import torch.utils._pytree + + torch.utils._pytree._register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + def __post_init__(self): class_fields = fields(self) diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 492e71f0ba31..82f40f0f2e0e 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -7,6 +7,7 @@ import PIL.Image from diffusers.utils.outputs import BaseOutput +from diffusers.utils.testing_utils import require_torch @dataclass @@ -69,3 +70,26 @@ def test_outputs_serialization(self): assert dir(outputs_orig) == dir(outputs_copy) assert dict(outputs_orig) == dict(outputs_copy) assert vars(outputs_orig) == vars(outputs_copy) + + @require_torch + def test_torch_pytree(self): + # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) + # this is important for DistributedDataParallel gradient synchronization with static_graph=True + import torch + import torch.utils._pytree + + data = np.random.rand(1, 3, 4, 4) + x = CustomOutput(images=data) + self.assertFalse(torch.utils._pytree._is_leaf(x)) + + expected_flat_outs = [data] + expected_tree_spec = torch.utils._pytree.TreeSpec( + CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()] + ) + + actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) + self.assertEqual(expected_flat_outs, actual_flat_outs) + self.assertEqual(expected_tree_spec, actual_tree_spec) + + unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) + self.assertEqual(x, unflattened_x) From 6f76eb4c5a79db7b9fe5fdd450915f40039e53db Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 20 Oct 2023 09:56:54 -0700 Subject: [PATCH 2/2] lint --- tests/others/test_outputs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 82f40f0f2e0e..cf709d93f709 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -83,9 +83,7 @@ def test_torch_pytree(self): self.assertFalse(torch.utils._pytree._is_leaf(x)) expected_flat_outs = [data] - expected_tree_spec = torch.utils._pytree.TreeSpec( - CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()] - ) + expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()]) actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs)