From b415b1d344b37a541428b6370a718fc70410bbe4 Mon Sep 17 00:00:00 2001 From: Stelios Rousoglou Date: Mon, 7 Oct 2024 13:21:27 -0400 Subject: [PATCH] stable diffusion fixes --- .gitignore | 3 +++ .../prompt-to-nft/stablediffusion/requirements.txt | 2 ++ .../stablediffusion/src/stable_diffusion_workflow.py | 11 +++-------- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 3d2dd57..478f9dc 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,6 @@ remote_sync # Virtual envs **/env/** + +# Arweave keyfile +keyfile-*.json \ No newline at end of file diff --git a/projects/prompt-to-nft/stablediffusion/requirements.txt b/projects/prompt-to-nft/stablediffusion/requirements.txt index 51299fb..9ab791c 100644 --- a/projects/prompt-to-nft/stablediffusion/requirements.txt +++ b/projects/prompt-to-nft/stablediffusion/requirements.txt @@ -1,5 +1,7 @@ +accelerate==1.0.0 diffusers~=0.19 huggingface-hub==0.23.0 +numpy==1.26.4 quart==0.19.4 safetensors~=0.3 torch==2.2.1 diff --git a/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py b/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py index c8c87fd..e7c3cf8 100644 --- a/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py +++ b/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py @@ -17,7 +17,7 @@ def __init__( self.is_setup = False - def do_setup(self) -> Any: + def setup(self) -> Any: ignore = [ "*.bin", "*.onnx_data", @@ -35,7 +35,6 @@ def do_setup(self) -> Any: torch_dtype=torch.float16, use_safetensors=True, variant="fp16", - device_map="auto", ) # Load base model @@ -51,10 +50,9 @@ def do_setup(self) -> Any: **load_options, ) - def do_preprocessing(self, input_data: dict[str, Any]) -> dict[str, Any]: - return input_data + self.is_setup = True - def do_run_model(self, input: dict[str, Any]) -> bytes: + def inference(self, input: dict[str, Any]) -> bytes: negative_prompt = input.get("negative_prompt", "disfigured, ugly, deformed") prompt = input["prompt"] n_steps = input.get("n_steps", 24) @@ -81,6 +79,3 @@ def do_run_model(self, input: dict[str, Any]) -> bytes: image_bytes = byte_stream.getvalue() return image_bytes - - def do_postprocessing(self, input: Any, output: Any) -> Any: - return output