-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sure DDPM and diffusers
can be used without Transformers
#5668
Changes from 6 commits
f3cc528
50cb9ce
f35bdb8
325370f
a73de37
d91ffe9
56ebaee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,15 @@ | |
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to do this change? Don't think it's necessary no? |
||
from ..utils import logging | ||
|
||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): | ||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules | ||
|
||
for _, attn_module in text_encoder_attn_modules(text_encoder): | ||
if isinstance(attn_module.q_proj, PatchedLoraProjection): | ||
attn_module.q_proj.lora_scale = lora_scale | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,40 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import ( | ||
DIFFUSERS_SLOW_IMPORT, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice clean-up! |
||
OptionalDependencyNotAvailable, | ||
_LazyModule, | ||
get_objects_from_module, | ||
is_torch_available, | ||
is_transformers_available, | ||
) | ||
|
||
|
||
_import_structure = { | ||
"pipeline_latent_consistency_img2img": ["LatentConsistencyModelImg2ImgPipeline"], | ||
"pipeline_latent_consistency_text2img": ["LatentConsistencyModelPipeline"], | ||
} | ||
_dummy_objects = {} | ||
_import_structure = {} | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline | ||
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline | ||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
from ...utils import dummy_torch_and_transformers_objects # noqa F403 | ||
|
||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) | ||
else: | ||
_import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"] | ||
_import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"] | ||
|
||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: | ||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
|
||
except OptionalDependencyNotAvailable: | ||
from ...utils.dummy_torch_and_transformers_objects import * | ||
else: | ||
from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline | ||
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline | ||
|
||
else: | ||
import sys | ||
|
@@ -24,3 +45,6 @@ | |
_import_structure, | ||
module_spec=__spec__, | ||
) | ||
|
||
for name, value in _dummy_objects.items(): | ||
setattr(sys.modules[__name__], name, value) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,48 @@ | ||
from .pipeline_pixart_alpha import PixArtAlphaPipeline | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice clean-up! |
||
DIFFUSERS_SLOW_IMPORT, | ||
OptionalDependencyNotAvailable, | ||
_LazyModule, | ||
get_objects_from_module, | ||
is_torch_available, | ||
is_transformers_available, | ||
) | ||
|
||
|
||
_dummy_objects = {} | ||
_import_structure = {} | ||
|
||
|
||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
from ...utils import dummy_torch_and_transformers_objects # noqa F403 | ||
|
||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) | ||
else: | ||
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] | ||
|
||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: | ||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
|
||
except OptionalDependencyNotAvailable: | ||
from ...utils.dummy_torch_and_transformers_objects import * | ||
else: | ||
from .pipeline_pixart_alpha import PixArtAlphaPipeline | ||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule( | ||
__name__, | ||
globals()["__file__"], | ||
_import_structure, | ||
module_spec=__spec__, | ||
) | ||
|
||
for name, value in _dummy_objects.items(): | ||
setattr(sys.modules[__name__], name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Going forward, let's make sure we always use strings for Transformers type hints