Skip to content

Commit

Permalink
Add SiliconCloud VLM API Node (#172)
Browse files Browse the repository at this point in the history
* Refactor model fetching extension for SiliconCloud LLM and VLM APIs

* Add VLM API support

* Add image detail control to SiliconCloudVLMAPI
  • Loading branch information
lcolok authored Oct 21, 2024
1 parent 96664f3 commit a84616a
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 68 deletions.
169 changes: 116 additions & 53 deletions js/siliconcloud_llm_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,69 +46,132 @@ app.registerExtension({
});

app.registerExtension({
name: "bizyair.siliconcloud.llm.api.model_fetch",
name: "bizyair.siliconcloud.vlm.api.populate",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAirSiliconCloudLLMAPI") {
const originalNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = async function () {
if (originalNodeCreated) {
originalNodeCreated.apply(this, arguments);
if (nodeData.name === "BizyAirSiliconCloudVLMAPI") {
function populate(text) {
if (this.widgets) {
const pos = this.widgets.findIndex((w) => w.name === "showtext");
if (pos !== -1) {
for (let i = pos; i < this.widgets.length; i++) {
this.widgets[i].onRemove?.();
}
this.widgets.length = pos;
}
}

for (const list of text) {
const w = ComfyWidgets["STRING"](this, "showtext", ["STRING", { multiline: true }], app).widget;
w.inputEl.readOnly = true;
w.inputEl.style.opacity = 0.6;
w.value = list;
}

const modelWidget = this.widgets.find((w) => w.name === "model");

const fetchModels = async () => {
try {
const response = await fetch("/bizyair/get_silicon_cloud_models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({}),
});

if (response.ok) {
const models = await response.json();
console.debug("Fetched models:", models);
return models;
} else {
console.error(`Failed to fetch models: ${response.status}`);
requestAnimationFrame(() => {
const sz = this.computeSize();
if (sz[0] < this.size[0]) {
sz[0] = this.size[0];
}
if (sz[1] < this.size[1]) {
sz[1] = this.size[1];
}
this.onResize?.(sz);
app.graph.setDirtyCanvas(true, false);
});
}

const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (message) {
onExecuted?.apply(this, arguments);
populate.call(this, message.text);
};
}
},
});

// 通用的模型获取和更新函数
const createModelFetchExtension = (nodeName, endpoint) => {
return {
name: `bizyair.siliconcloud.${nodeName.toLowerCase()}.api.model_fetch`,
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === nodeName) {
const originalNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = async function () {
if (originalNodeCreated) {
originalNodeCreated.apply(this, arguments);
}

const modelWidget = this.widgets.find((w) => w.name === "model");

const fetchModels = async () => {
try {
const response = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({}),
});

if (response.ok) {
const models = await response.json();
console.debug(`Fetched ${nodeName} models:`, models);
return models;
} else {
console.error(`Failed to fetch ${nodeName} models: ${response.status}`);
return [];
}
} catch (error) {
console.error(`Error fetching ${nodeName} models`, error);
return [];
}
} catch (error) {
console.error(`Error fetching models`, error);
return [];
}
};
};

const updateModels = async () => {
const prevValue = modelWidget.value;
modelWidget.value = "";
modelWidget.options.values = [];
const updateModels = async () => {
const prevValue = modelWidget.value;
modelWidget.value = "";
modelWidget.options.values = [];

const models = await fetchModels();
const models = await fetchModels();

modelWidget.options.values = models;
console.debug("Updated modelWidget.options.values:", modelWidget.options.values);
modelWidget.options.values = models;
console.debug(`Updated ${nodeName} modelWidget.options.values:`, modelWidget.options.values);

if (models.includes(prevValue)) {
modelWidget.value = prevValue; // stay on current.
} else if (models.length > 0) {
modelWidget.value = models[0]; // set first as default.
}
if (models.includes(prevValue)) {
modelWidget.value = prevValue; // stay on current.
} else if (models.length > 0) {
modelWidget.value = models[0]; // set first as default.
}

console.debug("Updated modelWidget.value:", modelWidget.value);
app.graph.setDirtyCanvas(true);
};
console.debug(`Updated ${nodeName} modelWidget.value:`, modelWidget.value);
app.graph.setDirtyCanvas(true);
};

const dummy = async () => {
// calling async method will update the widgets with actual value from the browser and not the default from Node definition.
};

const dummy = async () => {
// calling async method will update the widgets with actual value from the browser and not the default from Node definition.
// Initial update
await dummy(); // this will cause the widgets to obtain the actual value from web page.
await updateModels();
};
}
},
};
};

// Initial update
await dummy(); // this will cause the widgets to obtain the actual value from web page.
await updateModels();
};
}
},
});
// LLM Extension
app.registerExtension(
createModelFetchExtension(
"BizyAirSiliconCloudLLMAPI",
"/bizyair/get_silicon_cloud_llm_models"
)
);

// VLM Extension
app.registerExtension(
createModelFetchExtension(
"BizyAirSiliconCloudVLMAPI",
"/bizyair/get_silicon_cloud_vlm_models"
)
);
111 changes: 97 additions & 14 deletions llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@
from server import PromptServer

from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS
from bizyair.image_utils import decode_data, encode_data
from bizyair.image_utils import decode_data, encode_comfy_image, encode_data

from .utils import (
decode_and_deserialize,
get_api_key,
get_llm_response,
get_vlm_response,
send_post_request,
serialize_and_encode,
)


@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_models")
async def get_silicon_cloud_models_endpoint(request):
data = await request.json()
api_key = data.get("api_key", get_api_key())
async def fetch_all_models(api_key):
url = "https://api.siliconflow.cn/v1/models"
headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
params = {"type": "text", "sub_type": "chat"}
Expand All @@ -32,20 +30,37 @@ async def get_silicon_cloud_models_endpoint(request):
) as response:
if response.status == 200:
data = await response.json()
models = [model["id"] for model in data["data"]]
models.append("No LLM Enhancement")
return web.json_response(models)
all_models = [model["id"] for model in data["data"]]
return all_models
else:
print(f"Error fetching models: HTTP Status {response.status}")
return web.json_response(
["Error fetching models"], status=response.status
)
return []
except aiohttp.ClientError as e:
print(f"Error fetching models: {e}")
return web.json_response(["Error fetching models"], status=500)
return []
except asyncio.exceptions.TimeoutError:
print("Request to fetch models timed out")
return web.json_response(["Request timed out"], status=504)
return []


@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_llm_models")
async def get_silicon_cloud_llm_models_endpoint(request):
data = await request.json()
api_key = data.get("api_key", get_api_key())
all_models = await fetch_all_models(api_key)
llm_models = [model for model in all_models if "vl" not in model.lower()]
llm_models.append("No LLM Enhancement")
return web.json_response(llm_models)


@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_vlm_models")
async def get_silicon_cloud_vlm_models_endpoint(request):
data = await request.json()
api_key = data.get("api_key", get_api_key())
all_models = await fetch_all_models(api_key)
vlm_models = [model for model in all_models if "vl" in model.lower()]
vlm_models.append("No VLM Enhancement")
return web.json_response(vlm_models)


class SiliconCloudLLMAPI:
Expand Down Expand Up @@ -85,7 +100,6 @@ def INPUT_TYPES(s):
RETURN_TYPES = ("STRING",)
FUNCTION = "get_llm_model_response"
OUTPUT_NODE = False

CATEGORY = "☁️BizyAir/AI Assistants"

def get_llm_model_response(
Expand All @@ -105,6 +119,73 @@ def get_llm_model_response(
return {"ui": {"text": (text,)}, "result": (text,)}


class SiliconCloudVLMAPI:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ((), {}),
"system_prompt": (
"STRING",
{
"default": "你是一个能分析图像的AI助手。请仔细观察图像,并根据用户的问题提供详细、准确的描述。",
"multiline": True,
},
),
"user_prompt": (
"STRING",
{
"default": "请描述这张图片的内容,并指出任何有趣或不寻常的细节。",
"multiline": True,
},
),
"images": ("IMAGE",),
"max_tokens": ("INT", {"default": 512, "min": 100, "max": 1e5}),
"temperature": (
"FLOAT",
{"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.01},
),
"detail": (["auto", "low", "high"], {"default": "auto"}),
}
}

RETURN_TYPES = ("STRING",)
FUNCTION = "get_vlm_model_response"
OUTPUT_NODE = False
CATEGORY = "☁️BizyAir/AI Assistants"

def get_vlm_model_response(
self, model, system_prompt, user_prompt, images, max_tokens, temperature, detail
):
if model == "No VLM Enhancement":
return (user_prompt,)

# 使用 encode_comfy_image 函数编码图像批次
encoded_images_json = encode_comfy_image(
images, image_format="WEBP", lossless=True
)
encoded_images_dict = json.loads(encoded_images_json)

# 提取所有编码后的图像
base64_images = list(encoded_images_dict.values())

response = get_vlm_response(
model,
system_prompt,
user_prompt,
base64_images,
max_tokens,
temperature,
detail,
)
ret = json.loads(response)
text = ret["choices"][0]["message"]["content"]
return {"ui": {"text": (text,)}, "result": (text,)}


class BizyAirJoyCaption:
# refer to: https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha
API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/joycaption"
Expand Down Expand Up @@ -193,9 +274,11 @@ def joycaption(self, image, do_sample, temperature, max_tokens):

NODE_CLASS_MAPPINGS = {
"BizyAirSiliconCloudLLMAPI": SiliconCloudLLMAPI,
"BizyAirSiliconCloudVLMAPI": SiliconCloudVLMAPI,
"BizyAirJoyCaption": BizyAirJoyCaption,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BizyAirSiliconCloudLLMAPI": "☁️BizyAir SiliconCloud LLM API",
"BizyAirSiliconCloudVLMAPI": "☁️BizyAir SiliconCloud VLM API",
"BizyAirJoyCaption": "☁️BizyAir Joy Caption",
}
Loading

0 comments on commit a84616a

Please sign in to comment.