From a6c696c7de105e7691d432dd80102beec78d8fd4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:52:43 +0100 Subject: [PATCH] Generate guidance for flux (#2104) generate guidance --- optimum/onnxruntime/modeling_diffusion.py | 17 +++++++++++++++-- optimum/utils/input_generators.py | 4 ++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 79d302be449..66b08e1ef66 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -437,8 +437,21 @@ def to(self, device: Union[torch.device, str, int]): def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): return cls.load_config(config_name_or_path, **kwargs) - def _save_config(self, save_directory): - self.save_config(save_directory) + def _save_config(self, save_directory: Union[str, Path]): + model_dir = ( + self.model_save_dir + if not isinstance(self.model_save_dir, TemporaryDirectory) + else self.model_save_dir.name + ) + save_dir = Path(save_directory) + original_config = Path(model_dir) / self.config_name + if original_config.exists(): + if not save_dir.exists(): + save_dir.mkdir(parents=True) + + shutil.copy(original_config, save_dir) + else: + self.save_config(save_directory) @property def components(self) -> Dict[str, Any]: diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 0ac1805f97d..fbb77e6800a 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -1508,6 +1508,7 @@ class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator) SUPPORTED_INPUT_NAMES = ( "encoder_hidden_states", "pooled_projections", + "guidance", "txt_ids", ) @@ -1519,5 +1520,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int else [self.batch_size, self.sequence_length, 3] ) return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) + elif input_name == "guidance": + shape = [self.batch_size] + return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype) return super().generate(input_name, framework, int_dtype, float_dtype)