From 72a9cc20c8e90da37d45a853674cf39ce36718af Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 30 Oct 2023 09:34:50 -0700 Subject: [PATCH] Support 'BaseOutput' and subclasses from 'diffusers' in dynamo (#111978) Extending the workarounds for `transformers` `ModelOutput` to cover `diffusers` `BaseOutput`. Together with https://github.com/huggingface/diffusers/pull/5459 it should unblock export for `diffusers` models. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111978 Approved by: https://github.com/jansel --- test/dynamo/test_base_output.py | 97 ++++++++++++++++++++++++++++++++ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/variables/dicts.py | 82 +++++++++++++++++---------- 3 files changed, 155 insertions(+), 30 deletions(-) create mode 100644 test/dynamo/test_base_output.py diff --git a/test/dynamo/test_base_output.py b/test/dynamo/test_base_output.py new file mode 100644 index 0000000000000..0db9c7d0cfa7f --- /dev/null +++ b/test/dynamo/test_base_output.py @@ -0,0 +1,97 @@ +# Owner(s): ["module: dynamo"] +import unittest.mock + +import torch + +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.testing import same + +try: + from diffusers.models import unet_2d +except ImportError: + unet_2d = None + + +def maybe_skip(fn): + if unet_2d is None: + return unittest.skip("requires diffusers")(fn) + return fn + + +class TestBaseOutput(torch._dynamo.test_case.TestCase): + @maybe_skip + def test_create(self): + def fn(a): + tmp = unet_2d.UNet2DOutput(a + 1) + return tmp + + torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=1) + + @maybe_skip + def test_assign(self): + def fn(a): + tmp = unet_2d.UNet2DOutput(a + 1) + tmp.sample = a + 2 + return tmp + + args = [torch.randn(10)] + obj1 = fn(*args) + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize_assert(cnts)(fn) + obj2 = opt_fn(*args) + self.assertTrue(same(obj1.sample, obj2.sample)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 2) + + def _common(self, fn, op_count): + args = [ + unet_2d.UNet2DOutput( + sample=torch.randn(10), + ) + ] + obj1 = fn(*args) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize_assert(cnts)(fn) + obj2 = opt_fn(*args) + self.assertTrue(same(obj1, obj2)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, op_count) + + @maybe_skip + def test_getattr(self): + def fn(obj: unet_2d.UNet2DOutput): + x = obj.sample * 10 + return x + + self._common(fn, 1) + + @maybe_skip + def test_getitem(self): + def fn(obj: unet_2d.UNet2DOutput): + x = obj["sample"] * 10 + return x + + self._common(fn, 1) + + @maybe_skip + def test_tuple(self): + def fn(obj: unet_2d.UNet2DOutput): + a = obj.to_tuple() + return a[0] * 10 + + self._common(fn, 1) + + @maybe_skip + def test_index(self): + def fn(obj: unet_2d.UNet2DOutput): + return obj[0] * 10 + + self._common(fn, 1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4efbb718b11db..357691d885b79 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -303,7 +303,11 @@ def _convert_frame_assert( ): return None if code.co_name == "" and code.co_filename.endswith( - ("transformers/file_utils.py", "transformers/utils/generic.py") + ( + "transformers/file_utils.py", + "transformers/utils/generic.py", + "diffusers/utils/outputs.py", + ) ): # not needed, but cleans up torchbench error stats return None diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index b353b96da6890..6126caabdd8a0 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -437,6 +437,30 @@ def unpack_var_sequence(self, tx): return [x.add_options(self) for x in self.items] +def _is_matching_transformers_cls(cls) -> bool: + if not cls.__module__.startswith("transformers."): + return False + + try: + from transformers.file_utils import ModelOutput + + return issubclass(cls, ModelOutput) + except ImportError: + return False + + +def _is_matching_diffusers_cls(cls) -> bool: + if not cls.__module__.startswith("diffusers."): + return False + + try: + from diffusers.utils import BaseOutput + + return issubclass(cls, BaseOutput) + except ImportError: + return False + + class DataClassVariable(ConstDictVariable): """ This is a bit of a hack to deal with @@ -452,20 +476,27 @@ class DataClassVariable(ConstDictVariable): @staticmethod @functools.lru_cache(None) def _patch_once(): - from transformers.file_utils import ModelOutput + try: + from transformers.file_utils import ModelOutput - for obj in ModelOutput.__dict__.values(): - if callable(obj): - skip_code(obj.__code__) + for obj in ModelOutput.__dict__.values(): + if callable(obj): + skip_code(obj.__code__) + except ImportError: + pass - @staticmethod - def is_matching_cls(cls): try: - from transformers.file_utils import ModelOutput + from diffusers.utils import BaseOutput - return issubclass(cls, ModelOutput) + for obj in BaseOutput.__dict__.values(): + if callable(obj): + skip_code(obj.__code__) except ImportError: - return False + pass + + @staticmethod + def is_matching_cls(cls): + return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) @classmethod def is_matching_object(cls, obj): @@ -578,26 +609,19 @@ def var_getattr(self, tx, name: str) -> "VariableTracker": class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls(cls): - try: - # True if using default OrderedDict.__init__ and did not implement __post_init__ - if ( - issubclass(cls, collections.OrderedDict) - and cls.__init__ is collections.OrderedDict.__init__ - and not hasattr(cls, "__post_init__") - ): - return True - # hack for HF usecase: - # assume dataclass annotation for ModelOutput subclass - # assume self.create is AA to ModelOutput.__post_init__ - # for non-HF usecase: - # check __module__ string to avoid costy HF import - if cls.__module__ != "transformers.modeling_outputs": - return False - from transformers.file_utils import ModelOutput - - return issubclass(cls, ModelOutput) - except ImportError: - return False + # True if using default OrderedDict.__init__ and did not implement __post_init__ + if ( + issubclass(cls, collections.OrderedDict) + and cls.__init__ is collections.OrderedDict.__init__ + and not hasattr(cls, "__post_init__") + ): + return True + # hack for HF usecase: + # assume dataclass annotation for ModelOutput subclass + # assume self.create is AA to ModelOutput.__post_init__ + # for non-HF usecase: + # check __module__ string to avoid costy HF import + return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) @classmethod def is_matching_object(cls, obj):