Skip to content

Commit

Permalink
fix: add json schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Nov 29, 2024
1 parent 07a02f6 commit 5bcbae8
Showing 1 changed file with 115 additions and 80 deletions.
195 changes: 115 additions & 80 deletions python/dify_plugin/entities/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ class DefaultParameterName(Enum):
"""
Enum class for parameter template variable.
"""

TEMPERATURE = "temperature"
TOP_P = "top_p"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"

@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
def value_of(cls, value: Any) -> "DefaultParameterName":
"""
Get parameter name from value.
Expand All @@ -27,109 +29,122 @@ def value_of(cls, value: Any) -> 'DefaultParameterName':
for name in cls:
if name.value == value:
return name
raise ValueError(f'invalid parameter name {value}')
raise ValueError(f"invalid parameter name {value}")


PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
'label': {
'en_US': 'Temperature',
'zh_Hans': '温度',
"label": {
"en_US": "Temperature",
"zh_Hans": "温度",
},
'type': 'float',
'help': {
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_P: {
'label': {
'en_US': 'Top P',
'zh_Hans': 'Top P',
"label": {
"en_US": "Top P",
"zh_Hans": "Top P",
},
'type': 'float',
'help': {
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。',
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
},
'required': False,
'default': 1.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 1.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
"label": {
"en_US": "Presence Penalty",
"zh_Hans": "存在惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
'label': {
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
"label": {
"en_US": "Frequency Penalty",
"zh_Hans": "频率惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.MAX_TOKENS: {
'label': {
'en_US': 'Max Tokens',
'zh_Hans': '最大标记',
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大标记",
},
'type': 'int',
'help': {
'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.',
'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。',
"type": "int",
"help": {
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
},
'required': False,
'default': 64,
'min': 1,
'max': 2048,
'precision': 0,
"required": False,
"default": 64,
"min": 1,
"max": 2048,
"precision": 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
"label": {
"en_US": "Response Format",
"zh_Hans": "回复格式",
},
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
"required": False,
"options": ["JSON", "XML"],
},
DefaultParameterName.JSON_SCHEMA: {
"label": {
"en_US": "JSON Schema",
},
"type": "text",
"help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的json schema,llm将按照它返回",
},
'required': False,
'options': ['JSON', 'XML'],
}
"required": False,
},
}


class ModelType(Enum):
"""
Enum class for model type.
"""

LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
Expand All @@ -138,10 +153,12 @@ class ModelType(Enum):
TTS = "tts"
TEXT2IMG = "text2img"


class FetchFrom(Enum):
"""
Enum class for fetch from.
"""

PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"

Expand All @@ -150,6 +167,7 @@ class ModelFeature(Enum):
"""
Enum class for llm feature.
"""

TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
Expand All @@ -164,6 +182,7 @@ class ParameterType(Enum):
"""
Enum class for parameter type.
"""

FLOAT = "float"
INT = "int"
STRING = "string"
Expand All @@ -174,6 +193,7 @@ class ModelPropertyKey(Enum):
"""
Enum class for model property key.
"""

MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
Expand All @@ -191,6 +211,7 @@ class ProviderModel(BaseModel):
"""
Model class for provider model.
"""

model: str
label: I18nObject
model_type: ModelType
Expand All @@ -203,18 +224,21 @@ class ProviderModel(BaseModel):
"""
use model as label
"""
@model_validator(mode='before')

@model_validator(mode="before")
def validate_label(cls, data: dict) -> dict:
if isinstance(data, dict):
if not data.get("label"):
data["label"] = I18nObject(en_US=data["model"])

return data


class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""

name: str
use_template: Optional[str] = None
label: I18nObject
Expand All @@ -227,31 +251,39 @@ class ParameterRule(BaseModel):
precision: Optional[int] = None
options: list[str] = []

@model_validator(mode='before')
@model_validator(mode="before")
def validate_label(cls, data: dict) -> dict:
if isinstance(data, dict):
if not data.get("label"):
data["label"] = I18nObject(en_US=data["name"])

# check if there is a template
if 'use_template' in data:
if "use_template" in data:
try:
default_parameter_name = DefaultParameterName.value_of(data['use_template'])
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(default_parameter_name)
default_parameter_name = DefaultParameterName.value_of(
data["use_template"]
)
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(
default_parameter_name
)
if not default_parameter_rule:
raise Exception(f"Invalid model parameter rule name {default_parameter_name}")
raise Exception(
f"Invalid model parameter rule name {default_parameter_name}"
)
copy_default_parameter_rule = default_parameter_rule.copy()
copy_default_parameter_rule.update(data)
data = copy_default_parameter_rule
except ValueError:
pass

return data


class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""

input: Decimal
output: Optional[Decimal] = None
unit: Decimal
Expand All @@ -262,6 +294,7 @@ class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""

parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None

Expand All @@ -274,6 +307,7 @@ class PriceType(Enum):
"""
Enum class for price type.
"""

INPUT = "input"
OUTPUT = "output"

Expand All @@ -282,6 +316,7 @@ class PriceInfo(BaseModel):
"""
Model class for price info.
"""

unit_price: Decimal
unit: Decimal
total_amount: Decimal
Expand Down

0 comments on commit 5bcbae8

Please sign in to comment.