Skip to content

Commit

Permalink
Generate guidance for flux (#2104)
Browse files Browse the repository at this point in the history
generate guidance
  • Loading branch information
IlyasMoutawwakil authored Nov 26, 2024
1 parent 65a8a94 commit a6c696c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
17 changes: 15 additions & 2 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,21 @@ def to(self, device: Union[torch.device, str, int]):
def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
return cls.load_config(config_name_or_path, **kwargs)

def _save_config(self, save_directory):
self.save_config(save_directory)
def _save_config(self, save_directory: Union[str, Path]):
model_dir = (
self.model_save_dir
if not isinstance(self.model_save_dir, TemporaryDirectory)
else self.model_save_dir.name
)
save_dir = Path(save_directory)
original_config = Path(model_dir) / self.config_name
if original_config.exists():
if not save_dir.exists():
save_dir.mkdir(parents=True)

shutil.copy(original_config, save_dir)
else:
self.save_config(save_directory)

@property
def components(self) -> Dict[str, Any]:
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,7 @@ class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator)
SUPPORTED_INPUT_NAMES = (
"encoder_hidden_states",
"pooled_projections",
"guidance",
"txt_ids",
)

Expand All @@ -1519,5 +1520,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
else [self.batch_size, self.sequence_length, 3]
)
return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype)
elif input_name == "guidance":
shape = [self.batch_size]
return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)

0 comments on commit a6c696c

Please sign in to comment.