Skip to content

Commit

Permalink
stable diffusion fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
stelios-ritual committed Oct 7, 2024
1 parent 20ad930 commit b415b1d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@ remote_sync

# Virtual envs
**/env/**

# Arweave keyfile
keyfile-*.json
2 changes: 2 additions & 0 deletions projects/prompt-to-nft/stablediffusion/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
accelerate==1.0.0
diffusers~=0.19
huggingface-hub==0.23.0
numpy==1.26.4
quart==0.19.4
safetensors~=0.3
torch==2.2.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(

self.is_setup = False

def do_setup(self) -> Any:
def setup(self) -> Any:
ignore = [
"*.bin",
"*.onnx_data",
Expand All @@ -35,7 +35,6 @@ def do_setup(self) -> Any:
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
device_map="auto",
)

# Load base model
Expand All @@ -51,10 +50,9 @@ def do_setup(self) -> Any:
**load_options,
)

def do_preprocessing(self, input_data: dict[str, Any]) -> dict[str, Any]:
return input_data
self.is_setup = True

def do_run_model(self, input: dict[str, Any]) -> bytes:
def inference(self, input: dict[str, Any]) -> bytes:
negative_prompt = input.get("negative_prompt", "disfigured, ugly, deformed")
prompt = input["prompt"]
n_steps = input.get("n_steps", 24)
Expand All @@ -81,6 +79,3 @@ def do_run_model(self, input: dict[str, Any]) -> bytes:
image_bytes = byte_stream.getvalue()

return image_bytes

def do_postprocessing(self, input: Any, output: Any) -> Any:
return output

0 comments on commit b415b1d

Please sign in to comment.