Skip to content

Commit

Permalink
wip scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 4, 2023
1 parent 0f55c17 commit 111be88
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
[
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxLCMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
Expand Down Expand Up @@ -705,6 +706,7 @@
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxLCMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]

if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxLCMScheduler
)
from ..pipeline_flax_utils import FlaxDiffusionPipeline
from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
Expand All @@ -49,7 +50,7 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler, FlaxLCMScheduler
],
dtype: jnp.dtype = jnp.float32,
):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
else:
_import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
_import_structure["scheduling_lcm_flax"] = ["FlaxLCMScheduler"]
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
Expand Down Expand Up @@ -165,6 +166,7 @@
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_lcm_flax import FlaxLCMScheduler
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
Expand Down
Loading

0 comments on commit 111be88

Please sign in to comment.