diff --git a/06_gpu_and_ml/comfyui/comfyapp.py b/06_gpu_and_ml/comfyui/comfyapp.py index 41ddc98da..bc3e0bd9f 100644 --- a/06_gpu_and_ml/comfyui/comfyapp.py +++ b/06_gpu_and_ml/comfyui/comfyapp.py @@ -197,24 +197,52 @@ def launch_comfy_background(self): subprocess.run(cmd, shell=True, check=True) @modal.method() - def infer(self, workflow_path: str = "/root/workflow_api.json"): - # runs the comfy run --workflow command as a subprocess - cmd = f"comfy run --workflow {workflow_path} --wait --timeout 1200" - subprocess.run(cmd, shell=True, check=True) + def infer( + self, client_id: str, workflow_path: str = "/root/workflow_api.json" + ): + import json + import time + import urllib + + # reference locally running ComfyUI server + server_address = "127.0.0.1:8188" # completed workflows write output images to this directory output_dir = "/root/comfy/ComfyUI/output" - # looks up the name of the output image file based on the workflow - workflow = json.loads(Path(workflow_path).read_text()) - file_prefix = [ - node.get("inputs") - for node in workflow.values() - if node.get("class_type") == "SaveImage" - ][0]["filename_prefix"] + + # adapted from https://github.com/comfyanonymous/ComfyUI/blob/master/script_examples/websockets_api_example.py + def queue_prompt(prompt): + # queues a workflow run (the workflow JSON is passed in as the "prompt" variable) + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode("utf-8") + req = urllib.request.Request( + "http://{}/prompt".format(server_address), data=data + ) + return json.loads(urllib.request.urlopen(req).read()) + + def get_history(prompt_id): + # fetches the history of a queued prompt + with urllib.request.urlopen( + "http://{}/history/{}".format(server_address, prompt_id) + ) as response: + return json.loads(response.read()) + + # queue the workflow + with open(workflow_path, "r") as f: + workflow = json.load(f) + prompt_id = queue_prompt(workflow)["prompt_id"] + print(f"Running workflow with request id {client_id}...") + + # poll the history endpoint every second until the workflow is complete + while True: + history = get_history(prompt_id) + if history: + break + time.sleep(1) # returns the image as bytes for f in Path(output_dir).iterdir(): - if f.name.startswith(file_prefix): + if f.name.startswith(client_id): return f.read_bytes() @modal.web_endpoint(method="POST") @@ -237,7 +265,7 @@ def api(self, item: Dict): json.dump(workflow_data, Path(new_workflow_file).open("w")) # run inference on the currently running container - img_bytes = self.infer.local(new_workflow_file) + img_bytes = self.infer.local(client_id, new_workflow_file) return Response(img_bytes, media_type="image/jpeg")