Skip to content

Commit

Permalink
Merge pull request #262 from Ravleen-Solulab/tool_fix_stability
Browse files Browse the repository at this point in the history
Tool_fix/stability
  • Loading branch information
0xArdi authored Nov 6, 2024
2 parents ec4ffdc + c162664 commit 2c6a77f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion packages/packages.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"custom/valory/prediction_request_embedding/0.1.0": "bafybeihtwykqnoxluqo2n4w2ccoh4xqoc6pifevol6obho3fneg7touzj4",
"custom/valory/resolve_market/0.1.0": "bafybeidog2vsqmezxe63jqjpf7p6qmqy3opq3rppvihqtehf6k44hzyo74",
"custom/valory/prediction_request/0.1.0": "bafybeigupgsneg4nsaljassdcq4mu53abrglmw42vfrss5kwxy7fybtisu",
"custom/valory/stability_ai_request/0.1.0": "bafybeiamqdkh3nqsul6ihgijvkxyyretpwzpssh6dps3cmovippaau7wmy",
"custom/valory/stability_ai_request/0.1.0": "bafybeifzlmtyvvo2x43sx6y2f53gqyakoqrxk5yclembw5xc3gihf2vrxm",
"custom/polywrap/prediction_with_research_report/0.1.0": "bafybeiebis63otzt7vy44zxk4uwfknrttfsibnas5x7sttwgh4lzuhrnna",
"custom/jhehemann/prediction_sum_url_content/0.1.0": "bafybeih6wp7icu5apa2uyuyisg65reh6ptl5umeji7qvgoluwplufkrypy",
"custom/psouranis/optimization_by_prompting/0.1.0": "bafybeigvweriadejipt7rhsekoksf6ff6tqwaovjywzmhnzh22khdtfbfa",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
stabilityai_request.py: bafybeiccpaoydpdi4jvpcupooolzvgd4hefgmi2mzi6gd5nhz4liipzv6e
stabilityai_request.py: bafybeiaafasblkv2uuccrxexenprpwns7r5yuqai2g7i2a5dzjlmwyff3e
fingerprint_ignore_patterns: []
entry_point: stabilityai_request.py
callable: run
Expand Down
39 changes: 25 additions & 14 deletions packages/valory/customs/stability_ai_request/stabilityai_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,22 @@
PREFIX = "stabilityai-"
ENGINES = {
"picture": [
"stable-diffusion-v1-5",
"stable-diffusion-xl-beta-v2-2-2",
"stable-diffusion-512-v2-1",
"stable-diffusion-768-v2-1",
"stable-diffusion-xl-1024-v1-0",
"stable-diffusion-v1-6",
]
}
ENGINE_SIZE_CHART = {
"stable-diffusion-v1-5": {"height": 512, "width": 512},
"stable-diffusion-xl-beta-v2-2-2": {"height": 512, "width": 512},
"stable-diffusion-512-v2-1": {"height": 512, "width": 512},
"stable-diffusion-768-v2-1": {"height": 768, "width": 768},
"stable-diffusion-xl-1024-v1-0": [
{"height": 1024, "width": 1024},
{"height": 1152, "width": 896},
{"height": 896, "width": 1152},
{"height": 1216, "width": 832},
{"height": 1344, "width": 768},
{"height": 768, "width": 1344},
{"height": 1536, "width": 640},
{"height": 640, "width": 1536},
],
"stable-diffusion-v1-6": {"height": 512, "width": 512},
}

ALLOWED_TOOLS = [PREFIX + value for value in ENGINES["picture"]]
Expand Down Expand Up @@ -131,18 +136,24 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
api_key = kwargs["api_keys"]["stabilityai"]
tool = kwargs["tool"]
prompt = kwargs["prompt"]

if tool not in ALLOWED_TOOLS:
return f"Tool {tool} is not in the list of supported tools.", None, None, None

# Place content moderation request here if needed
engine = tool.replace(PREFIX, "")
cfg_scale = kwargs.get("cfg_scale", DEFAULT_STABILITYAI_SETTINGS["cfg_scale"])
weight = kwargs.get("weight", DEFAULT_STABILITYAI_SETTINGS["weight"])
clip_guidance_preset = kwargs.get(
"clip_guidance_preset", DEFAULT_STABILITYAI_SETTINGS["clip_guidance_preset"]
)
height = kwargs.get("height", ENGINE_SIZE_CHART[engine]["height"])
width = kwargs.get("width", ENGINE_SIZE_CHART[engine]["width"])
clip_guidance_preset = kwargs.get("clip_guidance_preset", DEFAULT_STABILITYAI_SETTINGS["clip_guidance_preset"])

# Handle different engine types
if engine == "stable-diffusion-xl-1024-v1-0":
height = kwargs.get("height", ENGINE_SIZE_CHART[engine][0]["height"]) # Access first size as default
width = kwargs.get("width", ENGINE_SIZE_CHART[engine][0]["width"])
else: # For stable-diffusion-v1-6
height = kwargs.get("height", ENGINE_SIZE_CHART[engine]["height"])
width = kwargs.get("width", ENGINE_SIZE_CHART[engine]["width"])

samples = kwargs.get("samples", DEFAULT_STABILITYAI_SETTINGS["samples"])
steps = kwargs.get("steps", DEFAULT_STABILITYAI_SETTINGS["steps"])

Expand Down Expand Up @@ -181,4 +192,4 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
None,
None,
None,
)
)

0 comments on commit 2c6a77f

Please sign in to comment.