From df7ef211dd78d871e6c2bbb80bf54c8b59d3bb80 Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:12:55 +0800 Subject: [PATCH] Add flux upscale (#253) * tmp save * tmp save * fix BizyAirTask in client.py * fix payload error * tmp save * refine * add task_manager * tmp save * refine * refine * refine * refine * add examples/bizyair-flux1-upscale.json * refine --- README.md | 2 + bizyair_example_menu.json | 6 +- bizyair_extras/nodes_upscale_model.py | 105 +-- examples/bizyair-flux1-upscale.json | 700 ++++++++++++++++++ .../commands/processors/prompt_processor.py | 4 +- src/bizyair/commands/servers/prompt_server.py | 194 ++++- src/bizyair/common/caching.py | 198 +++++ src/bizyair/common/client.py | 19 +- src/bizyair/common/utils.py | 4 + src/bizyair/configs/conf.py | 13 +- src/bizyair/configs/models.yaml | 29 + src/bizyair/nodes_base.py | 1 + src/bizyair/path_utils/path_manager.py | 3 +- 13 files changed, 1165 insertions(+), 113 deletions(-) create mode 100644 examples/bizyair-flux1-upscale.json create mode 100644 src/bizyair/common/caching.py diff --git a/README.md b/README.md index a24c1aa7..6a5c4d07 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ # BizyAir + +- [2024/12/17] 🌩️ BizyAir supports Flux Upscale Model. [FLUX Upscale](./examples/bizyair-flux1-upscale.json) - [2024/11/27] 🌩️ BizyAir supports Stable Diffusion 3.5 Large ControlNet Canny, Depth, and Blur. [ControlNet Canny](./examples/bizyair_sd3_5_canny.json) [ControlNet Depth](./examples/bizyair_sd3_5_depth.json) [ControlNet Blur](./examples/bizyair_sd3_5_blur.json) - [2024/11/22] 🌩️ BizyAir supports FLUX Fill, ControlNet and Redux modes. [canny](./examples/bizyair-flux1-tools-canny.json) [depth](./examples/bizyair-flux1-tools-depth.json) [fill](./examples/bizyair-flux-fill1-inpaint.json) [redux](./examples/bizyair-flux1-tools-redux.json) - [2024/11/06] 🌩️ BizyAir PixelWave Flux.1-dev Text to Image node is released. [PixelWave Flux.1-dev Text to Image](./examples/bizyair_flux_pixelwave_txt2img.json) diff --git a/bizyair_example_menu.json b/bizyair_example_menu.json index 65b934d4..23e197cf 100644 --- a/bizyair_example_menu.json +++ b/bizyair_example_menu.json @@ -10,16 +10,14 @@ "FLUX ControlNet Depth": "bizyair-flux1-tools-depth.json", "FLUX Redux": "bizyair-flux1-tools-redux.json", "FLUX Fill": "bizyair-flux-fill1-inpaint.json", - "FLUX Detail Daemon Sampler": "bizyair_flux_detail_daemon_sampler.json" + "FLUX Detail Daemon Sampler": "bizyair_flux_detail_daemon_sampler.json", + "FLUX Upscale": "bizyair-flux1-upscale.json" }, "ControlNet Union": { "Generate an image from a line drawing": "bizyair_showcase_interior_design.json", "Design a submarine like a great white shark": "bizyair_showcase_shark_submarine.json", "All types of ControlNet preprocessors": "bizyair_controlnet_preprocessor_workflow.json" }, - "SD15": { - "UltimateSDUpscale": "bizyair_ultimate_sd_upscale.json" - }, "SDXL": { "Text to Image by BizyAir KSampler": "bizyair_showcase_ksampler_txt2img.json", "Image to Image by BizyAir KSampler": "bizyair_showcase_ksampler_img2img.json", diff --git a/bizyair_extras/nodes_upscale_model.py b/bizyair_extras/nodes_upscale_model.py index 469da05c..bd115d41 100644 --- a/bizyair_extras/nodes_upscale_model.py +++ b/bizyair_extras/nodes_upscale_model.py @@ -1,88 +1,3 @@ -# import os -# import logging -# from spandrel import ModelLoader, ImageModelDescriptor -# from comfy import model_management -# import torch -# import comfy.utils -# import folder_paths - -# try: -# from spandrel_extra_arches import EXTRA_REGISTRY -# from spandrel import MAIN_REGISTRY -# MAIN_REGISTRY.add(*EXTRA_REGISTRY) -# logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") -# except: -# pass - -# class UpscaleModelLoader: -# @classmethod -# def INPUT_TYPES(s): -# return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), -# }} -# RETURN_TYPES = ("UPSCALE_MODEL",) -# FUNCTION = "load_model" - -# CATEGORY = "loaders" - -# def load_model(self, model_name): -# model_path = folder_paths.get_full_path("upscale_models", model_name) -# sd = comfy.utils.load_torch_file(model_path, safe_load=True) -# if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: -# sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) -# out = ModelLoader().load_from_state_dict(sd).eval() - -# if not isinstance(out, ImageModelDescriptor): -# raise Exception("Upscale model must be a single-image model.") - -# return (out, ) - - -# class ImageUpscaleWithModel: -# @classmethod -# def INPUT_TYPES(s): -# return {"required": { "upscale_model": ("UPSCALE_MODEL",), -# "image": ("IMAGE",), -# }} -# RETURN_TYPES = ("IMAGE",) -# FUNCTION = "upscale" - -# CATEGORY = "image/upscaling" - -# def upscale(self, upscale_model, image): -# device = model_management.get_torch_device() - -# memory_required = model_management.module_size(upscale_model.model) -# memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate -# memory_required += image.nelement() * image.element_size() -# model_management.free_memory(memory_required, device) - -# upscale_model.to(device) -# in_img = image.movedim(-1,-3).to(device) - -# tile = 512 -# overlap = 32 - -# oom = True -# while oom: -# try: -# steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) -# pbar = comfy.utils.ProgressBar(steps) -# s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) -# oom = False -# except model_management.OOM_EXCEPTION as e: -# tile //= 2 -# if tile < 128: -# raise e - -# upscale_model.to("cpu") -# s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) -# return (s,) - -# NODE_CLASS_MAPPINGS = { -# "UpscaleModelLoader": UpscaleModelLoader, -# "ImageUpscaleWithModel": ImageUpscaleWithModel -# } - import bizyair.path_utils as folder_paths from bizyair import BizyAirBaseNode, BizyAirNodeIO from bizyair.data_types import UPSCALE_MODEL @@ -98,10 +13,20 @@ def INPUT_TYPES(s): } RETURN_TYPES = (UPSCALE_MODEL,) - FUNCTION = "load_model" + # FUNCTION = "load_model" CATEGORY = "loaders" - def load_model(self, **kwargs): - model = BizyAirNodeIO(self.assigned_id) - model.add_node_data(class_type="UpscaleModelLoader", inputs=kwargs) - return (model,) + +class ImageUpscaleWithModel(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "upscale_model": (UPSCALE_MODEL,), + "image": ("IMAGE",), + } + } + + RETURN_TYPES = ("IMAGE",) + # FUNCTION = "upscale" + CATEGORY = "image/upscaling" diff --git a/examples/bizyair-flux1-upscale.json b/examples/bizyair-flux1-upscale.json new file mode 100644 index 00000000..cd1f9fdd --- /dev/null +++ b/examples/bizyair-flux1-upscale.json @@ -0,0 +1,700 @@ +{ + "last_node_id": 98, + "last_link_id": 164, + "nodes": [ + { + "id": 84, + "type": "BizyAir_FluxGuidance", + "pos": [ + 610.0081176757812, + -516.2211303710938 + ], + "size": [ + 418.1999816894531, + 58 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "BIZYAIR_CONDITIONING", + "link": 134 + } + ], + "outputs": [ + { + "name": "BIZYAIR_CONDITIONING", + "type": "BIZYAIR_CONDITIONING", + "links": [ + 136 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_FluxGuidance" + }, + "widgets_values": [ + 3.5 + ] + }, + { + "id": 85, + "type": "BizyAir_ConditioningZeroOut", + "pos": [ + 615.65771484375, + -407.4277038574219 + ], + "size": [ + 413.81939697265625, + 26 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "BIZYAIR_CONDITIONING", + "link": 135 + } + ], + "outputs": [ + { + "name": "BIZYAIR_CONDITIONING", + "type": "BIZYAIR_CONDITIONING", + "links": [ + 137 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_ConditioningZeroOut" + }, + "widgets_values": [] + }, + { + "id": 83, + "type": "BizyAir_CLIPTextEncode", + "pos": [ + 608.86376953125, + -651.6828002929688 + ], + "size": [ + 412.748291015625, + 81 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "BIZYAIR_CLIP", + "link": 133 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "BIZYAIR_CONDITIONING", + "links": [ + 134, + 135 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_CLIPTextEncode" + }, + "widgets_values": [ + "a girl", + [ + false, + true + ] + ] + }, + { + "id": 82, + "type": "BizyAir_DualCLIPLoader", + "pos": [ + 38.52742004394531, + -887.0072021484375 + ], + "size": [ + 354.45062255859375, + 106 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "BIZYAIR_CLIP", + "type": "BIZYAIR_CLIP", + "links": [ + 133 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 86, + "type": "BizyAir_UNETLoader", + "pos": [ + 42.18415451049805, + -672.583251953125 + ], + "size": [ + 365.7839050292969, + 82.68058013916016 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "BIZYAIR_MODEL", + "type": "BIZYAIR_MODEL", + "links": [ + 138 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_UNETLoader" + }, + "widgets_values": [ + "flux/flux1-dev.sft", + "fp8_e4m3fn" + ] + }, + { + "id": 81, + "type": "BizyAir_VAELoader", + "pos": [ + 44.5821533203125, + -492.4258728027344 + ], + "size": [ + 365.79168701171875, + 58 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "vae", + "type": "BIZYAIR_VAE", + "links": [ + 131 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VAELoader" + }, + "widgets_values": [ + "flux/ae.sft" + ] + }, + { + "id": 37, + "type": "Reroute", + "pos": [ + 452.06280517578125, + -635.9285888671875 + ], + "size": [ + 75, + 26 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "", + "type": "*", + "link": 138 + } + ], + "outputs": [ + { + "name": "", + "type": "BIZYAIR_MODEL", + "links": [ + 60 + ], + "slot_index": 0 + } + ], + "properties": { + "showOutputText": false, + "horizontal": false + } + }, + { + "id": 39, + "type": "Reroute", + "pos": [ + 458.05914306640625, + -500.0390625 + ], + "size": [ + 75, + 26 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "", + "type": "*", + "link": 131 + } + ], + "outputs": [ + { + "name": "", + "type": "BIZYAIR_VAE", + "links": [ + 64 + ], + "slot_index": 0 + } + ], + "properties": { + "showOutputText": false, + "horizontal": false + } + }, + { + "id": 38, + "type": "Reroute", + "pos": [ + 629.8372192382812, + -949.8573608398438 + ], + "size": [ + 75, + 26 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "", + "type": "*", + "link": 60 + } + ], + "outputs": [ + { + "name": "", + "type": "BIZYAIR_MODEL", + "links": [ + 139 + ], + "slot_index": 0 + } + ], + "properties": { + "showOutputText": false, + "horizontal": false + } + }, + { + "id": 40, + "type": "Reroute", + "pos": [ + 647.3805541992188, + -883.508544921875 + ], + "size": [ + 75, + 26 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "", + "type": "*", + "link": 64 + } + ], + "outputs": [ + { + "name": "", + "type": "BIZYAIR_VAE", + "links": [ + 132 + ], + "slot_index": 0 + } + ], + "properties": { + "showOutputText": false, + "horizontal": false + } + }, + { + "id": 79, + "type": "BizyAir_UpscaleModelLoader", + "pos": [ + 618.9271850585938, + -773.9298706054688 + ], + "size": [ + 382.7853088378906, + 58.608306884765625 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "BIZYAIR_UPSCALE_MODEL", + "type": "BIZYAIR_UPSCALE_MODEL", + "links": [ + 130 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_UpscaleModelLoader" + }, + "widgets_values": [ + "4x_NMKD-Siax_200k.pth" + ] + }, + { + "id": 80, + "type": "BizyAir_UltimateSDUpscale", + "pos": [ + 1084.4683837890625, + -871.3263549804688 + ], + "size": [ + 365.4000244140625, + 614 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 163 + }, + { + "name": "model", + "type": "BIZYAIR_MODEL", + "link": 139 + }, + { + "name": "positive", + "type": "BIZYAIR_CONDITIONING", + "link": 136 + }, + { + "name": "negative", + "type": "BIZYAIR_CONDITIONING", + "link": 137 + }, + { + "name": "vae", + "type": "BIZYAIR_VAE", + "link": 132 + }, + { + "name": "upscale_model", + "type": "BIZYAIR_UPSCALE_MODEL", + "link": 130 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 148 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_UltimateSDUpscale" + }, + "widgets_values": [ + 2, + 927540179114028, + "fixed", + 20, + 8, + "dpmpp_2m", + "karras", + 0.2, + "Linear", + 768, + 768, + 8, + 32, + "None", + 1, + 64, + 8, + 16, + true, + false + ] + }, + { + "id": 98, + "type": "LoadImage", + "pos": [ + 1472.5528564453125, + -918.76171875 + ], + "size": [ + 210, + 307.54327392578125 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 163 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "example.png", + "image" + ] + }, + { + "id": 88, + "type": "SaveImage", + "pos": [ + 1701.8026123046875, + -928.7681884765625 + ], + "size": [ + 423.9115905761719, + 466.6839599609375 + ], + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 148 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "SaveImage" + }, + "widgets_values": [ + "ComfyUI" + ] + } + ], + "links": [ + [ + 60, + 37, + 0, + 38, + 0, + "*" + ], + [ + 64, + 39, + 0, + 40, + 0, + "*" + ], + [ + 130, + 79, + 0, + 80, + 5, + "BIZYAIR_UPSCALE_MODEL" + ], + [ + 131, + 81, + 0, + 39, + 0, + "*" + ], + [ + 132, + 40, + 0, + 80, + 4, + "BIZYAIR_VAE" + ], + [ + 133, + 82, + 0, + 83, + 0, + "BIZYAIR_CLIP" + ], + [ + 134, + 83, + 0, + 84, + 0, + "BIZYAIR_CONDITIONING" + ], + [ + 135, + 83, + 0, + 85, + 0, + "BIZYAIR_CONDITIONING" + ], + [ + 136, + 84, + 0, + 80, + 2, + "BIZYAIR_CONDITIONING" + ], + [ + 137, + 85, + 0, + 80, + 3, + "BIZYAIR_CONDITIONING" + ], + [ + 138, + 86, + 0, + 37, + 0, + "*" + ], + [ + 139, + 38, + 0, + 80, + 1, + "BIZYAIR_MODEL" + ], + [ + 148, + 80, + 0, + 88, + 0, + "IMAGE" + ], + [ + 163, + 98, + 0, + 80, + 0, + "IMAGE" + ] + ], + "groups": [ + { + "id": 1, + "title": "LoadModel", + "bounding": [ + 9.072793006896973, + -980.7294311523438, + 574.950927734375, + 596.8793334960938 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 2, + "title": "ApplyUpscale", + "bounding": [ + 594.8491821289062, + -981.4524536132812, + 868.4437255859375, + 727.177734375 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], + "config": {}, + "extra": { + "ds": { + "scale": 0.5644739300537773, + "offset": [ + 320.04684143941637, + 1376.536543579042 + ] + } + }, + "version": 0.4 +} diff --git a/src/bizyair/commands/processors/prompt_processor.py b/src/bizyair/commands/processors/prompt_processor.py index 4ddfc496..fc403b9f 100644 --- a/src/bizyair/commands/processors/prompt_processor.py +++ b/src/bizyair/commands/processors/prompt_processor.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List from bizyair.common import client +from bizyair.common.caching import BizyAirTaskCache, CacheConfig from bizyair.common.env_var import ( BIZYAIR_DEBUG, BIZYAIR_DEV_REQUEST_URL, @@ -62,7 +63,8 @@ def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]): base_model, out_route, out_score = None, None, None for rule in results[::-1]: - if rule.mode_type in {"unet", "vae", "checkpoint"}: + # TODO add to config models.yaml + if rule.mode_type in {"unet", "vae", "checkpoint", "upscale_models"}: base_model = rule.base_model out_route = rule.route out_score = rule.score diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index a8968021..7bdb80bb 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -1,19 +1,177 @@ +import hashlib +import json import pprint +import time import traceback +from dataclasses import dataclass, field from typing import Any, Dict, List -from bizyair.common.env_var import BIZYAIR_DEBUG +import comfy + +from bizyair.common.caching import BizyAirTaskCache, CacheConfig +from bizyair.common.client import send_request +from bizyair.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS from bizyair.common.utils import truncate_long_strings +from bizyair.configs.conf import config_manager from bizyair.image_utils import decode_data, encode_data from ..base import Command, Processor # type: ignore +def get_task_result(task_id: str, offset: int = 0) -> dict: + """ + Get the result of a task. + """ + import requests + + task_api = config_manager.get_task_api() + url = f"{BIZYAIR_SERVER_ADDRESS}/{task_api.task_result_endpoint}/{task_id}" + response_json = send_request( + method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8") + ) + out = response_json + events = out.get("data", {}).get("events", []) + new_events = [] + for event in events: + if ( + "data" in event + and isinstance(event["data"], str) + and event["data"].startswith("https://") + ): + # event["data"] = requests.get(event["data"]).json() + event["data"] = send_request(method="GET", url=event["data"]) + new_events.append(event) + out["data"]["events"] = new_events + return out + + +@dataclass +class BizyAirTask: + TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"] + task_id: str + data_pool: list[dict] = field(default_factory=list) + data_status: str = None + + @staticmethod + def check_inputs(inputs: dict) -> bool: + return ( + inputs.get("code") == 20000 + and inputs.get("status", False) + and "task_id" in inputs.get("data", {}) + ) + + @classmethod + def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask": + if check_inputs and not cls.check_inputs(inputs): + raise ValueError(f"Invalid inputs: {inputs}") + data = inputs.get("data", {}) + task_id = data.get("task_id", "") + return cls(task_id=task_id, data_pool=[], data_status="started") + + def is_finished(self) -> bool: + if not self.data_pool: + return False + if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]: + return True + return False + + def send_request(self, offset: int = 0) -> dict: + if offset >= len(self.data_pool): + return get_task_result(self.task_id, offset) + else: + return self.data_pool[offset] + + def get_data(self, offset: int = 0) -> dict: + if offset >= len(self.data_pool): + return {} + return self.data_pool[offset] + + @staticmethod + def _fetch_remote_data(url: str) -> dict: + import requests + + return requests.get(url).json() + + def get_last_data(self) -> dict: + return self.get_data(len(self.data_pool) - 1) + + def do_task_until_completed( + self, *, timeout: int = 600, poll_interval: float = 1 + ) -> list[dict]: + offset = 0 + start_time = time.time() + pbar = None + while not self.is_finished(): + try: + data = self.send_request(offset) + data_lst = self._extract_data_list(data) + self.data_pool.extend(data_lst) + offset += len(data_lst) + for data in data_lst: + message = data.get("data", {}).get("message", {}) + if ( + isinstance(message, dict) + and message.get("event", None) == "progress" + ): + value = message["data"]["value"] + total = message["data"]["max"] + if pbar is None: + pbar = comfy.utils.ProgressBar(total) + pbar.update_absolute(value + 1, total, None) + except Exception as e: + print(f"Exception: {e}") + + if time.time() - start_time > timeout: + raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish") + + time.sleep(poll_interval) + + return self.data_pool + + def _extract_data_list(self, data): + data_lst = data.get("data", {}).get("events", []) + if not data_lst: + raise ValueError(f"No data found in task {self.task_id}") + return data_lst + + class PromptServer(Command): + cache_manager: BizyAirTaskCache = BizyAirTaskCache( + config=CacheConfig.from_config(config_manager.get_cache_config()) + ) + def __init__(self, router: Processor, processor: Processor): self.router = router self.processor = processor + def get_task_id(self, result: Dict[str, Any]) -> str: + return result.get("data", {}).get("task_id", "") + + def is_async_task(self, result: Dict[str, Any]) -> str: + """Determine if the result indicates an asynchronous task.""" + return ( + result.get("code") == 20000 + and result.get("status", False) + and "task_id" in result.get("data", {}) + ) + + def _get_result(self, result: Dict[str, Any], *, cache_key: str = None): + try: + response_data = result["data"] + if BizyAirTask.check_inputs(result): + self.cache_manager.set(cache_key, result) + bz_task = BizyAirTask.from_data(result, check_inputs=False) + bz_task.do_task_until_completed(timeout=10 * 60) # 10 minutes + last_data = bz_task.get_last_data() + response_data = last_data.get("data") + out = response_data["payload"] + assert out is not None, "Output payload should not be None" + self.cache_manager.set(cache_key, out, overwrite=True) + return out + except Exception as e: + self.cache_manager.delete(cache_key) + raise RuntimeError(f"Exception: {e}, response_data: {response_data}") from e + def execute( self, prompt: Dict[str, Dict[str, Any]], @@ -23,34 +181,46 @@ def execute( ): prompt = encode_data(prompt) + if BIZYAIR_DEBUG: debug_info = { "prompt": truncate_long_strings(prompt, 50), "last_node_ids": last_node_ids, } pprint.pprint(debug_info, indent=4) + url = self.router(prompt=prompt, last_node_ids=last_node_ids) + if BIZYAIR_DEBUG: print(f"Generated URL: {url}") - result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids) + start_time = time.time() + sh256 = hashlib.sha256( + json.dumps({"url": url, "prompt": prompt}).encode("utf-8") + ).hexdigest() + end_time = time.time() if BIZYAIR_DEBUG: - pprint.pprint({"result": truncate_long_strings(result, 50)}, indent=4) + print( + f"Time taken to generate sh256-{sh256}: {end_time - start_time} seconds" + ) - if result is None: - raise RuntimeError("result is None") + cached_output = self.cache_manager.get(sh256) + if cached_output: + if BIZYAIR_DEBUG: + print(f"Cache hit for sh256-{sh256}") + out = cached_output + else: + result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids) + out = self._get_result(result, cache_key=sh256) + + if BIZYAIR_DEBUG: + pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4) - try: - out = result["data"]["payload"] - assert out is not None - except Exception as e: - raise RuntimeError( - f'Unexpected error accessing result["data"]["payload"]. Result: {result}' - ) from e try: real_out = decode_data(out) return real_out[0] except Exception as e: print("Exception occurred while decoding data") + self.cache_manager.delete(sh256) traceback.print_exc() raise RuntimeError(f"Exception: {e=}") from e diff --git a/src/bizyair/common/caching.py b/src/bizyair/common/caching.py new file mode 100644 index 00000000..96dc5b41 --- /dev/null +++ b/src/bizyair/common/caching.py @@ -0,0 +1,198 @@ +import glob +import json +import os +import time +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class CacheConfig: + max_size: int = 100 + expiration: int = 300 # 300 seconds + cache_dir: str = "./cache" + file_prefix: str = "bizyair_cache_" + file_suffix: str = ".json" + use_cache: bool = True + + @classmethod + def from_config(cls, config: Dict[str, Any]): + return cls( + max_size=config.get("max_size", 100), + expiration=config.get("expiration", 300), + cache_dir=config.get("cache_dir", "./cache"), + file_prefix=config.get("file_prefix", "bizyair_cache_"), + file_suffix=config.get("file_suffix", ".json"), + use_cache=config.get("use_cache", True), + ) + + +class CacheManager(ABC): + @abstractmethod + def get(self, key): + pass + + @abstractmethod + def set(self, key, value): + pass + + @abstractmethod + def clear(self): + pass + + @abstractmethod + def disable(self): + pass + + +class BizyAirTaskCache(CacheManager): + def __init__(self, config: CacheConfig): + self.config = config + self.cache = OrderedDict() + self.cache_dir = config.cache_dir + self.ensure_directory_exists() + self.cache = self.load_cache() if config.use_cache else self.cache + + def ensure_directory_exists(self): + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + def load_cache(self): + cache_v_files = glob.glob( + os.path.join( + self.cache_dir, f"{self.config.file_prefix}*{self.config.file_suffix}" + ) + ) + output = OrderedDict() + cache_datas = [] + for cache_file in cache_v_files: + try: + file_name = os.path.basename(cache_file)[ + len(self.config.file_prefix) : -len(self.config.file_suffix) + ] + cache_key = file_name.split("-")[0] + cache_timestamp = file_name.split("-")[1] + if int(time.time()) - int(cache_timestamp) > self.config.expiration: + self.delete_file(cache_file) + continue + cache_datas.append( + { + "key": cache_key, + "timestamp": int(cache_timestamp), + "file_path": cache_file, + } + ) + except Exception as e: + print( + f"Warning: Error loading cache file {cache_file}: because {e}, will delete it" + ) + cache_datas = sorted(cache_datas, key=lambda x: x["timestamp"]) + for cache_data in cache_datas: + output[cache_data["key"]] = ( + cache_data["file_path"], + cache_data["timestamp"], + ) + return output + + def delete(self, key): + if key in self.cache: + self.delete_file(self.cache[key][0]) + del self.cache[key] + + def get(self, key): + if key not in self.cache: + return None + + file_path, timestamp = self.cache[key] + if time.time() - timestamp >= self.config.expiration: + self._remove_expired_entry(file_path, key) + return None + + cache_data = self._read_file(file_path) + if cache_data["cache_key"] == key: + return cache_data["result"] + else: + self._remove_expired_entry(file_path, key) + return None + + def _read_file(self, file_path): + try: + with open(file_path, "r") as f: + cache_data = json.load(f) + return cache_data + except Exception as e: + print(f"Error reading file {file_path}: {e}") + return None + + def _remove_expired_entry(self, file_path, key): + self.delete_file(file_path) + del self.cache[key] + + def set(self, key, value, *, overwrite=False): + if not overwrite and key in self.cache: + raise ValueError( + f"Key '{key}' already exists in cache. Use overwrite=True to replace it." + ) + assert isinstance(key, str), "Key must be a string" + + if len(self.cache) >= self.config.max_size: + self._evict_oldest() + + timestamp = int(time.time()) + file_path = os.path.join( + self.cache_dir, f"{self.config.file_prefix}{key}-{timestamp}.json" + ) + self.write_file(key, value, file_path, timestamp) + + def _evict_oldest(self): + oldest_key, (oldest_file_path, _) = self.cache.popitem(last=False) + self.delete_file(oldest_file_path) + + def write_file(self, key: str, value: Any, file_path: str, timestamp: int): + try: + with open(file_path, "w") as f: + json.dump( + {"result": value, "cache_key": key, "timestamp": timestamp}, f + ) + self.cache[key] = (file_path, timestamp) + except Exception as e: + print(f"Error writing file for key '{key}': {e}") + + def delete_file(self, file_path): + if os.path.exists(file_path): + try: + os.remove(file_path) + except Exception as e: + print(f"Error deleting file '{file_path}': {e}") + + def clear(self): + for file_path, _ in self.cache.values(): + self.delete_file(file_path) + self.cache.clear() + + def disable(self): + self.clear() + + +# Example usage +if __name__ == "__main__": + cache_config = CacheConfig(max_size=12, expiration=10, cache_dir="./cache") + cache = BizyAirTaskCache(cache_config) + + # Set some cache values + cache.set("key1", "This is the value for key1") + cache.set("key2", "This is the value for key2") + + # Retrieve values from cache + print(cache.get("key1")) # Output: This is the value for key1 + print(cache.get("key2")) # Output: This is the value for key2 + + # Wait for expiration + time.sleep(9) + print(cache.get("key1")) # Output: None (expired) + + # Clear cache + cache.clear() + print(cache.get("key2")) # Output: None (cache cleared) diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index ee78fbb2..d9d02f09 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -1,11 +1,18 @@ -import asyncio +import hashlib import json import pprint +import time import urllib.error import urllib.request import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Union import aiohttp +import comfy + +from bizyair.common.caching import CacheManager __all__ = ["send_request"] @@ -111,8 +118,10 @@ def send_request( data: bytes = None, verbose=False, callback: callable = process_response_data, + response_handler: callable = json.loads, + cache_manager: CacheManager = None, **kwargs, -) -> dict: +) -> Union[dict, Any]: try: headers = kwargs.pop("headers") if "headers" in kwargs else _headers() headers["User-Agent"] = "BizyAir Client" @@ -138,9 +147,11 @@ def send_request( + "Also, verify your network settings and disable any proxies if necessary.\n" + "After checking, please restart the ComfyUI service." ) + if response_handler: + response_data = response_handler(response_data) if callback: - return callback(json.loads(response_data)) - return json.loads(response_data) + return callback(response_data) + return response_data async def async_send_request( diff --git a/src/bizyair/common/utils.py b/src/bizyair/common/utils.py index 8ce5daf4..cd3dd6c3 100644 --- a/src/bizyair/common/utils.py +++ b/src/bizyair/common/utils.py @@ -14,6 +14,10 @@ def truncate_long_strings(obj, max_length=50): return {k: truncate_long_strings(v, max_length) for k, v in obj.items()} elif isinstance(obj, list): return [truncate_long_strings(v, max_length) for v in obj] + elif isinstance(obj, tuple): + return tuple(truncate_long_strings(v, max_length) for v in obj) + elif isinstance(obj, torch.Tensor): + return obj.shape, obj.dtype, obj.device else: return obj diff --git a/src/bizyair/configs/conf.py b/src/bizyair/configs/conf.py index d59bcbdb..2dea9c22 100644 --- a/src/bizyair/configs/conf.py +++ b/src/bizyair/configs/conf.py @@ -17,6 +17,11 @@ class ModelRule: inputs: dict +@dataclass +class TaskApi: + task_result_endpoint: str + + class ModelRuleManager: def __init__(self, model_rules: list[dict]): self.model_rules = model_rules @@ -57,7 +62,7 @@ def find_rules(self, class_type: str) -> List[ModelRule]: score=self.model_rules[idx_1]["score"], route=self.model_rules[idx_1]["route"], class_type=class_type, - inputs=self.model_rules[idx_1]["nodes"][idx_2]["inputs"], + inputs=self.model_rules[idx_1]["nodes"][idx_2].get("inputs", {}), ) for idx_1, idx_2 in rule_indexes ] @@ -93,6 +98,12 @@ def get_rules(self, class_type: str) -> List[ModelRule]: def get_model_version_id_prefix(self): return self.model_rule_config["model_version_config"]["model_version_id_prefix"] + def get_cache_config(self): + return self.model_rule_config.get("cache_config", {}) + + def get_task_api(self): + return TaskApi(**self.model_rule_config["task_api"]) + model_path_config = os.path.join(os.path.dirname(__file__), "models.json") model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml") diff --git a/src/bizyair/configs/models.yaml b/src/bizyair/configs/models.yaml index 6936c6c2..c4fcf87e 100644 --- a/src/bizyair/configs/models.yaml +++ b/src/bizyair/configs/models.yaml @@ -2,6 +2,14 @@ model_version_config: model_version_id_prefix: "BIZYAIR_MODEL_VERSION_ID:" +cache_config: + max_size: 100 # 100 items + expiration: 604800 # 7 days + cache_dir: ".bizyair_cache" + file_prefix: "bizyair_task_" + file_suffix: ".json" + use_cache: true + model_hub: find_model: @@ -14,6 +22,11 @@ model_types: # checkpoints: bizyair/checkpoint # vae: bizyair/vae +task_api: + # Base URL for task-related API calls + task_result_endpoint: bizy_task + + model_rules: - mode_type: unet base_model: FLUX @@ -238,3 +251,19 @@ model_rules: inputs: vae_name: - ^flux.1-canny-vae.safetensors$ + + - mode_type: upscale_models + base_model: UPSCALE_MODEL + describe: Upscale Model + score: 1 + route: /bizy_task/bizyair-flux1-dev-fp8-async + nodes: + - class_type: UpscaleModelLoader + + - mode_type: upscale_model + base_model: FLUX + describe: Flux Upscale Model + score: 6 + route: /bizy_task/bizyair-flux1-dev-fp8-async + nodes: + - class_type: UltimateSDUpscale diff --git a/src/bizyair/nodes_base.py b/src/bizyair/nodes_base.py index 870f9d12..13288f2b 100644 --- a/src/bizyair/nodes_base.py +++ b/src/bizyair/nodes_base.py @@ -122,6 +122,7 @@ def assigned_id(self): def default_function(self, **kwargs): class_type = self._determine_class_type() + node_ios = self._process_non_send_request_types(class_type, kwargs) # TODO: add processing for send_request_types send_request_datatype_list = self._get_send_request_datatypes() diff --git a/src/bizyair/path_utils/path_manager.py b/src/bizyair/path_utils/path_manager.py index c7f9df94..ebff143d 100644 --- a/src/bizyair/path_utils/path_manager.py +++ b/src/bizyair/path_utils/path_manager.py @@ -77,7 +77,8 @@ def guess_url_from_node( out = [ rule for rule in rules - if all( + if len(rule.inputs) == 0 + or all( any(re.search(p, node["inputs"][key]) is not None for p in patterns) for key, patterns in rule.inputs.items() )