From a84616aaaa6843758f525dae8c79ce770f59ee2e Mon Sep 17 00:00:00 2001 From: lcolok <425311101@qq.com> Date: Mon, 21 Oct 2024 10:57:47 +0800 Subject: [PATCH] Add SiliconCloud VLM API Node (#172) * Refactor model fetching extension for SiliconCloud LLM and VLM APIs * Add VLM API support * Add image detail control to SiliconCloudVLMAPI --- js/siliconcloud_llm_api.js | 169 +++++++++++++++++++++++++------------ llm.py | 111 +++++++++++++++++++++--- utils.py | 60 ++++++++++++- 3 files changed, 272 insertions(+), 68 deletions(-) diff --git a/js/siliconcloud_llm_api.js b/js/siliconcloud_llm_api.js index 273a9301..8c2f1eb4 100644 --- a/js/siliconcloud_llm_api.js +++ b/js/siliconcloud_llm_api.js @@ -46,69 +46,132 @@ app.registerExtension({ }); app.registerExtension({ - name: "bizyair.siliconcloud.llm.api.model_fetch", + name: "bizyair.siliconcloud.vlm.api.populate", async beforeRegisterNodeDef(nodeType, nodeData, app) { - if (nodeData.name === "BizyAirSiliconCloudLLMAPI") { - const originalNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = async function () { - if (originalNodeCreated) { - originalNodeCreated.apply(this, arguments); + if (nodeData.name === "BizyAirSiliconCloudVLMAPI") { + function populate(text) { + if (this.widgets) { + const pos = this.widgets.findIndex((w) => w.name === "showtext"); + if (pos !== -1) { + for (let i = pos; i < this.widgets.length; i++) { + this.widgets[i].onRemove?.(); + } + this.widgets.length = pos; + } + } + + for (const list of text) { + const w = ComfyWidgets["STRING"](this, "showtext", ["STRING", { multiline: true }], app).widget; + w.inputEl.readOnly = true; + w.inputEl.style.opacity = 0.6; + w.value = list; } - const modelWidget = this.widgets.find((w) => w.name === "model"); - - const fetchModels = async () => { - try { - const response = await fetch("/bizyair/get_silicon_cloud_models", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({}), - }); - - if (response.ok) { - const models = await response.json(); - console.debug("Fetched models:", models); - return models; - } else { - console.error(`Failed to fetch models: ${response.status}`); + requestAnimationFrame(() => { + const sz = this.computeSize(); + if (sz[0] < this.size[0]) { + sz[0] = this.size[0]; + } + if (sz[1] < this.size[1]) { + sz[1] = this.size[1]; + } + this.onResize?.(sz); + app.graph.setDirtyCanvas(true, false); + }); + } + + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (message) { + onExecuted?.apply(this, arguments); + populate.call(this, message.text); + }; + } + }, +}); + +// 通用的模型获取和更新函数 +const createModelFetchExtension = (nodeName, endpoint) => { + return { + name: `bizyair.siliconcloud.${nodeName.toLowerCase()}.api.model_fetch`, + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === nodeName) { + const originalNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = async function () { + if (originalNodeCreated) { + originalNodeCreated.apply(this, arguments); + } + + const modelWidget = this.widgets.find((w) => w.name === "model"); + + const fetchModels = async () => { + try { + const response = await fetch(endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({}), + }); + + if (response.ok) { + const models = await response.json(); + console.debug(`Fetched ${nodeName} models:`, models); + return models; + } else { + console.error(`Failed to fetch ${nodeName} models: ${response.status}`); + return []; + } + } catch (error) { + console.error(`Error fetching ${nodeName} models`, error); return []; } - } catch (error) { - console.error(`Error fetching models`, error); - return []; - } - }; + }; - const updateModels = async () => { - const prevValue = modelWidget.value; - modelWidget.value = ""; - modelWidget.options.values = []; + const updateModels = async () => { + const prevValue = modelWidget.value; + modelWidget.value = ""; + modelWidget.options.values = []; - const models = await fetchModels(); + const models = await fetchModels(); - modelWidget.options.values = models; - console.debug("Updated modelWidget.options.values:", modelWidget.options.values); + modelWidget.options.values = models; + console.debug(`Updated ${nodeName} modelWidget.options.values:`, modelWidget.options.values); - if (models.includes(prevValue)) { - modelWidget.value = prevValue; // stay on current. - } else if (models.length > 0) { - modelWidget.value = models[0]; // set first as default. - } + if (models.includes(prevValue)) { + modelWidget.value = prevValue; // stay on current. + } else if (models.length > 0) { + modelWidget.value = models[0]; // set first as default. + } - console.debug("Updated modelWidget.value:", modelWidget.value); - app.graph.setDirtyCanvas(true); - }; + console.debug(`Updated ${nodeName} modelWidget.value:`, modelWidget.value); + app.graph.setDirtyCanvas(true); + }; + + const dummy = async () => { + // calling async method will update the widgets with actual value from the browser and not the default from Node definition. + }; - const dummy = async () => { - // calling async method will update the widgets with actual value from the browser and not the default from Node definition. + // Initial update + await dummy(); // this will cause the widgets to obtain the actual value from web page. + await updateModels(); }; + } + }, + }; +}; - // Initial update - await dummy(); // this will cause the widgets to obtain the actual value from web page. - await updateModels(); - }; - } - }, -}); +// LLM Extension +app.registerExtension( + createModelFetchExtension( + "BizyAirSiliconCloudLLMAPI", + "/bizyair/get_silicon_cloud_llm_models" + ) +); + +// VLM Extension +app.registerExtension( + createModelFetchExtension( + "BizyAirSiliconCloudVLMAPI", + "/bizyair/get_silicon_cloud_vlm_models" + ) +); diff --git a/llm.py b/llm.py index 2f626efc..b1275d37 100644 --- a/llm.py +++ b/llm.py @@ -6,21 +6,19 @@ from server import PromptServer from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS -from bizyair.image_utils import decode_data, encode_data +from bizyair.image_utils import decode_data, encode_comfy_image, encode_data from .utils import ( decode_and_deserialize, get_api_key, get_llm_response, + get_vlm_response, send_post_request, serialize_and_encode, ) -@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_models") -async def get_silicon_cloud_models_endpoint(request): - data = await request.json() - api_key = data.get("api_key", get_api_key()) +async def fetch_all_models(api_key): url = "https://api.siliconflow.cn/v1/models" headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"} params = {"type": "text", "sub_type": "chat"} @@ -32,20 +30,37 @@ async def get_silicon_cloud_models_endpoint(request): ) as response: if response.status == 200: data = await response.json() - models = [model["id"] for model in data["data"]] - models.append("No LLM Enhancement") - return web.json_response(models) + all_models = [model["id"] for model in data["data"]] + return all_models else: print(f"Error fetching models: HTTP Status {response.status}") - return web.json_response( - ["Error fetching models"], status=response.status - ) + return [] except aiohttp.ClientError as e: print(f"Error fetching models: {e}") - return web.json_response(["Error fetching models"], status=500) + return [] except asyncio.exceptions.TimeoutError: print("Request to fetch models timed out") - return web.json_response(["Request timed out"], status=504) + return [] + + +@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_llm_models") +async def get_silicon_cloud_llm_models_endpoint(request): + data = await request.json() + api_key = data.get("api_key", get_api_key()) + all_models = await fetch_all_models(api_key) + llm_models = [model for model in all_models if "vl" not in model.lower()] + llm_models.append("No LLM Enhancement") + return web.json_response(llm_models) + + +@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_vlm_models") +async def get_silicon_cloud_vlm_models_endpoint(request): + data = await request.json() + api_key = data.get("api_key", get_api_key()) + all_models = await fetch_all_models(api_key) + vlm_models = [model for model in all_models if "vl" in model.lower()] + vlm_models.append("No VLM Enhancement") + return web.json_response(vlm_models) class SiliconCloudLLMAPI: @@ -85,7 +100,6 @@ def INPUT_TYPES(s): RETURN_TYPES = ("STRING",) FUNCTION = "get_llm_model_response" OUTPUT_NODE = False - CATEGORY = "☁️BizyAir/AI Assistants" def get_llm_model_response( @@ -105,6 +119,73 @@ def get_llm_model_response( return {"ui": {"text": (text,)}, "result": (text,)} +class SiliconCloudVLMAPI: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ((), {}), + "system_prompt": ( + "STRING", + { + "default": "你是一个能分析图像的AI助手。请仔细观察图像,并根据用户的问题提供详细、准确的描述。", + "multiline": True, + }, + ), + "user_prompt": ( + "STRING", + { + "default": "请描述这张图片的内容,并指出任何有趣或不寻常的细节。", + "multiline": True, + }, + ), + "images": ("IMAGE",), + "max_tokens": ("INT", {"default": 512, "min": 100, "max": 1e5}), + "temperature": ( + "FLOAT", + {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.01}, + ), + "detail": (["auto", "low", "high"], {"default": "auto"}), + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "get_vlm_model_response" + OUTPUT_NODE = False + CATEGORY = "☁️BizyAir/AI Assistants" + + def get_vlm_model_response( + self, model, system_prompt, user_prompt, images, max_tokens, temperature, detail + ): + if model == "No VLM Enhancement": + return (user_prompt,) + + # 使用 encode_comfy_image 函数编码图像批次 + encoded_images_json = encode_comfy_image( + images, image_format="WEBP", lossless=True + ) + encoded_images_dict = json.loads(encoded_images_json) + + # 提取所有编码后的图像 + base64_images = list(encoded_images_dict.values()) + + response = get_vlm_response( + model, + system_prompt, + user_prompt, + base64_images, + max_tokens, + temperature, + detail, + ) + ret = json.loads(response) + text = ret["choices"][0]["message"]["content"] + return {"ui": {"text": (text,)}, "result": (text,)} + + class BizyAirJoyCaption: # refer to: https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/joycaption" @@ -193,9 +274,11 @@ def joycaption(self, image, do_sample, temperature, max_tokens): NODE_CLASS_MAPPINGS = { "BizyAirSiliconCloudLLMAPI": SiliconCloudLLMAPI, + "BizyAirSiliconCloudVLMAPI": SiliconCloudVLMAPI, "BizyAirJoyCaption": BizyAirJoyCaption, } NODE_DISPLAY_NAME_MAPPINGS = { "BizyAirSiliconCloudLLMAPI": "☁️BizyAir SiliconCloud LLM API", + "BizyAirSiliconCloudVLMAPI": "☁️BizyAir SiliconCloud VLM API", "BizyAirJoyCaption": "☁️BizyAir Joy Caption", } diff --git a/utils.py b/utils.py index 109e9437..c0c876af 100644 --- a/utils.py +++ b/utils.py @@ -5,7 +5,7 @@ import urllib.parse import urllib.request import zlib -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np @@ -158,3 +158,61 @@ def get_llm_response( } response = send_post_request(api_url, headers=headers, payload=payload) return response + + +def get_vlm_response( + model: str, + system_prompt: str, + user_prompt: str, + base64_images: List[str], + max_tokens: int = 1024, + temperature: float = 0.7, + detail: str = "auto", +): + api_url = "https://api.siliconflow.cn/v1/chat/completions" + API_KEY = get_api_key() + headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"Bearer {API_KEY}", + } + + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": system_prompt}], + }, # 此方法皆适用于两种 VL 模型 + # { + # "role": "system", + # "content": system_prompt, + # }, # role 为 "system" 的这种方式只适用于 QwenVL 系列模型,并不适用于 InternVL 系列模型 + ] + + user_content = [] + for base64_image in base64_images: + user_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/webp;base64,{base64_image}", + "detail": detail, + }, + } + ) + user_content.append({"type": "text", "text": user_prompt}) + + messages.append({"role": "user", "content": user_content}) + + payload = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": 0.9, + "top_k": 50, + "stream": False, + "n": 1, + } + + response = send_post_request(api_url, headers=headers, payload=payload) + return response