Skip to content

Commit

Permalink
prompt to nft service
Browse files Browse the repository at this point in the history
  • Loading branch information
stelios-ritual committed Oct 4, 2024
1 parent 99f3d18 commit cc20038
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 58 deletions.
87 changes: 43 additions & 44 deletions projects/prompt-to-nft/container/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import aiohttp
from eth_abi import decode, encode # type: ignore
from infernet_ml.utils.service_models import InfernetInput, JobLocation
from quart import Quart, request
from ritual_arweave.file_manager import FileManager

Expand Down Expand Up @@ -34,6 +33,7 @@ def ensure_env_vars() -> None:
def create_app() -> Quart:
app = Quart(__name__)
ensure_env_vars()
temp_file = "image.png"

@app.route("/")
def index() -> str:
Expand All @@ -44,28 +44,30 @@ def index() -> str:

@app.route("/service_output", methods=["POST"])
async def inference() -> dict[str, Any]:
req_data = await request.get_json()
"""
InfernetInput has the format:
Input data has the format:
source: (0 on-chain, 1 off-chain)
destination: (0 on-chain, 1 off-chain)
data: dict[str, Any]
"""
infernet_input: InfernetInput = InfernetInput(**req_data)
temp_file = "image.png"

match infernet_input:
case InfernetInput(source=JobLocation.OFFCHAIN):
prompt: str = cast(dict[str, str], infernet_input.data)["prompt"]
case InfernetInput(source=JobLocation.ONCHAIN):
# On-chain requests are sent as a generalized hex-string which we will
# decode to the appropriate format.
(prompt, mintTo) = decode(
["string", "address"], bytes.fromhex(cast(str, infernet_input.data))
)
log.info("mintTo: %s", mintTo)
log.info("prompt: %s", prompt)
case _:
raise ValueError("Invalid source")
req_data: dict[str, Any] = await request.get_json()
onchain_source = True if req_data.get("source") == 0 else False
onchain_destination = True if req_data.get("destination") == 0 else False
data = req_data.get("data")

if onchain_source:
"""
For on-chain requests, the prompt is sent as a generalized hex-string
which we will decode to the appropriate format.
"""
(prompt, mintTo) = decode(
["string", "address"], bytes.fromhex(cast(str, data))
)
log.info("mintTo: %s", mintTo)
log.info("prompt: %s", prompt)
else:
"""For off-chain requests, the prompt is sent as is."""
prompt = cast(dict[str, str], data).get("prompt")

# run the inference and download the image to a temp file
await run_inference(prompt, temp_file)
Expand All @@ -74,38 +76,35 @@ async def inference() -> dict[str, Any]:
Path(temp_file), {"Content-Type": "image/png"}
)

match infernet_input:
case InfernetInput(destination=JobLocation.OFFCHAIN):
"""
In case of an off-chain request, the result is returned as is.
"""
return {
"prompt": prompt,
"hash": tx.id,
"image_url": f"https://arweave.net/{tx.id}",
}
case InfernetInput(destination=JobLocation.ONCHAIN):
"""
In case of an on-chain request, the result is returned in the format:
if onchain_destination:
"""
For on-chain requests, the result is returned in the format:
{
"raw_input": str,
"processed_input": str,
"raw_output": str,
"processed_output": str,
"proof": str,
}
refer to: https://docs.ritual.net/infernet/node/advanced/containers for
more info.
"""
return {
"raw_input": infernet_input.data,
"processed_input": "",
"raw_output": encode(["string"], [tx.id]).hex(),
"processed_output": "",
"proof": "",
}
case _:
raise ValueError("Invalid destination")
refer to: https://docs.ritual.net/infernet/node/advanced/containers for more
info.
"""
return {
"raw_input": data,
"processed_input": "",
"raw_output": encode(["string"], [tx.id]).hex(),
"processed_output": "",
"proof": "",
}
else:
"""
For off-chain request, the result is returned as is.
"""
return {
"prompt": prompt,
"hash": tx.id,
"image_url": f"https://arweave.net/{tx.id}",
}

return app

Expand Down
3 changes: 1 addition & 2 deletions projects/prompt-to-nft/container/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
quart==0.19.4
infernet-ml==1.0.0
web3==6.15.0
tqdm==4.66.4
web3==6.15.0
11 changes: 4 additions & 7 deletions projects/prompt-to-nft/stablediffusion/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
diffusers~=0.19
invisible_watermark~=0.1
transformers==4.41.2
accelerate~=0.21
huggingface-hub==0.23.0
quart==0.19.4
safetensors~=0.3
Quart==0.19.4
jmespath==1.0.1
huggingface-hub==0.20.3
infernet-ml==1.0.0
torch==2.2.1
transformers==4.41.2
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download
from infernet_ml.workflows.inference.base_inference_workflow import (
BaseInferenceWorkflow,
)


class StableDiffusionWorkflow(BaseInferenceWorkflow):
class StableDiffusionWorkflow:
def __init__(
self,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.args: list[Any] = list(args)
self.kwargs: dict[Any, Any] = kwargs

self.is_setup = False

def do_setup(self) -> Any:
ignore = [
Expand Down

0 comments on commit cc20038

Please sign in to comment.