From 915c050bd07451866053c7cb97bd4de9750569c9 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 13 Aug 2024 15:04:48 +0200 Subject: [PATCH] chore: apply black formatter (#153) This commit applies the black formatter to the codebase to ensure the code formatting is consistent. --- runner/app/pipelines/audio_to_text.py | 6 +++++- runner/app/pipelines/image_to_image.py | 22 +++++++++++++++------- runner/app/pipelines/optim/sfast.py | 3 +-- runner/app/pipelines/text_to_image.py | 2 +- runner/app/pipelines/upscale.py | 10 +++++++--- runner/app/pipelines/utils/__init__.py | 14 ++++++++++---- runner/app/pipelines/utils/utils.py | 11 +++++++---- runner/gen_openapi.py | 10 ++++++++-- runner/modal_app.py | 3 +-- 9 files changed, 55 insertions(+), 26 deletions(-) diff --git a/runner/app/pipelines/audio_to_text.py b/runner/app/pipelines/audio_to_text.py index 300831b9..670f360f 100644 --- a/runner/app/pipelines/audio_to_text.py +++ b/runner/app/pipelines/audio_to_text.py @@ -45,7 +45,11 @@ def __init__(self, model_id: str): kwargs["torch_dtype"] = torch.bfloat16 model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, low_cpu_mem_usage=True, use_safetensors=True, cache_dir=get_model_dir(), **kwargs + model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + cache_dir=get_model_dir(), + **kwargs, ).to(torch_device) processor = AutoProcessor.from_pretrained(model_id, cache_dir=get_model_dir()) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 9d9e7fdf..f42e6550 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -6,13 +6,21 @@ import PIL import torch from app.pipelines.base import Pipeline -from app.pipelines.utils import (SafetyChecker, get_model_dir, - get_torch_device, is_lightning_model, - is_turbo_model) -from diffusers import (AutoPipelineForImage2Image, - EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionXLPipeline, UNet2DConditionModel) +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, +) +from diffusers import ( + AutoPipelineForImage2Image, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) from huggingface_hub import file_download, hf_hub_download from PIL import ImageFile from safetensors.torch import load_file diff --git a/runner/app/pipelines/optim/sfast.py b/runner/app/pipelines/optim/sfast.py index 8bfa5b10..c449aadb 100644 --- a/runner/app/pipelines/optim/sfast.py +++ b/runner/app/pipelines/optim/sfast.py @@ -5,8 +5,7 @@ import logging -from sfast.compilers.diffusion_pipeline_compiler import (CompilationConfig, - compile) +from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile logger = logging.getLogger(__name__) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 1a062eae..84a2228e 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -17,10 +17,10 @@ from diffusers import ( AutoPipelineForText2Image, EulerDiscreteScheduler, + FluxPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, UNet2DConditionModel, - FluxPipeline, ) from diffusers.models import AutoencoderKL from huggingface_hub import file_download, hf_hub_download diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index c7ee404c..e36e4606 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -5,9 +5,13 @@ import PIL import torch from app.pipelines.base import Pipeline -from app.pipelines.utils import (SafetyChecker, get_model_dir, - get_torch_device, is_lightning_model, - is_turbo_model) +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, +) from diffusers import StableDiffusionUpscalePipeline from huggingface_hub import file_download from PIL import ImageFile diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index 79cc49b9..844b86e9 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -1,6 +1,12 @@ """This module contains several utility functions that are used across the pipelines module.""" -from app.pipelines.utils.utils import (SafetyChecker, get_model_dir, - get_model_path, get_torch_device, - is_lightning_model, is_turbo_model, - split_prompt, validate_torch_device) +from app.pipelines.utils.utils import ( + SafetyChecker, + get_model_dir, + get_model_path, + get_torch_device, + is_lightning_model, + is_turbo_model, + split_prompt, + validate_torch_device, +) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 7edc903e..8ba653dc 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -4,7 +4,7 @@ import os import re from pathlib import Path -from typing import Optional +from typing import Dict, Optional import numpy as np import torch @@ -12,7 +12,6 @@ from PIL import Image from torch import dtype as TorchDtype from transformers import CLIPImageProcessor -from typing import Dict logger = logging.getLogger(__name__) @@ -99,10 +98,14 @@ def split_prompt( Returns: Dict[str, str]: A dictionary of all prompts, including the main prompt. """ - prompts = [prompt.strip() for prompt in input_prompt.split(separator, max_splits) if prompt.strip()] + prompts = [ + prompt.strip() + for prompt in input_prompt.split(separator, max_splits) + if prompt.strip() + ] if not prompts: return {} - + start_index = max(1, len(prompts) - max_splits) if max_splits >= 0 else 1 prompt_dict = {f"{key_prefix}": prompts[0]} diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index bd62d71b..7fde5ee3 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -5,8 +5,14 @@ import yaml from app.main import app, use_route_names_as_operation_ids -from app.routes import (audio_to_text, health, image_to_image, image_to_video, - text_to_image, upscale) +from app.routes import ( + audio_to_text, + health, + image_to_image, + image_to_video, + text_to_image, + upscale, +) from fastapi.openapi.utils import get_openapi # Specify Endpoints for OpenAPI schema generation. diff --git a/runner/modal_app.py b/runner/modal_app.py index 3789f582..23acdf9f 100644 --- a/runner/modal_app.py +++ b/runner/modal_app.py @@ -2,8 +2,7 @@ import os from pathlib import Path -from app.main import (config_logging, load_route, - use_route_names_as_operation_ids) +from app.main import config_logging, load_route, use_route_names_as_operation_ids from app.routes import health from modal import Image, Secret, Stub, Volume, asgi_app, enter, method