diff --git a/README.md b/README.md index 9a6403cdacb..9a81e69e126 100644 --- a/README.md +++ b/README.md @@ -268,3 +268,34 @@ You can find more examples in the [documentation](https://huggingface.co/docs/op ``` You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training). + + +### Quanto + +[Quanto](https://github.com/huggingface/optimum-quanto) is a pytorch quantization backend. + +You can quantize a model either using the python API or the `optimum-cli`. + +```python +from transformers import AutoModelForCausalLM +from optimum.quanto import QuantizedModelForCausalLM, qint4 + +model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3.1-8B') +qmodel = QuantizedModelForCausalLM.quantize(model, weights=qint4, exclude='lm_head') +``` + +The quantized model can be saved using `save_pretrained`: + +```python +qmodel.save_pretrained('./Llama-3.1-8B-quantized') +``` + +It can later be reloaded using `from_pretrained`: + +```python +from optimum.quanto import QuantizedModelForCausalLM + +qmodel = QuantizedModelForCausalLM.from_pretrained('Llama-3.1-8B-quantized') +``` + +You can see more details and [examples](https://github.com/huggingface/optimum-quanto/tree/main/examples) in the [Quanto](https://github.com/huggingface/optimum-quanto) repository. diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 3d575c93c25..1a525ba4c8b 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -24,7 +24,7 @@ In the 2.0 version, PyTorch includes a native scaled dot-product attention opera We provide an integration with these optimizations out of the box in 🤗 Optimum, so that you can convert any supported 🤗 Transformers model so as to use the optimized paths & `scaled_dot_product_attention` function when relevant. -PyTorch-native `scaled_dot_product_attention` is slowly being natively [made default and integrated in 🤗 Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention). For models that do support SDPA in Transformers, we deprecate BetterTransformer and recommend you to use directly Transformers and PyTorc latest version for the attention optimizations (Flash Attention, memory-efficient attention) through SDPA. +PyTorch-native `scaled_dot_product_attention` is slowly being natively [made default and integrated in 🤗 Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention). For models that do support SDPA in Transformers, we deprecate BetterTransformer and recommend you to use directly Transformers and PyTorch latest version for the attention optimizations (Flash Attention, memory-efficient attention) through SDPA. diff --git a/examples/onnxruntime/training/image-classification/README.md b/examples/onnxruntime/training/image-classification/README.md index bf4bed8ee43..967942e7a93 100644 --- a/examples/onnxruntime/training/image-classification/README.md +++ b/examples/onnxruntime/training/image-classification/README.md @@ -39,7 +39,7 @@ torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE run_image_classification.py \ --per_device_eval_batch_size 32 \ --logging_strategy steps \ --logging_steps 10 \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --seed 1337 ``` diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 97053040879..a489f34fb06 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -308,9 +308,9 @@ class TasksManager: "image-feature-extraction": "feature-extraction", # for backward compatibility and testing (where # model task and model type are still the same) - "lcm": "text-to-image", "stable-diffusion": "text-to-image", "stable-diffusion-xl": "text-to-image", + "latent-consistency": "text-to-image", } _CUSTOM_CLASSES = { diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 9b9ad782dd4..11cdd766065 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -13,16 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F from torch.fx import GraphModule from ..core import Config, ParallelExecutionCtx, ParameterMeta from ..distributed import scatter -from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from ..parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + VocabParallelCrossEntropyLoss, + sharded_cross_entropy_wrapper_fn, +) from ..passes import ( ParallelAxisSolverPass, ParallelLayerAnnotatePass, @@ -64,6 +71,17 @@ def create_parallel_embedding( ) -> nn.Module: raise NotImplementedError + @abstractmethod + def create_parallel_cross_entropy( + self, + mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy], + parallel_ctx: "ParallelExecutionCtx", + ): + if isinstance(mod_or_fn, nn.CrossEntropyLoss): + return VocabParallelCrossEntropyLoss(ctx=parallel_ctx, reduction=mod_or_fn.reduction) + else: + return sharded_cross_entropy_wrapper_fn(process_group=parallel_ctx.tp_group) + def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 26258d451bf..5410818e929 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -197,7 +197,7 @@ def run(self, *args, **kwargs): def decompose_and_functionalize( graph_module: GraphModule, decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(), - leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], + leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy], ) -> Callable: """ API to decompose and functionalize a high-level graph module. diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index 56b8fc16bc0..4a9c55e3764 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -19,7 +19,7 @@ from torch.fx import Node from ..core import Config -from ..utils import is_activation, is_embedding, is_linear +from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear class Registry: @@ -334,7 +334,16 @@ def propagate(self) -> List[int]: ndim = arg.meta["val"].ndim slice_dim = (slice_dim + ndim) % ndim if slice_dim == axis: - # slice on the parallel axis is not allowed + # slice on the parallel axis is not allowed, except it's a nop + start, stop, step = 0, arg.meta["val"].shape[axis], 1 + if len(self.node.args) > 2: + start = self.node.args[2] + elif len(self.node.args) > 3: + stop = self.node.args[3] + elif len(self.node.args) > 4: + step = self.node.args[4] + if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1: + return [axis] return [] return [axis] @@ -404,12 +413,12 @@ def propagate(self) -> List[int]: if self.node.op in ["placeholder", "get_attr"]: return [None] elif self.node.op == "output": - for node in self.node.all_input_nodes: - # TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather - # parallelized output if intructed - if self.extract_axis(node) is not None: - return [] - return [None] + # does not care about if output is being parallelized right now, because if the output is loss, + # then it must be not parallelized as long as it comes from sharded cross entropy. + # TODO: append all-gather comm ops before all parallelized output nodes if instructed. + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + return [axis] elif is_linear(self.node): input_arg = self.node.all_input_nodes[0] axis = self.extract_axis(input_arg) @@ -438,6 +447,16 @@ def propagate(self) -> List[int]: return [1, None] if self.config.enable_sequence_parallel else [None] else: return [] + elif is_cross_entropy(self.node): + logits = self.node.all_input_nodes[0] + axis = self.extract_axis(logits) + if axis is None or ( + is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1 + ): + # for cross entropy, the input logits parallel axis can only be the last axis or None + return [None] + else: + return [] elif is_activation(self.node): return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() diff --git a/optimum/fx/parallelization/parallel_layers/__init__.py b/optimum/fx/parallelization/parallel_layers/__init__.py index 9bfb13afdf6..474ae7f7eef 100644 --- a/optimum/fx/parallelization/parallel_layers/__init__.py +++ b/optimum/fx/parallelization/parallel_layers/__init__.py @@ -14,3 +14,4 @@ # limitations under the License. from .embedding import VocabParallelEmbedding from .linear import ColumnParallelLinear, RowParallelLinear +from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn diff --git a/optimum/fx/parallelization/parallel_layers/loss.py b/optimum/fx/parallelization/parallel_layers/loss.py new file mode 100644 index 00000000000..0a11e33c08e --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/loss.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from ..core import ParallelExecutionCtx + + +# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py +class _ShardedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size) + target: torch.Tensor, # (batch_size, length) + group: dist.ProcessGroup, + ): + # Maximum value along last dimension across all GPUs. + logits_max = torch.max(sharded_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) + # Subtract the maximum value. + sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) + + # Get the shard's indices + sharded_hidden_size = sharded_logits.shape[-1] + rank = dist.get_rank(group) + start_index = rank * sharded_hidden_size + end_index = start_index + sharded_hidden_size + + # Create a mask of valid ids (1 means it needs to be masked). + target_mask = (target < start_index) | (target >= end_index) + masked_target = target.clone() - start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, shard-size] and target to a 1-D tensor of size [*]. + logits_2d = sharded_logits.view(-1, sharded_hidden_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + if predicted_logits_1d.is_contiguous(): + predicted_logits_1d = predicted_logits_1d.clone() + else: + predicted_logits_1d = predicted_logits_1d.contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = sharded_logits + torch.exp(sharded_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Retrieve tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + sharded_hidden_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, sharded_hidden_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup): + return _ShardedCrossEntropy.apply(sharded_logits, target, process_group) + + +def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup): + @wraps(sharded_cross_entropy) + def wrapper( + sharded_logits: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ): + if weight is not None or ignore_index != -100 or label_smoothing != 0.0: + raise ValueError( + "Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation." + ) + loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group) + + if size_average is not None or reduce is not None: + size_average = True if size_average is None else size_average + reduce = True if reduce is None else reduce + + if size_average and reduce: + reduction = "mean" + elif reduce: + reduction = "sum" + else: + reduction = "none" + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + return loss + + return wrapper + + +class VocabParallelCrossEntropyLoss(nn.Module): + """ + Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet. + """ + + def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None: + super(VocabParallelCrossEntropyLoss, self).__init__() + self.process_group = ctx.tp_group + self.reduction = reduction + + def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor): + loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group) + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "sum": + return loss.sum() + return loss diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 18c17fb26b7..b547d5ef148 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -27,6 +27,7 @@ from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .utils import ( ensure_divisibility, + is_cross_entropy, is_embedding, is_linear, is_shape_consumer, @@ -276,6 +277,11 @@ def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "C info["sequence_parallel"] = False self.place_marker_per_node(node, info) + elif is_cross_entropy(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + if axis_before is not None: + self.place_marker_per_node(node, {"axis": "vocab"}) + return graph_module @@ -383,7 +389,38 @@ def handle_embedding(self, node: Node, ctx: "ParallelExecutionCtx") -> None: layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) - def handle_hard_coded_axis_param(self, node: Node, ctx: "ParallelExecutionCtx") -> None: + @staticmethod + def handle_cross_entropy(node: Node, ctx: "ParallelExecutionCtx") -> None: + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") + if axis is None: + return + + assert axis in {"vocab"}, "Only support parallelization on vocab dim for now." + backend = ctx.backend + if node.op == "call_module": + graph_module = node.graph.owning_module + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod: nn.CrossEntropyLoss = graph_module.get_submodule(node.target) + key, layer_cache = node.target, ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + assert ctx.compile_times == 0, "illegal path for recompilation" + new_mod = backend.create_parallel_cross_entropy(mod, ctx) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) + else: + node.target = backend.create_parallel_cross_entropy(node.target, ctx) + + @staticmethod + def handle_hard_coded_axis_param(node: Node, ctx: "ParallelExecutionCtx") -> None: def extract_shape_from_node(node: Node) -> List[Any]: if "size" in node.kwargs: return list(node.kwargs["size"]) @@ -423,6 +460,8 @@ def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "C self.handle_linear(node, ctx) elif is_embedding(node): self.handle_embedding(node, ctx) + elif is_cross_entropy(node): + self.handle_cross_entropy(node, ctx) # correct the attention head num in parallel setting elif is_shape_consumer(node): self.handle_hard_coded_axis_param(node, ctx) diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index e4852bde90f..2590430df9c 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -82,6 +82,40 @@ def is_shape_generator(node: Node) -> bool: return node.op == "call_method" and node.target == "size" +def is_cross_entropy(node: Node) -> bool: + if node.op == "call_function": + return node.target is F.cross_entropy + elif node.op == "call_module": + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.CrossEntropyLoss) + return False + + +def is_cross_entropy_parallel_compatible(node: Node) -> bool: + """ + For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing. + """ + if node.op == "call_function": + weight = node.kwargs.get("weight", None) + ignore_index = node.kwargs.get("ignore_index", -100) + label_smoothing = node.kwargs.get("label_smoothing", 0.0) + if len(node.args) > 2 and weight is None: + weight = node.args[2] + if len(node.args) > 4 and ignore_index == -100: + ignore_index = node.args[4] + if len(node.args) > 7 and label_smoothing == 0.0: + label_smoothing = node.args[7] + + return weight is None and ignore_index == -100 and label_smoothing == 0.0 + + elif node.op == "call_module": + mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target) + weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index + return weight is None and ignore_index == -100 and label_smoothing == 0.0 + + return False + + def stable_topological_sort(graph: Graph): def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: args: List[torch.fx.node.Argument] = [] diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 902af87bbb0..949d4d260df 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -546,7 +546,7 @@ def tmp(_, input, output): if self.bits == 4: # device not on gpu - if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])): + if device.type != "cuda" or (has_device_map and any(d in devices for d in ["cpu", "disk", "hpu"])): if not self.disable_exllama: logger.warning( "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" @@ -589,13 +589,14 @@ def post_init_model(self, model): The input model """ if self.bits == 4 and not self.disable_exllama: - if get_device(model) == torch.device("cpu") or ( - hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"]) + if get_device(model).type != "cuda" or ( + hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk", "hpu"]) ): - raise ValueError( - "Found modules on cpu/disk. Using Exllama or Exllamav2 backend requires all the modules to be on GPU." - "You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object" - ) + if not self.disable_exllama: + logger.warning( + "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" + ) + self.disable_exllama = True class StoreAttr(object): pass diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 5bab0622de4..3da2d9d0d21 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -85,7 +85,6 @@ class PreTrainedModel(ABC): # noqa: F811 class OptimizedModel(PreTrainedModel): config_class = AutoConfig - load_tf_weights = None base_model_prefix = "optimized_model" config_name = CONFIG_NAME @@ -378,10 +377,14 @@ def from_pretrained( ) model_id, revision = model_id.split("@") - library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token) + library_name = TasksManager.infer_library_from_model( + model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) if library_name == "timm": - config = PretrainedConfig.from_pretrained(model_id, subfolder, revision) + config = PretrainedConfig.from_pretrained( + model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) if config is None: if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index f1d4f63a9ff..09a48ec955c 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -79,6 +79,10 @@ "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", "ORTLatentConsistencyModelPipeline", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + "ORTPipelineForText2Image", + "ORTDiffusionPipeline", ] else: _import_structure["modeling_diffusion"] = [ @@ -88,6 +92,10 @@ "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", "ORTLatentConsistencyModelPipeline", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + "ORTPipelineForText2Image", + "ORTDiffusionPipeline", ] @@ -137,7 +145,11 @@ raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_diffusers_objects import ( + ORTDiffusionPipeline, ORTLatentConsistencyModelPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, @@ -146,7 +158,11 @@ ) else: from .modeling_diffusion import ( + ORTDiffusionPipeline, ORTLatentConsistencyModelPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index d9877670ba8..0e54bafed78 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -41,17 +41,11 @@ class ORTModelPart: _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs - def __init__( - self, - session: InferenceSession, - parent_model: "ORTModel", - ): + def __init__(self, session: InferenceSession, parent_model: "ORTModel"): self.session = session self.parent_model = parent_model - self.normalized_config = NormalizedConfigManager.get_normalized_config_class( - self.parent_model.config.model_type - )(self.parent_model.config) self.main_input_name = self.parent_model.main_input_name + self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} @@ -90,12 +84,18 @@ class ORTEncoder(ORTModelPart): Encoder part of the encoder-decoder model for ONNX Runtime inference. """ - def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - **kwargs, - ) -> BaseModelOutput: + def __init__(self, session: InferenceSession, parent_model: "ORTModel"): + super().__init__(session, parent_model) + + config = ( + self.parent_model.config.encoder + if hasattr(self.parent_model.config, "encoder") + else self.parent_model.config + ) + + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput: use_torch = isinstance(input_ids, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) @@ -138,6 +138,14 @@ def __init__( ): super().__init__(session, parent_model) + config = ( + self.parent_model.config.decoder + if hasattr(self.parent_model.config, "decoder") + else self.parent_model.config + ) + + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + # TODO: make this less hacky. self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] @@ -153,11 +161,7 @@ def __init__( self.use_past_in_outputs = len(self.key_value_output_names) > 0 self.use_past_in_inputs = len(self.key_value_input_names) > 0 - self.use_fp16 = False - for inp in session.get_inputs(): - if "past_key_values" in inp.name and inp.type == "tensor(float16)": - self.use_fp16 = True - break + self.use_fp16 = self.dtype == torch.float16 # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 # can be used but do not support KV caching for the cross-attention key/values, see: @@ -461,11 +465,3 @@ def prepare_inputs_for_merged( cache_position = cache_position.to(self.device) return use_cache_branch_tensor, past_key_values, cache_position - - -class ORTDecoder(ORTDecoderForSeq2Seq): - def __init__(self, *args, **kwargs): - logger.warning( - "The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead." - ) - super().__init__(*args, **kwargs) diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index 31da5379184..f32ecc56e6e 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -157,9 +157,9 @@ def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding: Returns an IOBinding object for an inference session. This method is for general purpose, if the inputs and outputs are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. """ - if not all(input_name in inputs.keys() for input_name in ort_model.inputs_names): + if not all(input_name in inputs.keys() for input_name in ort_model.input_names): raise ValueError( - f"The ONNX model takes {ort_model.inputs_names.keys()} as inputs, but only {inputs.keys()} are given." + f"The ONNX model takes {ort_model.input_names.keys()} as inputs, but only {inputs.keys()} are given." ) name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_model.model) @@ -168,7 +168,7 @@ def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding: io_binding = ort_model.model.io_binding() # Bind inputs - for input_name in ort_model.inputs_names: + for input_name in ort_model.input_names: onnx_input = inputs.pop(input_name) onnx_input = onnx_input.contiguous() diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 4bbfb2eda2a..18cd38c5f29 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -17,7 +17,7 @@ import os import shutil import warnings -from abc import abstractmethod +from collections import OrderedDict from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, Dict, Optional, Union @@ -25,18 +25,28 @@ import numpy as np import torch from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ConfigMixin, DDIMScheduler, + LatentConsistencyModelPipeline, LMSDiscreteScheduler, PNDMScheduler, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from huggingface_hub import snapshot_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import validate_hf_hub_args from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers.file_utils import add_end_docstrings +from transformers.modeling_outputs import ModelOutput import onnxruntime as ort @@ -56,9 +66,10 @@ DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) +from .base import ORTModelPart +from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( - _ORT_TO_NP_TYPE, ONNX_WEIGHTS_NAME, get_provider_for_device, parse_device, @@ -69,23 +80,23 @@ logger = logging.getLogger(__name__) -class ORTStableDiffusionPipelineBase(ORTModel): - auto_model_class = StableDiffusionPipeline - main_input_name = "input_ids" - base_model_prefix = "onnx_model" +class ORTPipeline(ORTModel): + auto_model_class = None + model_type = "onnx_pipeline" + config_name = "model_index.json" sub_component_config_name = "config.json" def __init__( self, vae_decoder_session: ort.InferenceSession, - text_encoder_session: ort.InferenceSession, unet_session: ort.InferenceSession, - config: Dict[str, Any], tokenizer: CLIPTokenizer, + config: Dict[str, Any], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], feature_extractor: Optional[CLIPFeatureExtractor] = None, vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_2_session: Optional[ort.InferenceSession] = None, tokenizer_2: Optional[CLIPTokenizer] = None, use_io_binding: Optional[bool] = None, @@ -94,23 +105,28 @@ def __init__( """ Args: vae_decoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the VAE decoder. - text_encoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the text encoder. + The ONNX Runtime inference session associated to the VAE decoder unet_session (`ort.InferenceSession`): The ONNX Runtime inference session associated to the U-NET. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + for the text encoder. config (`Dict[str, Any]`): A config dictionary from which the model components will be instantiated. Make sure to only load configuration files of compatible classes. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): The ONNX Runtime inference session associated to the VAE encoder. + text_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): + The ONNX Runtime inference session associated to the text encoder. + tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + for the second text encoder. use_io_binding (`Optional[bool]`, defaults to `None`): Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` if the device is CUDA, otherwise defaults to `False`. @@ -118,7 +134,7 @@ def __init__( The directory under which the model exported to ONNX was saved. """ self.shared_attributes_init( - vae_decoder_session, + model=vae_decoder_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @@ -350,9 +366,9 @@ def _from_pretrained( text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, - text_encoder_2_path=new_model_save_dir - / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER - / text_encoder_2_file_name, + text_encoder_2_path=( + new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name + ), provider=provider, session_options=session_options, provider_options=provider_options, @@ -399,7 +415,7 @@ def _from_transformers( provider_options: Optional[Dict[str, Any]] = None, use_io_binding: Optional[bool] = None, task: Optional[str] = None, - ) -> "ORTStableDiffusionPipeline": + ) -> "ORTPipeline": if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", @@ -480,131 +496,142 @@ def _save_config(self, save_directory): self.save_config(save_directory) -# TODO : Use ORTModelPart once IOBinding support is added -class _ORTDiffusionModelPart: - """ - For multi-file ONNX models, represents a part of the model. - It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. - """ - +class ORTPipelinePart(ORTModelPart): CONFIG_NAME = "config.json" - def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): - self.session = session - self.parent_model = parent_model - self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} - self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline): config_path = Path(session._model_path).parent / self.CONFIG_NAME - self.config = self.parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} - self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} + + if config_path.is_file(): + # TODO: use FrozenDict + self.config = parent_model._dict_from_json_file(config_path) + else: + self.config = {} + + super().__init__(session, parent_model) @property - def device(self): - return self.parent_model.device + def input_dtype(self): + # for backward compatibility and diffusion mixins (will be standardized in the future) + return {name: TypeHelper.ort_type_to_numpy_type(ort_type) for name, ort_type in self.input_dtypes.items()} - @abstractmethod - def forward(self, *args, **kwargs): - pass - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) +class ORTModelTextEncoder(ORTPipelinePart): + def forward(self, input_ids: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(input_ids, torch.Tensor) + model_inputs = {"input_ids": input_ids} -class ORTModelTextEncoder(_ORTDiffusionModelPart): - def forward(self, input_ids: np.ndarray): - onnx_inputs = { - "input_ids": input_ids, - } - outputs = self.session.run(None, onnx_inputs) - return outputs + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + return ModelOutput(**model_outputs) -class ORTModelUnet(_ORTDiffusionModelPart): - def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): - super().__init__(session, parent_model) +class ORTModelUnet(ORTPipelinePart): def forward( self, - sample: np.ndarray, - timestep: np.ndarray, - encoder_hidden_states: np.ndarray, - text_embeds: Optional[np.ndarray] = None, - time_ids: Optional[np.ndarray] = None, - timestep_cond: Optional[np.ndarray] = None, + sample: Union[np.ndarray, torch.Tensor], + timestep: Union[np.ndarray, torch.Tensor], + encoder_hidden_states: Union[np.ndarray, torch.Tensor], + text_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None, + time_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, + timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None, ): - onnx_inputs = { + use_torch = isinstance(sample, torch.Tensor) + + model_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, + "text_embeds": text_embeds, + "time_ids": time_ids, + "timestep_cond": timestep_cond, } - if text_embeds is not None: - onnx_inputs["text_embeds"] = text_embeds - if time_ids is not None: - onnx_inputs["time_ids"] = time_ids - if timestep_cond is not None: - onnx_inputs["timestep_cond"] = timestep_cond - outputs = self.session.run(None, onnx_inputs) - return outputs + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + return ModelOutput(**model_outputs) -class ORTModelVaeDecoder(_ORTDiffusionModelPart): - def forward(self, latent_sample: np.ndarray): - onnx_inputs = { - "latent_sample": latent_sample, - } - outputs = self.session.run(None, onnx_inputs) - return outputs +class ORTModelVaeDecoder(ORTPipelinePart): + def forward(self, latent_sample: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(latent_sample, torch.Tensor) -class ORTModelVaeEncoder(_ORTDiffusionModelPart): - def forward(self, sample: np.ndarray): - onnx_inputs = { - "sample": sample, - } - outputs = self.session.run(None, onnx_inputs) - return outputs + model_inputs = {"latent_sample": latent_sample} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + return ModelOutput(**model_outputs) + + +class ORTModelVaeEncoder(ORTPipelinePart): + def forward(self, sample: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(sample, torch.Tensor) + + model_inputs = {"sample": sample} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + return ModelOutput(**model_outputs) @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusionPipelineMixin): +class ORTStableDiffusionPipeline(ORTPipeline, StableDiffusionPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionPipeline + __call__ = StableDiffusionPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin): +class ORTStableDiffusionImg2ImgPipeline(ORTPipeline, StableDiffusionImg2ImgPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionImg2ImgPipeline + __call__ = StableDiffusionImg2ImgPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin): +class ORTStableDiffusionInpaintPipeline(ORTPipeline, StableDiffusionInpaintPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionInpaintPipeline + __call__ = StableDiffusionInpaintPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTLatentConsistencyModelPipeline(ORTStableDiffusionPipelineBase, LatentConsistencyPipelineMixin): +class ORTLatentConsistencyModelPipeline(ORTPipeline, LatentConsistencyPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline). """ - __call__ = LatentConsistencyPipelineMixin.__call__ + main_input_name = "prompt" + auto_model_class = LatentConsistencyModelPipeline + __call__ = LatentConsistencyPipelineMixin.__call__ -class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase): - auto_model_class = StableDiffusionXLImg2ImgPipeline +class ORTStableDiffusionXLPipelineBase(ORTPipeline): def __init__( self, vae_decoder_session: ort.InferenceSession, @@ -657,6 +684,9 @@ class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffu ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionXLPipeline + __call__ = StableDiffusionXLPipelineMixin.__call__ @@ -666,4 +696,140 @@ class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, Stab ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionXLImg2ImgPipeline + __call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__ + + +SUPPORTED_ORT_PIPELINES = [ + ORTStableDiffusionPipeline, + ORTStableDiffusionImg2ImgPipeline, + ORTStableDiffusionInpaintPipeline, + ORTLatentConsistencyModelPipeline, + ORTStableDiffusionXLPipeline, + ORTStableDiffusionXLImg2ImgPipeline, +] + + +def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): + for ort_pipeline_class in SUPPORTED_ORT_PIPELINES: + if ( + ort_pipeline_class.__name__ == pipeline_class_name + or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name + ): + return ort_pipeline_class + + if throw_error_if_not_exist: + raise ValueError(f"ORTDiffusionPipeline can't find a pipeline linked to {pipeline_class_name}") + + +class ORTDiffusionPipeline(ConfigMixin): + config_name = "model_index.json" + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + load_config_kwargs = { + "force_download": kwargs.get("force_download", False), + "resume_download": kwargs.get("resume_download", None), + "local_files_only": kwargs.get("local_files_only", False), + "cache_dir": kwargs.get("cache_dir", None), + "revision": kwargs.get("revision", None), + "proxies": kwargs.get("proxies", None), + "token": kwargs.get("token", None), + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + config = config[0] if isinstance(config, tuple) else config + class_name = config["_class_name"] + + ort_pipeline_class = _get_pipeline_class(class_name) + + return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) + + +ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), + ("latent-consistency", ORTLatentConsistencyModelPipeline), + ] +) + +ORT_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), + ] +) + +ORT_INPAINT_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionInpaintPipeline), + ] +) + +SUPPORTED_ORT_PIPELINES_MAPPINGS = [ + ORT_TEXT2IMAGE_PIPELINES_MAPPING, + ORT_IMAGE2IMAGE_PIPELINES_MAPPING, + ORT_INPAINT_PIPELINES_MAPPING, +] + + +def _get_task_class(mapping, pipeline_class_name): + def _get_model_name(pipeline_class_name): + for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS: + for model_name, ort_pipeline_class in ort_pipelines_mapping.items(): + if ( + ort_pipeline_class.__name__ == pipeline_class_name + or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name + ): + return model_name + + model_name = _get_model_name(pipeline_class_name) + + if model_name is not None: + task_class = mapping.get(model_name, None) + if task_class is not None: + return task_class + + raise ValueError(f"ORTPipelineForTask can't find a pipeline linked to {pipeline_class_name} for {model_name}") + + +class ORTPipelineForTask(ConfigMixin): + config_name = "model_index.json" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + load_config_kwargs = { + "force_download": kwargs.get("force_download", False), + "resume_download": kwargs.get("resume_download", None), + "local_files_only": kwargs.get("local_files_only", False), + "cache_dir": kwargs.get("cache_dir", None), + "revision": kwargs.get("revision", None), + "proxies": kwargs.get("proxies", None), + "token": kwargs.get("token", None), + } + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + config = config[0] if isinstance(config, tuple) else config + class_name = config["_class_name"] + + ort_pipeline_class = _get_task_class(cls.ort_pipelines_mapping, class_name) + + return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) + + +class ORTPipelineForText2Image(ORTPipelineForTask): + auto_model_class = AutoPipelineForText2Image + ort_pipelines_mapping = ORT_TEXT2IMAGE_PIPELINES_MAPPING + + +class ORTPipelineForImage2Image(ORTPipelineForTask): + auto_model_class = AutoPipelineForImage2Image + ort_pipelines_mapping = ORT_IMAGE2IMAGE_PIPELINES_MAPPING + + +class ORTPipelineForInpainting(ORTPipelineForTask): + auto_model_class = AutoPipelineForInpainting + ort_pipelines_mapping = ORT_INPAINT_PIPELINES_MAPPING diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 4ce3e4707ed..3cecadafe3e 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -46,7 +46,6 @@ from ..onnx.utils import _get_external_data_paths from ..utils import check_if_transformers_greater from ..utils.file_utils import validate_file_exists -from ..utils.normalized_config import NormalizedConfigManager from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoderForSeq2Seq, ORTEncoder from .constants import ( @@ -72,16 +71,6 @@ from transformers.generation_utils import GenerationMixin -# if check_if_transformers_greater("4.37.0"): -# # starting from transformers v4.37.0, the whisper generation loop is implemented in the `WhisperGenerationMixin` -# # and it implements many new features including short and long form generation, and starts with 2 init tokens -# from transformers.models.whisper.generation_whisper import WhisperGenerationMixin -# else: - -# class WhisperGenerationMixin(WhisperForConditionalGeneration, GenerationMixin): -# pass - - if check_if_transformers_greater("4.43.0"): from transformers.cache_utils import EncoderDecoderCache else: @@ -1165,49 +1154,6 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" - def __init__( - self, - encoder_session: ort.InferenceSession, - decoder_session: ort.InferenceSession, - config: "PretrainedConfig", - onnx_paths: List[str], - decoder_with_past_session: Optional[ort.InferenceSession] = None, - use_cache: bool = True, - use_io_binding: Optional[bool] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - preprocessors: Optional[List] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ): - super().__init__( - encoder_session, - decoder_session, - config, - onnx_paths, - decoder_with_past_session, - use_cache, - use_io_binding, - model_save_dir, - preprocessors, - generation_config, - **kwargs, - ) - - # The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized. - if config.model_type == "encoder-decoder": - self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.encoder.model_type - )(config.encoder) - - self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - - if self.decoder_with_past is not None: - self.decoder_with_past.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoder(session, self) @@ -1521,20 +1467,6 @@ def __init__( **kwargs, ) - # The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized. - self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.encoder.model_type - )(config.encoder) - - self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - - if self.decoder_with_past is not None: - self.decoder_with_past.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoderForVisionEncoderDecoder(session, self) diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 9bc2bb5134d..66273cbcf96 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -55,7 +55,6 @@ from torch.utils.data import Dataset, RandomSampler from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import Trainer @@ -81,10 +80,10 @@ is_apex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_torch_tpu_available, ) from ..utils import logging +from ..utils.import_utils import check_if_transformers_greater from .training_args import ORTOptimizerNames, ORTTrainingArguments from .utils import ( is_onnxruntime_training_available, @@ -94,8 +93,25 @@ if is_apex_available(): from apex import amp -if is_torch_tpu_available(check_device=False): - import torch_xla.core.xla_model as xm +if check_if_transformers_greater("4.33"): + from transformers.integrations.deepspeed import ( + deepspeed_init, + deepspeed_load_checkpoint, + is_deepspeed_zero3_enabled, + ) +else: + from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled + +if check_if_transformers_greater("4.39"): + from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available + + if is_torch_tpu_xla_available(): + import torch_xla.core.xla_model as xm +else: + from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available + + if is_torch_tpu_xla_available(check_device=False): + import torch_xla.core.xla_model as xm if TYPE_CHECKING: import optuna @@ -719,7 +735,7 @@ def get_dataloader_sampler(dataloader): if ( args.logging_nan_inf_filter - and not is_torch_tpu_available() + and not is_torch_tpu_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index 2e43ee89e00..1565ffa6acb 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -19,10 +19,10 @@ import torch from torch import nn from torch.utils.data import Dataset -from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging +from ..utils.import_utils import check_if_transformers_greater from .trainer import ORTTrainer @@ -33,6 +33,11 @@ "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) +if check_if_transformers_greater("4.33"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + logger = logging.get_logger(__name__) diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index 6aec362c07c..6135abc1376 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -117,32 +117,32 @@ def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN - if isinstance(self.evaluation_strategy, EvaluationStrategy): + if isinstance(self.eval_strategy, EvaluationStrategy): warnings.warn( - "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" + "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5" " of 🤗 Transformers. Use `IntervalStrategy` instead", FutureWarning, ) # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. - self.evaluation_strategy = self.evaluation_strategy.value + self.eval_strategy = self.eval_strategy.value - self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + self.eval_strategy = IntervalStrategy(self.eval_strategy) self.logging_strategy = IntervalStrategy(self.logging_strategy) self.save_strategy = IntervalStrategy(self.save_strategy) self.hub_strategy = HubStrategy(self.hub_strategy) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) - if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: + if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO: self.do_eval = True # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero - if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): if self.logging_steps > 0: logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") self.eval_steps = self.logging_steps else: raise ValueError( - f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or" + f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or" " --logging_steps" ) @@ -154,7 +154,7 @@ def __post_init__(self): if self.logging_steps != int(self.logging_steps): raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") self.logging_steps = int(self.logging_steps) - if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: if self.eval_steps != int(self.eval_steps): raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") self.eval_steps = int(self.eval_steps) @@ -165,13 +165,13 @@ def __post_init__(self): # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. if self.load_best_model_at_end: - if self.evaluation_strategy != self.save_strategy: + if self.eval_strategy != self.save_strategy: raise ValueError( "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps " f"{self.save_steps} and eval_steps {self.eval_steps}." ) - if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: if self.eval_steps < 1 or self.save_steps < 1: if not (self.eval_steps < 1 and self.save_steps < 1): raise ValueError( @@ -244,7 +244,7 @@ def __post_init__(self): ) if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: - if self.evaluation_strategy == IntervalStrategy.NO: + if self.eval_strategy == IntervalStrategy.NO: raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") if not is_torch_available(): raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index ad40af92b9d..985980e31b0 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib import os import re from enum import Enum @@ -31,7 +32,6 @@ import onnxruntime as ort from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss -from ..utils.import_utils import _is_package_available if TYPE_CHECKING: @@ -91,9 +91,11 @@ def is_onnxruntime_training_available(): def is_cupy_available(): """ - Checks if onnxruntime-training is available. + Checks if CuPy is available. """ - return _is_package_available("cupy") + # Don't use _is_package_available as it doesn't work with CuPy installed + # with `cupy-cuda*` and `cupy-rocm-*` package name (prebuilt wheels). + return importlib.util.find_spec("cupy") is not None class ORTConfigManager: diff --git a/optimum/pipelines/diffusers/pipeline_latent_consistency.py b/optimum/pipelines/diffusers/pipeline_latent_consistency.py index 41c85b5b6ac..630d463de73 100644 --- a/optimum/pipelines/diffusers/pipeline_latent_consistency.py +++ b/optimum/pipelines/diffusers/pipeline_latent_consistency.py @@ -36,7 +36,7 @@ def __call__( original_inference_steps: int = None, guidance_scale: float = 8.5, num_images_per_prompt: int = 1, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, output_type: str = "pil", @@ -66,7 +66,7 @@ def __call__( usually at the expense of lower image quality. num_images_per_prompt (`int`, defaults to 1): The number of images to generate per prompt. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -121,7 +121,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() prompt_embeds = self._encode_prompt( prompt, diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py index 98bff0de44d..6cc47fab1b9 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py @@ -189,7 +189,15 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = generator.randn(*shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*shape).astype(dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) elif latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -209,7 +217,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -248,7 +256,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -303,7 +311,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py index 81a6ffa1e04..a66035a789b 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -16,10 +16,9 @@ from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import deprecate from .pipeline_stable_diffusion import StableDiffusionPipelineMixin @@ -72,6 +71,43 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + if isinstance(generator, np.random.RandomState): + noise = generator.randn(*init_latents.shape).astype(dtype) + elif isinstance(generator, torch.Generator): + noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ).numpy() + + return init_latents + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.__call__ def __call__( self, @@ -83,7 +119,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: str = "pil", @@ -125,7 +161,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): A np.random.RandomState to make generation deterministic. prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -168,7 +204,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -191,31 +227,7 @@ def __call__( latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) - # encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) - init_latents = scaling_factor * init_latents - - if isinstance(prompt, str): - prompt = [prompt] - if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = len(prompt) // init_latents.shape[0] - init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) - elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." - ) - else: - init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) @@ -225,12 +237,8 @@ def __call__( timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) - # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(latents_dtype) - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() + # 5. Prepare latent variables + latents = self.prepare_latents(image, timesteps, batch_size, num_images_per_prompt, latents_dtype, generator) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -241,8 +249,6 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - latents = init_latents - t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() @@ -276,7 +282,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py index 19de793ccd0..cb3c7db96e9 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -16,7 +16,7 @@ from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import PIL_INTERPOLATION @@ -108,7 +108,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -200,7 +200,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -229,11 +229,19 @@ def __call__( width // self.vae_scale_factor, ) latents_dtype = prompt_embeds.dtype + if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*latents_shape, generator=generator).numpy().astype(latents_dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image( diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py index 2a5e7bf78b0..0407c16a77a 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -235,7 +235,15 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = generator.randn(*shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*shape).astype(dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) elif latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -270,7 +278,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -315,7 +323,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -383,7 +391,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -440,6 +448,7 @@ def __call__( timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance @@ -475,7 +484,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py index a07903a735e..19988599b64 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput @@ -222,7 +222,7 @@ def get_timesteps(self, num_inference_steps, strength): return timesteps, num_inference_steps - t_start # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: @@ -242,11 +242,22 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt init_latents = np.concatenate([init_latents], axis=0) # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + noise = generator.randn(*init_latents.shape).astype(dtype) + elif isinstance(generator, torch.Generator): + noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep) + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) - return init_latents.numpy() + init_latents = init_latents.numpy() + + return init_latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype @@ -274,7 +285,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -375,7 +386,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -482,7 +493,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_utils.py b/optimum/pipelines/diffusers/pipeline_utils.py index 869b91ffe59..e9d5986b61c 100644 --- a/optimum/pipelines/diffusers/pipeline_utils.py +++ b/optimum/pipelines/diffusers/pipeline_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers import ConfigMixin from diffusers.image_processor import VaeImageProcessor as DiffusersVaeImageProcessor @@ -206,7 +206,7 @@ def postprocess( def get_height_width( self, - image: [PIL.Image.Image, np.ndarray], + image: Union[PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, ): @@ -264,10 +264,10 @@ def reshape(images: np.ndarray) -> np.ndarray: # TODO : remove after diffusers v0.21.0 release def resize( self, - image: [PIL.Image.Image, np.ndarray, torch.Tensor], + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, - ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. """ diff --git a/optimum/utils/dummy_diffusers_objects.py b/optimum/utils/dummy_diffusers_objects.py index f6914bbcd3a..35d1ffe9fc7 100644 --- a/optimum/utils/dummy_diffusers_objects.py +++ b/optimum/utils/dummy_diffusers_objects.py @@ -79,3 +79,47 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) + + +class ORTDiffusionPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForText2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForImage2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForInpainting(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) diff --git a/optimum/version.py b/optimum/version.py index 8eeeb9d05a7..4a8a7edab63 100644 --- a/optimum/version.py +++ b/optimum/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.22.0.dev0" +__version__ = "1.23.0.dev0" diff --git a/setup.py b/setup.py index 98ee4f36a3f..ac5db71a74b 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,7 @@ "graphcore": "optimum-graphcore", "furiosa": "optimum-furiosa", "amd": "optimum-amd", + "quanto": ["optimum-quanto>=0.2.4"], "dev": TESTS_REQUIRE + QUALITY_REQUIRE, "tests": TESTS_REQUIRE, "quality": QUALITY_REQUIRE, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index a55c7a124df..c8a33b0be35 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -298,7 +298,7 @@ PYTORCH_DIFFUSION_MODEL = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", - "lcm": "echarlaix/tiny-random-latent-consistency", + "latent-consistency": "echarlaix/tiny-random-latent-consistency", } PYTORCH_TIMM_MODEL = { diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 4c9ba131e4b..1fee30f7d43 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -36,6 +36,7 @@ "output_attentions": False, "output_hidden_states": False, "tie_word_embeddings": True, + "return_dict": True, } DUMMY_MODELS_TO_TEST = ( @@ -64,11 +65,10 @@ def prepare_dummy_inputs( seq_len: int = 10, device: Union[str, torch.device] = "cuda", ): - return { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device), - "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1), - } + input_ids = torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + labels = input_ids.clone() + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): @@ -82,8 +82,8 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) - logits = model(**inputs)[0] - tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) + loss = model(**inputs).loss + tensors = gather_at_main_process(tensor=loss, group=tp_group, rank=rank, world_size=world_size) # check results at main worker process if rank == 0: @@ -145,7 +145,7 @@ def run_test_parallel_results_matches_non_parallel( inputs = prepare_dummy_inputs(model.config) set_seed(SEED) - logits = model(**inputs)[0] + loss = model(**inputs).loss torch._dynamo.reset() del model @@ -154,9 +154,9 @@ def run_test_parallel_results_matches_non_parallel( set_seed(SEED) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) - parallel_logits = model(**inputs)[0] + parallel_loss = model(**inputs).loss - torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) + torch.testing.assert_close(loss.cpu(), parallel_loss.cpu(), rtol=1e-4, atol=1e-4) dist.barrier(tp_group) tearDown() diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py new file mode 100644 index 00000000000..9f480b2d1a0 --- /dev/null +++ b/tests/onnxruntime/test_diffusion.py @@ -0,0 +1,793 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import PIL +import pytest +import torch +from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + DiffusionPipeline, +) +from diffusers.utils import load_image +from parameterized import parameterized +from transformers.testing_utils import require_torch_gpu +from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin + +from optimum.onnxruntime import ( + ORTDiffusionPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, +) +from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor +from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_ort_rocm + + +def get_generator(framework, seed): + if framework == "np": + return np.random.RandomState(seed) + elif framework == "pt": + return torch.Generator().manual_seed(seed) + else: + raise ValueError(f"Unknown framework: {framework}") + + +def _generate_prompts(batch_size=1): + inputs = { + "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, + "num_inference_steps": 3, + "guidance_scale": 7.5, + "output_type": "np", + } + return inputs + + +def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type="pil"): + if input_type == "pil": + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((width, height)) + elif input_type == "np": + image = np.random.rand(height, width, channel) + elif input_type == "pt": + image = torch.rand((channel, height, width)) + + return [image] * batch_size + + +def to_np(image): + if isinstance(image[0], PIL.Image.Image): + return np.stack([np.array(i) for i in image], axis=0) + elif isinstance(image, torch.Tensor): + return image.cpu().numpy().transpose(0, 2, 3, 1) + return image + + +class ORTPipelineForText2ImageTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["latent-consistency", "stable-diffusion", "stable-diffusion-xl"] + + ORTMODEL_CLASS = ORTPipelineForText2Image + AUTOMODEL_CLASS = AutoPipelineForText2Image + + TASK = "text-to-image" + + def generate_inputs(self, height=128, width=128, batch_size=1): + inputs = _generate_prompts(batch_size=batch_size) + + inputs["height"] = height + inputs["width"] = width + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 128, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + if model_arch == "latent-consistency": + # Latent Consistency Model (LCM) doesn't support deterministic outputs beyond the first inference step + # TODO: Investigate why this is the case + inputs["num_inference_steps"] = 1 + + for output_type in ["latent", "np"]: + inputs["output_type"] = output_type + + ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + self.assertTrue( + np.allclose(ort_output, diffusers_output, atol=1e-4), + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), + ) + self.assertEqual(ort_pipeline.device, diffusers_pipeline.device) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + height, width, batch_size = 64, 32, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_callback = Callback() + auto_callback = Callback() + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertTrue(auto_callback.has_been_called) + self.assertEqual(auto_callback.number_of_steps, ort_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + height, width, batch_size = 128, 64, 1 + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + if model_arch in ["latent-consistency"]: + pytest.skip("Latent Consistency Model (LCM) doesn't support deterministic outputs") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_negative_prompt(self, model_arch: str): + if model_arch in ["latent-consistency"]: + pytest.skip("Latent Consistency Model (LCM) does not support negative prompts") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + negative_prompt = ["This is a negative prompt"] + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + image_slice_1 = pipeline( + **inputs, negative_prompt=negative_prompt, generator=np.random.RandomState(SEED) + ).images[0, -3:, -3:, -1] + prompt = inputs.pop("prompt") + + if model_arch == "stable-diffusion-xl": + ( + inputs["prompt_embeds"], + inputs["negative_prompt_embeds"], + inputs["pooled_prompt_embeds"], + inputs["negative_pooled_prompt_embeds"], + ) = pipeline._encode_prompt(prompt, 1, False, negative_prompt) + else: + text_ids = pipeline.tokenizer( + prompt, + max_length=pipeline.tokenizer.model_max_length, + padding="max_length", + return_tensors="np", + truncation=True, + ).input_ids + negative_text_ids = pipeline.tokenizer( + negative_prompt, + max_length=pipeline.tokenizer.model_max_length, + padding="max_length", + return_tensors="np", + truncation=True, + ).input_ids + inputs["prompt_embeds"] = pipeline.text_encoder(text_ids)[0] + inputs["negative_prompt_embeds"] = pipeline.text_encoder(negative_text_ids)[0] + + image_slice_2 = pipeline(**inputs, generator=np.random.RandomState(SEED)).images[0, -3:, -3:, -1] + + self.assertTrue(np.allclose(image_slice_1, image_slice_2, rtol=1e-1)) + + +class ORTPipelineForImage2ImageTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] + + AUTOMODEL_CLASS = AutoPipelineForImage2Image + ORTMODEL_CLASS = ORTPipelineForImage2Image + + TASK = "image-to-image" + + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="np"): + inputs = _generate_prompts(batch_size=batch_size) + + inputs["image"] = _generate_images( + height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type + ) + + inputs["strength"] = 0.75 + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(list(SUPPORTED_ARCHITECTURES)) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + batch_size, height = 1, 32 + for width in [64, 32]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + if model_arch in ["stable-diffusion"]: + pytest.skip( + "Stable Diffusion For Img2Img doesn't behave as expected with callbacks (doesn't call it every step with callback_steps=1)" + ) + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["num_inference_steps"] = 3 + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_callback = Callback() + auto_callback = Callback() + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + height, width, batch_size = 32, 64, 1 + + for input_type in ["np", "pil", "pt"]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + pytest.skip("Img2Img models do not support support output reproducibility for some reason") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 128, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + self.assertTrue(np.allclose(ort_output, diffusers_output, rtol=1e-2)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + pytest.skip("Img2Img models do not support support output reproducibility for some reason") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ORTPipelineForInpaintingTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["stable-diffusion"] + + AUTOMODEL_CLASS = AutoPipelineForInpainting + ORTMODEL_CLASS = ORTPipelineForInpainting + + TASK = "inpainting" + + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): + assert batch_size == 1, "Inpainting models only support batch_size=1" + assert input_type == "pil", "Inpainting models only support input_type='pil'" + + inputs = _generate_prompts(batch_size=batch_size) + + inputs["image"] = _generate_images( + height=height, width=width, batch_size=1, channel=channel, input_type="pil" + )[0] + inputs["mask_image"] = _generate_images( + height=height, width=width, batch_size=1, channel=channel, input_type="pil" + )[0] + + inputs["height"] = height + inputs["width"] = width + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + batch_size, height = 1, 32 + for width in [64, 32]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["num_inference_steps"] = 3 + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_callback = Callback() + auto_callback = Callback() + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + height, width, batch_size = 32, 64, 1 + + for input_type in ["pil"]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + if model_arch in ["stable-diffusion"]: + pytest.skip( + "Stable Diffusion For Inpainting fails, it was used to be compared to StableDiffusionPipeline for some reason which is the text-to-image variant" + ) + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + latents_shape = ( + batch_size, + ort_pipeline.vae_decoder.config["latent_channels"], + height // ort_pipeline.vae_scale_factor, + width // ort_pipeline.vae_scale_factor, + ) + + np_latents = np.random.rand(*latents_shape).astype(np.float32) + torch_latents = torch.from_numpy(np_latents) + + ort_output = ort_pipeline(**inputs, latents=np_latents).images + diffusers_output = diffusers_pipeline(**inputs, latents=torch_latents).images + + self.assertTrue( + np.allclose(ort_output, diffusers_output, atol=1e-4), + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ImageProcessorTest(unittest.TestCase): + def test_vae_image_processor_pt(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_pt = torch.stack(_generate_images(height=8, width=8, batch_size=1, input_type="pt")) + input_np = to_np(input_pt) + + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_pt), output_type=output_type) + out_np = to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) + + def test_vae_image_processor_np(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_np = np.stack(_generate_images(height=8, width=8, input_type="np")) + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) + out_np = to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) + + def test_vae_image_processor_pil(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_pil = _generate_images(height=8, width=8, batch_size=1, input_type="pil") + + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) + for i, o in zip(input_pil, out): + in_np = np.array(i) + out_np = to_np(out) if output_type == "pil" else (to_np(out) * 255).round() + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4b44acb38ab..199b96342e7 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -89,15 +89,8 @@ ORTModelForSpeechSeq2Seq, ORTModelForTokenClassification, ORTModelForVision2Seq, - ORTStableDiffusionPipeline, ) from optimum.onnxruntime.base import ORTDecoderForSeq2Seq, ORTEncoder -from optimum.onnxruntime.modeling_diffusion import ( - ORTModelTextEncoder, - ORTModelUnet, - ORTModelVaeDecoder, - ORTModelVaeEncoder, -) from optimum.onnxruntime.modeling_ort import ORTModel from optimum.pipelines import pipeline from optimum.utils import ( @@ -108,7 +101,24 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm +from optimum.utils.import_utils import is_diffusers_available +from optimum.utils.testing_utils import ( + grid_parameters, + remove_directory, + require_diffusers, + require_hf_token, + require_ort_rocm, +) + + +if is_diffusers_available(): + from optimum.onnxruntime.modeling_diffusion import ( + ORTModelTextEncoder, + ORTModelUnet, + ORTModelVaeDecoder, + ORTModelVaeEncoder, + ORTStableDiffusionPipeline, + ) logger = logging.get_logger() @@ -205,6 +215,7 @@ def test_load_seq2seq_model_from_empty_cache(self): with self.assertRaises(Exception): _ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True) + @require_diffusers def test_load_stable_diffusion_model_from_cache(self): _ = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) # caching @@ -218,6 +229,7 @@ def test_load_stable_diffusion_model_from_cache(self): self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + @require_diffusers def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--") @@ -300,6 +312,7 @@ def test_load_seq2seq_model_unknown_provider(self): with self.assertRaises(ValueError): ORTModelForSeq2SeqLM.from_pretrained(self.ONNX_SEQ2SEQ_MODEL_ID, provider="FooExecutionProvider") + @require_diffusers def test_load_stable_diffusion_model_from_hub(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) self.assertIsInstance(model.text_encoder, ORTModelTextEncoder) @@ -308,6 +321,7 @@ def test_load_stable_diffusion_model_from_hub(self): self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_load_stable_diffusion_model_cuda_provider(self): @@ -321,6 +335,7 @@ def test_load_stable_diffusion_model_cuda_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -335,6 +350,7 @@ def test_load_stable_diffusion_model_rocm_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + @require_diffusers def test_load_stable_diffusion_model_cpu_provider(self): model = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, provider="CPUExecutionProvider" @@ -346,6 +362,7 @@ def test_load_stable_diffusion_model_cpu_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cpu")) + @require_diffusers def test_load_stable_diffusion_model_unknown_provider(self): with self.assertRaises(ValueError): ORTStableDiffusionPipeline.from_pretrained( @@ -478,6 +495,7 @@ def test_passing_session_options_seq2seq(self): self.assertEqual(model.encoder.session.get_session_options().intra_op_num_threads, 3) self.assertEqual(model.decoder.session.get_session_options().intra_op_num_threads, 3) + @require_diffusers def test_passing_session_options_stable_diffusion(self): options = onnxruntime.SessionOptions() options.intra_op_num_threads = 3 @@ -772,6 +790,7 @@ def test_seq2seq_model_on_rocm_ep_str(self): self.assertEqual(model.decoder_with_past.session.get_providers()[0], "ROCMExecutionProvider") self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_passing_provider_options_stable_diffusion(self): @@ -810,6 +829,7 @@ def test_passing_provider_options_stable_diffusion(self): model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0" ) + @require_diffusers def test_stable_diffusion_model_on_cpu(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) cpu = torch.device("cpu") @@ -825,7 +845,7 @@ def test_stable_diffusion_model_on_cpu(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) - # test string device input for to() + @require_diffusers def test_stable_diffusion_model_on_cpu_str(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) cpu = torch.device("cpu") @@ -841,6 +861,7 @@ def test_stable_diffusion_model_on_cpu_str(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_stable_diffusion_model_on_gpu(self): @@ -858,6 +879,7 @@ def test_stable_diffusion_model_on_gpu(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -876,6 +898,7 @@ def test_stable_diffusion_model_on_rocm_ep(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "ROCMExecutionProvider") self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @unittest.skipIf(get_gpu_count() <= 1, "this test requires multi-gpu") def test_stable_diffusion_model_on_gpu_id(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) @@ -899,7 +922,7 @@ def test_stable_diffusion_model_on_gpu_id(self): self.assertEqual(model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") - # test string device input for to() + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_stable_diffusion_model_on_gpu_str(self): @@ -916,6 +939,7 @@ def test_stable_diffusion_model_on_gpu_str(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -975,6 +999,7 @@ def test_save_seq2seq_model_without_past(self): self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) self.assertTrue(CONFIG_NAME in folder_contents) + @require_diffusers def test_save_stable_diffusion_model(self): with tempfile.TemporaryDirectory() as tmpdirname: model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) @@ -1050,6 +1075,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") remove_directory(tmpdirname) + @require_diffusers def test_save_load_stable_diffusion_model_with_external_data(self): with tempfile.TemporaryDirectory() as tmpdirname: os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data @@ -1180,6 +1206,7 @@ def test_push_seq2seq_model_with_external_data_to_hub(self): ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + @require_diffusers @require_hf_token def test_push_stable_diffusion_model_with_external_data_to_hub(self): with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/onnxruntime/test_stable_diffusion_pipeline.py b/tests/onnxruntime/test_stable_diffusion_pipeline.py deleted file mode 100644 index 44cd22ffecc..00000000000 --- a/tests/onnxruntime/test_stable_diffusion_pipeline.py +++ /dev/null @@ -1,562 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random -import unittest -from typing import Dict - -import numpy as np -import PIL -import pytest -import torch -from diffusers import ( - OnnxStableDiffusionImg2ImgPipeline, - StableDiffusionPipeline, - StableDiffusionXLPipeline, -) -from diffusers.utils import load_image -from diffusers.utils.testing_utils import floats_tensor -from packaging.version import Version, parse -from parameterized import parameterized -from transformers.testing_utils import require_torch_gpu -from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin - -from optimum.onnxruntime import ( - ORTLatentConsistencyModelPipeline, - ORTStableDiffusionImg2ImgPipeline, - ORTStableDiffusionInpaintPipeline, - ORTStableDiffusionPipeline, - ORTStableDiffusionXLImg2ImgPipeline, - ORTStableDiffusionXLPipeline, -) -from optimum.onnxruntime.modeling_diffusion import ( - ORTModelTextEncoder, - ORTModelUnet, - ORTModelVaeDecoder, - ORTModelVaeEncoder, -) -from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor -from optimum.utils.import_utils import _diffusers_version -from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_ort_rocm - - -if parse(_diffusers_version) > Version("0.21.4"): - from diffusers import LatentConsistencyModelPipeline - - -def _generate_inputs(batch_size=1): - inputs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 3, - "guidance_scale": 7.5, - "output_type": "np", - } - return inputs - - -def _create_image(height=128, width=128, batch_size=1, channel=3, input_type="pil"): - if input_type == "pil": - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - elif input_type == "np": - image = np.random.rand(height, width, channel) - elif input_type == "pt": - image = torch.rand((channel, height, width)) - - return [image] * batch_size - - -def to_np(image): - if isinstance(image[0], PIL.Image.Image): - return np.stack([np.array(i) for i in image], axis=0) - elif isinstance(image, torch.Tensor): - return image.cpu().numpy().transpose(0, 2, 3, 1) - return image - - -class ORTStableDiffusionPipelineBase(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionPipeline - TASK = "text-to-image" - - @require_diffusers - def test_load_vanilla_model_which_is_not_supported(self): - with self.assertRaises(Exception) as context: - _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) - - self.assertIn( - f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) - ) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.cuda_ep_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) - ) - @require_torch_gpu - @require_ort_rocm - @pytest.mark.rocm_ep_test - @require_diffusers - def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_callback(self, model_arch: str): - def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: - callback_fn.has_been_called = True - callback_fn.number_of_steps += 1 - - pipe = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - callback_fn.has_been_called = False - callback_fn.number_of_steps = 0 - inputs = self.generate_inputs(height=64, width=64) - pipe(**inputs, callback=callback_fn, callback_steps=1) - self.assertTrue(callback_fn.has_been_called) - self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"]) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_shape(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - height, width, batch_size = 128, 64, 1 - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - - if self.TASK == "image-to-image": - input_types = ["np", "pil", "pt"] - elif self.TASK == "text-to-image": - input_types = ["np"] - else: - input_types = ["pil"] - - for input_type in input_types: - if self.TASK == "image-to-image": - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - else: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for output_type in ["np", "pil", "latent"]: - inputs["output_type"] = output_type - outputs = pipeline(**inputs).images - if output_type == "pil": - self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) - elif output_type == "np": - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - else: - self.assertEqual( - outputs.shape, - (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), - ) - - def generate_inputs(self, height=128, width=128, batch_size=1): - inputs = _generate_inputs(batch_size=batch_size) - inputs["height"] = height - inputs["width"] = width - return inputs - - -class ORTStableDiffusionImg2ImgPipelineTest(ORTStableDiffusionPipelineBase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionImg2ImgPipeline - TASK = "image-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_diffusers_pipeline(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - height, width = 128, 128 - - inputs = self.generate_inputs(height=height, width=width) - inputs["prompt"] = "A painting of a squirrel eating a burger" - inputs["image"] = floats_tensor((1, 3, height, width), rng=random.Random(SEED)) - - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - ort_output = ort_pipeline(**inputs, generator=np.random.RandomState(SEED)).images - - diffusers_onnx_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - diffusers_onnx_output = diffusers_onnx_pipeline(**inputs, generator=np.random.RandomState(SEED)).images - - self.assertTrue(np.allclose(ort_output, diffusers_onnx_output, atol=1e-1)) - - def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): - inputs = _generate_inputs(batch_size=batch_size) - inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) - inputs["strength"] = 0.75 - return inputs - - -class ORTStableDiffusionPipelineTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) - pipeline.safety_checker = None - batch_size, num_images_per_prompt, height, width = 1, 2, 64, 32 - - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": "sailing ship in storm by Leonardo da Vinci", - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - self.assertIsInstance(ort_outputs, np.ndarray) - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_image_reproducibility(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - height, width = 64, 32 - np.random.seed(0) - ort_outputs_1 = pipeline(**inputs, height=height, width=width) - np.random.seed(0) - ort_outputs_2 = pipeline(**inputs, height=height, width=width) - ort_outputs_3 = pipeline(**inputs, height=height, width=width) - # Compare model outputs - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_negative_prompt(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - inputs["height"], inputs["width"] = 64, 32 - negative_prompt = ["This is a negative prompt"] - np.random.seed(0) - image_slice_1 = pipeline(**inputs, negative_prompt=negative_prompt).images[0, -3:, -3:, -1] - prompt = inputs.pop("prompt") - embeds = [] - for p in [prompt, negative_prompt]: - text_inputs = pipeline.tokenizer( - p, - padding="max_length", - max_length=pipeline.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_inputs = text_inputs["input_ids"].astype(pipeline.text_encoder.input_dtype.get("input_ids", np.int32)) - embeds.append(pipeline.text_encoder(text_inputs)[0]) - - inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds - np.random.seed(0) - image_slice_2 = pipeline(**inputs).images[0, -3:, -3:, -1] - self.assertTrue(np.allclose(image_slice_1, image_slice_2, atol=1e-4)) - - -class ORTStableDiffusionXLPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion-xl", - ] - ORTMODEL_CLASS = ORTStableDiffusionXLPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.text_encoder_2, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = StableDiffusionXLPipeline.from_pretrained(MODEL_NAMES[model_arch]) - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 32 - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - self.assertIsInstance(ort_outputs, np.ndarray) - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_image_reproducibility(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - height, width = 64, 32 - np.random.seed(0) - ort_outputs_1 = pipeline(**inputs, height=height, width=width) - np.random.seed(0) - ort_outputs_2 = pipeline(**inputs, height=height, width=width) - ort_outputs_3 = pipeline(**inputs, height=height, width=width) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - -class ORTStableDiffusionInpaintPipelineTest(ORTStableDiffusionPipelineBase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionInpaintPipeline - TASK = "inpainting" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_diffusers_pipeline(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - diffusers_pipeline = self.ORTMODEL_CLASS.auto_model_class.from_pretrained(MODEL_NAMES[model_arch]) - height, width = 64, 64 - latents_shape = ( - 1, - ort_pipeline.vae_decoder.config["latent_channels"], - height // ort_pipeline.vae_scale_factor, - width // ort_pipeline.vae_scale_factor, - ) - inputs = self.generate_inputs(height=height, width=width) - - np_latents = np.random.rand(*latents_shape).astype(np.float32) - torch_latents = torch.from_numpy(np_latents) - - ort_outputs = ort_pipeline(**inputs, latents=np_latents).images - self.assertEqual(ort_outputs.shape, (1, height, width, 3)) - - diffusers_outputs = diffusers_pipeline(**inputs, latents=torch_latents).images - self.assertEqual(diffusers_outputs.shape, (1, height, width, 3)) - - self.assertTrue(np.allclose(ort_outputs, diffusers_outputs, atol=1e-4)) - - def generate_inputs(self, height=128, width=128, batch_size=1): - inputs = super(ORTStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width) - inputs["image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] - inputs["mask_image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] - return inputs - - -class ORTStableDiffusionXLImg2ImgPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion-xl", - ] - ORTMODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline - TASK = "image-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_inference(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - - height, width = 128, 128 - inputs = self.generate_inputs(height=height, width=width) - inputs["image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] - expected_slice = np.array([0.6515, 0.5405, 0.4858, 0.5632, 0.5174, 0.5681, 0.4948, 0.4253, 0.5080]) - - self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) - - def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): - inputs = _generate_inputs(batch_size=batch_size) - inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) - inputs["strength"] = 0.75 - return inputs - - -class ImageProcessorTest(unittest.TestCase): - def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pt = torch.stack(_create_image(height=8, width=8, batch_size=1, input_type="pt")) - input_np = to_np(input_pt) - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pt), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_np = np.stack(_create_image(height=8, width=8, input_type="np")) - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pil = _create_image(height=8, width=8, batch_size=1, input_type="pil") - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) - for i, o in zip(input_pil, out): - in_np = np.array(i) - out_np = to_np(out) if output_type == "pil" else (to_np(out) * 255).round() - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - -class ORTLatentConsistencyModelPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "latent-consistency", - ] - ORTMODEL_CLASS = ORTLatentConsistencyModelPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - @unittest.skipIf( - parse(_diffusers_version) <= Version("0.21.4"), - "not supported with this diffusers version, needs diffusers>=v0.22.0", - ) - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = LatentConsistencyModelPipeline.from_pretrained(MODEL_NAMES[model_arch]) - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 32 - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_scale": 8.5, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - self.assertIsInstance(ort_outputs, np.ndarray) - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device)