Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DashScope model wrapper into AgentScope #54

Merged
merged 21 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from .tongyi_model import (
TongyiWrapper,
TongyiChatWrapper,
QwenChatWrapper,
)


Expand All @@ -36,7 +36,7 @@
"read_model_configs",
"clear_model_configs",
"TongyiWrapper",
"TongyiChatWrapper",
"QwenChatWrapper",
]

_MODEL_CONFIGS: dict[str, dict] = {}
Expand Down
117 changes: 89 additions & 28 deletions src/agentscope/models/tongyi_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Tongyi models"""
"""Model wrapper for Qwen chat models"""
from http import HTTPStatus
from typing import Any

try:
Expand All @@ -13,9 +14,13 @@

from ..utils.monitor import MonitorFactory
from ..utils.monitor import get_full_name
from ..utils import QuotaExceededError
from ..constants import _DEFAULT_API_BUDGET

# The models in this list require that the roles of messages must alternate
# between "user" and "assistant".
# TODO: add more models
SPECIAL_MODEL_LIST = ["qwen-turbo", "qwen-plus", "qwen1.5-72b-chat"]
pan-x-c marked this conversation as resolved.
Show resolved Hide resolved


class TongyiWrapper(ModelWrapperBase):
qbc2016 marked this conversation as resolved.
Show resolved Hide resolved
"""The model wrapper for Tongyi API."""
Expand Down Expand Up @@ -100,10 +105,10 @@ def _metric(self, metric_name: str) -> str:
return get_full_name(name=metric_name, prefix=self.model)


class TongyiChatWrapper(TongyiWrapper):
"""The model wrapper for Tongyi's chat API."""
class QwenChatWrapper(TongyiWrapper):
"""The model wrapper for Qwen's chat API."""

model_type: str = "tongyi_chat"
model_type: str = "qwen_chat"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
Expand All @@ -127,22 +132,22 @@ def __call__(
messages: list,
**kwargs: Any,
) -> ModelResponse:
"""Processes a list of messages to construct a payload for the Tongyi
API call. It then makes a request to the Tongyi API and returns the
"""Processes a list of messages to construct a payload for the Qwen
API call. It then makes a request to the Qwen API and returns the
response. This method also updates monitoring metrics based on the
API response.

Each message in the 'messages' list can contain text content and
optionally an 'image_urls' key. If 'image_urls' is provided,
it is expected to be a list of strings representing URLs to images.
These URLs will be transformed to a suitable format for the Tongyi
These URLs will be transformed to a suitable format for the Qwen
API, which might involve converting local file paths to data URIs.

Args:
messages (`list`):
A list of messages to process.
**kwargs (`Any`):
The keyword arguments to Tongyi chat completions API,
The keyword arguments to Qwen chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to

for more detailed arguments.
Expand Down Expand Up @@ -173,15 +178,9 @@ def __call__(
if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for Tongyi API.",
"and 'content' key for Qwen API.",
)

# For Tongyi model, the "role" value of the first and the last message
# must be "user"
if len(messages) > 0:
messages[0]["role"] = "user"
messages[-1]["role"] = "user"

# step3: forward to generate response
response = dashscope.Generation.call(
model=self.model,
Expand All @@ -190,6 +189,23 @@ def __call__(
**kwargs,
)

if response.status_code == 400:
logger.warning(
"Initial API call failed with status 400. Attempting role "
"preprocessing and retrying. You'd better do it yourself in "
"the prompt engineering to satisfy the model call rule.",
)
# TODO: remove this and leave prompt engineering to user
messages = self._preprocess_role(messages)
# Retry the API call
response = dashscope.Generation.call(
model=self.model,
messages=messages,
result_format="message",
# set the result to be "message" format.
**kwargs,
)

# step4: record the api invocation if needed
self._save_model_invocation(
arguments={
Expand All @@ -200,18 +216,63 @@ def __call__(
json_response=response,
)

# TODO: Add monitor for Qwen?
# step5: update monitor accordingly
try:
self.monitor.update(
response.usage,
prefix=self.model,
)
except QuotaExceededError as e:
# TODO: optimize quota exceeded error handling process
logger.error(e.message)
# try:
# self.monitor.update(
# response.usage,
qbc2016 marked this conversation as resolved.
Show resolved Hide resolved
# prefix=self.model,
# )
# except QuotaExceededError as e:
# logger.error(e.message)

# step6: return response
return ModelResponse(
text=response.output["choices"][0]["message"]["content"],
raw=response,
)
if response.status_code == HTTPStatus.OK:
return ModelResponse(
text=response.output["choices"][0]["message"]["content"],
raw=response,
)
else:
error_msg = (
f"Request id: {response.request_id},"
f" Status code: {response.status_code},"
f" error code: {response.code},"
f" error message: {response.message}."
)

raise RuntimeError(error_msg)

def _preprocess_role(self, messages: list) -> list:
qbc2016 marked this conversation as resolved.
Show resolved Hide resolved
"""preprocess role rules for Qwen"""
if self.model in SPECIAL_MODEL_LIST:
# The models in this list require that the roles of messages must
# alternate between "user" and "assistant".
message_length = len(messages)
if message_length % 2 == 1:
# If the length of the message list is odd, roles will
# alternate, starting with "user"
roles = [
"user" if i % 2 == 0 else "assistant"
for i in range(message_length)
]
else:
# If the length of the message list is even, the first role
# will be "system", followed by alternating "user" and
# "assistant"
roles = ["system"] + [
"user" if i % 2 == 1 else "assistant"
for i in range(1, message_length)
]

# Assign the roles list to the "role" key for each message in
# the messages list
for message, role in zip(messages, roles):
message["role"] = role
else:
# For other Qwen models, the "role" value of the first and the
# last messages must be "user"
if len(messages) > 0:
messages[0]["role"] = "user"
messages[-1]["role"] = "user"

return messages