Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use basic http polling for api #997

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions 06_gpu_and_ml/comfyui/comfyapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,52 @@ def launch_comfy_background(self):
subprocess.run(cmd, shell=True, check=True)

@modal.method()
def infer(self, workflow_path: str = "/root/workflow_api.json"):
# runs the comfy run --workflow command as a subprocess
cmd = f"comfy run --workflow {workflow_path} --wait --timeout 1200"
subprocess.run(cmd, shell=True, check=True)
def infer(
self, client_id: str, workflow_path: str = "/root/workflow_api.json"
):
import json
import time
import urllib

# reference locally running ComfyUI server
server_address = "127.0.0.1:8188"

# completed workflows write output images to this directory
output_dir = "/root/comfy/ComfyUI/output"
# looks up the name of the output image file based on the workflow
workflow = json.loads(Path(workflow_path).read_text())
file_prefix = [
node.get("inputs")
for node in workflow.values()
if node.get("class_type") == "SaveImage"
][0]["filename_prefix"]

# adapted from https://github.com/comfyanonymous/ComfyUI/blob/master/script_examples/websockets_api_example.py
def queue_prompt(prompt):
# queues a workflow run (the workflow JSON is passed in as the "prompt" variable)
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(
"http://{}/prompt".format(server_address), data=data
)
return json.loads(urllib.request.urlopen(req).read())

def get_history(prompt_id):
# fetches the history of a queued prompt
with urllib.request.urlopen(
"http://{}/history/{}".format(server_address, prompt_id)
) as response:
return json.loads(response.read())

# queue the workflow
with open(workflow_path, "r") as f:
workflow = json.load(f)
prompt_id = queue_prompt(workflow)["prompt_id"]
print(f"Running workflow with request id {client_id}...")

# poll the history endpoint every second until the workflow is complete
while True:
history = get_history(prompt_id)
if history:
break
time.sleep(1)

# returns the image as bytes
for f in Path(output_dir).iterdir():
if f.name.startswith(file_prefix):
if f.name.startswith(client_id):
return f.read_bytes()

@modal.web_endpoint(method="POST")
Expand All @@ -237,7 +265,7 @@ def api(self, item: Dict):
json.dump(workflow_data, Path(new_workflow_file).open("w"))

# run inference on the currently running container
img_bytes = self.infer.local(new_workflow_file)
img_bytes = self.infer.local(client_id, new_workflow_file)

return Response(img_bytes, media_type="image/jpeg")

Expand Down