diff --git a/projects/prompt-to-nft/container/src/app.py b/projects/prompt-to-nft/container/src/app.py index 34c2b28..4832599 100644 --- a/projects/prompt-to-nft/container/src/app.py +++ b/projects/prompt-to-nft/container/src/app.py @@ -5,7 +5,6 @@ import aiohttp from eth_abi import decode, encode # type: ignore -from infernet_ml.utils.service_models import InfernetInput, JobLocation from quart import Quart, request from ritual_arweave.file_manager import FileManager @@ -34,6 +33,7 @@ def ensure_env_vars() -> None: def create_app() -> Quart: app = Quart(__name__) ensure_env_vars() + temp_file = "image.png" @app.route("/") def index() -> str: @@ -44,28 +44,30 @@ def index() -> str: @app.route("/service_output", methods=["POST"]) async def inference() -> dict[str, Any]: - req_data = await request.get_json() """ - InfernetInput has the format: + Input data has the format: source: (0 on-chain, 1 off-chain) + destination: (0 on-chain, 1 off-chain) data: dict[str, Any] """ - infernet_input: InfernetInput = InfernetInput(**req_data) - temp_file = "image.png" - - match infernet_input: - case InfernetInput(source=JobLocation.OFFCHAIN): - prompt: str = cast(dict[str, str], infernet_input.data)["prompt"] - case InfernetInput(source=JobLocation.ONCHAIN): - # On-chain requests are sent as a generalized hex-string which we will - # decode to the appropriate format. - (prompt, mintTo) = decode( - ["string", "address"], bytes.fromhex(cast(str, infernet_input.data)) - ) - log.info("mintTo: %s", mintTo) - log.info("prompt: %s", prompt) - case _: - raise ValueError("Invalid source") + req_data: dict[str, Any] = await request.get_json() + onchain_source = True if req_data.get("source") == 0 else False + onchain_destination = True if req_data.get("destination") == 0 else False + data = req_data.get("data") + + if onchain_source: + """ + For on-chain requests, the prompt is sent as a generalized hex-string + which we will decode to the appropriate format. + """ + (prompt, mintTo) = decode( + ["string", "address"], bytes.fromhex(cast(str, data)) + ) + log.info("mintTo: %s", mintTo) + log.info("prompt: %s", prompt) + else: + """For off-chain requests, the prompt is sent as is.""" + prompt = cast(dict[str, str], data).get("prompt") # run the inference and download the image to a temp file await run_inference(prompt, temp_file) @@ -74,19 +76,9 @@ async def inference() -> dict[str, Any]: Path(temp_file), {"Content-Type": "image/png"} ) - match infernet_input: - case InfernetInput(destination=JobLocation.OFFCHAIN): - """ - In case of an off-chain request, the result is returned as is. - """ - return { - "prompt": prompt, - "hash": tx.id, - "image_url": f"https://arweave.net/{tx.id}", - } - case InfernetInput(destination=JobLocation.ONCHAIN): - """ - In case of an on-chain request, the result is returned in the format: + if onchain_destination: + """ + For on-chain requests, the result is returned in the format: { "raw_input": str, "processed_input": str, @@ -94,18 +86,25 @@ async def inference() -> dict[str, Any]: "processed_output": str, "proof": str, } - refer to: https://docs.ritual.net/infernet/node/advanced/containers for - more info. - """ - return { - "raw_input": infernet_input.data, - "processed_input": "", - "raw_output": encode(["string"], [tx.id]).hex(), - "processed_output": "", - "proof": "", - } - case _: - raise ValueError("Invalid destination") + refer to: https://docs.ritual.net/infernet/node/advanced/containers for more + info. + """ + return { + "raw_input": data, + "processed_input": "", + "raw_output": encode(["string"], [tx.id]).hex(), + "processed_output": "", + "proof": "", + } + else: + """ + For off-chain request, the result is returned as is. + """ + return { + "prompt": prompt, + "hash": tx.id, + "image_url": f"https://arweave.net/{tx.id}", + } return app diff --git a/projects/prompt-to-nft/container/src/requirements.txt b/projects/prompt-to-nft/container/src/requirements.txt index ba89fe5..b339323 100644 --- a/projects/prompt-to-nft/container/src/requirements.txt +++ b/projects/prompt-to-nft/container/src/requirements.txt @@ -1,4 +1,3 @@ quart==0.19.4 -infernet-ml==1.0.0 -web3==6.15.0 tqdm==4.66.4 +web3==6.15.0 diff --git a/projects/prompt-to-nft/stablediffusion/src/requirements.txt b/projects/prompt-to-nft/stablediffusion/src/requirements.txt index 74e9663..51299fb 100644 --- a/projects/prompt-to-nft/stablediffusion/src/requirements.txt +++ b/projects/prompt-to-nft/stablediffusion/src/requirements.txt @@ -1,9 +1,6 @@ diffusers~=0.19 -invisible_watermark~=0.1 -transformers==4.41.2 -accelerate~=0.21 +huggingface-hub==0.23.0 +quart==0.19.4 safetensors~=0.3 -Quart==0.19.4 -jmespath==1.0.1 -huggingface-hub==0.20.3 -infernet-ml==1.0.0 +torch==2.2.1 +transformers==4.41.2 diff --git a/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py b/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py index 29e8736..c8c87fd 100644 --- a/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py +++ b/projects/prompt-to-nft/stablediffusion/src/stable_diffusion_workflow.py @@ -4,18 +4,18 @@ import torch from diffusers import DiffusionPipeline from huggingface_hub import snapshot_download -from infernet_ml.workflows.inference.base_inference_workflow import ( - BaseInferenceWorkflow, -) -class StableDiffusionWorkflow(BaseInferenceWorkflow): +class StableDiffusionWorkflow: def __init__( self, *args: Any, **kwargs: Any, ): - super().__init__(*args, **kwargs) + self.args: list[Any] = list(args) + self.kwargs: dict[Any, Any] = kwargs + + self.is_setup = False def do_setup(self) -> Any: ignore = [