Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax pipeline pndm #583

Merged
merged 30 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b9ca406
WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
mishig25 Sep 19, 2022
30abc63
todo comment
mishig25 Sep 19, 2022
9b54559
Merge branch 'main' into flax_pipeline
Sep 19, 2022
4b2becb
Fix imports
mishig25 Sep 19, 2022
7f0e429
Fix imports
mishig25 Sep 19, 2022
d9e2ae1
add dummies
patrickvonplaten Sep 19, 2022
d51e881
Fix empty init
mishig25 Sep 19, 2022
741046d
Merge branch 'flax_pipeline' of https://github.com/huggingface/diffus…
mishig25 Sep 19, 2022
7aab68d
make pipeline work
patrickvonplaten Sep 19, 2022
7d3fff6
merge conflict
patrickvonplaten Sep 19, 2022
47d7739
up
patrickvonplaten Sep 19, 2022
4dfcf21
Allow dtype to be overridden on model load.
pcuenca Sep 20, 2022
d480534
Convert params to bfloat16 or fp16 after loading.
pcuenca Sep 20, 2022
0c2a868
Use Flax schedulers (typing, docstring)
pcuenca Sep 20, 2022
a71e6be
Merge branch 'flax_pipeline' into flax_pipeline_bf16
pcuenca Sep 20, 2022
aa3c010
PNDM: replace control flow with jax functions.
pcuenca Sep 19, 2022
d6dbb89
Pass latents shape to scheduler set_timesteps()
pcuenca Sep 20, 2022
69b1d7a
Wrap model imports inside availability checks.
pcuenca Sep 20, 2022
7091c1d
Merge branch 'flax_pipeline' into flax_pipeline_pndm
pcuenca Sep 20, 2022
23f7d73
Optionally return state in from_config.
pcuenca Sep 20, 2022
163df38
Do not convert model weights to dtype.
pcuenca Sep 20, 2022
039d1d6
Merge branch 'flax_pipeline_bf16' into flax_pipeline_pndm
pcuenca Sep 20, 2022
8bc06b0
Re-enable PRK steps with functional implementation.
pcuenca Sep 20, 2022
3752bbc
Merge remote-tracking branch 'origin/main' into flax_pipeline_pndm
pcuenca Sep 21, 2022
8a9ccf2
Remove left over has_state var.
pcuenca Sep 21, 2022
cf6cd7a
make style
pcuenca Sep 21, 2022
f974a41
Apply suggestion list -> tuple
pcuenca Sep 22, 2022
ce0a327
Apply suggestion list -> tuple
pcuenca Sep 22, 2022
7fcbc32
Remove unused comments.
pcuenca Sep 22, 2022
cd17c56
Use zeros instead of empty.
pcuenca Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
Expand All @@ -76,3 +77,8 @@
)
else:
from .utils.dummy_flax_objects import * # noqa F403

if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
16 changes: 13 additions & 3 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,25 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret

"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)

init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)

# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")

# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
return_tuple = (model,)

# Some components (Flax schedulers) have a state.
if getattr(cls, "has_state", False): # Check for "create_state" in model instead?
state = model.create_state()
return_tuple += (state,)
pcuenca marked this conversation as resolved.
Show resolved Hide resolved

if return_unused_kwargs:
return model, unused_kwargs
return return_tuple + (unused_kwargs,)
else:
return model
return return_tuple if len(return_tuple) > 1 else model

@classmethod
def get_config_dict(
Expand Down
13 changes: 10 additions & 3 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
from ..utils import is_torch_available, is_flax_available

if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel

if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
Loading