Skip to content

Commit

Permalink
Merge pull request #3 from nod-ai/0.29.0.dev0-shark-fix
Browse files Browse the repository at this point in the history
Adapt to shark_turbine -> iree.turbine rename
  • Loading branch information
saienduri authored Oct 10, 2024
2 parents bfce212 + 8d6b683 commit fca148c
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
#from shark_turbine.ops.iree import trace_tensor
#from iree.turbine.ops.iree import trace_tensor


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .activations import FP32SiLU, get_activation
from .attention_processor import Attention

from shark_turbine.ops.iree import trace_tensor
from iree.turbine.ops.iree import trace_tensor


def get_timestep_embedding(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
import shark_turbine.ops.iree as iree_ops
import iree.turbine.ops.iree as iree_ops


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
get_mid_block,
get_up_block,
)
import shark_turbine.ops.iree as iree_ops
import iree.turbine.ops.iree as iree_ops


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
#from shark_turbine.ops.iree import trace_tensor
#from iree.turbine.ops.iree import trace_tensor

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
#from shark_turbine.ops.iree import trace_tensor
#from iree.turbine.ops.iree import trace_tensor

# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
Expand Down

0 comments on commit fca148c

Please sign in to comment.