Skip to content

Commit

Permalink
[Hot Fix] Fix _ModelConfig state get and set (#397)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: DavdGao <[email protected]>
  • Loading branch information
pan-x-c and DavdGao authored Aug 12, 2024
1 parent 21161fe commit 100e8cb
Showing 1 changed file with 26 additions and 53 deletions.
79 changes: 26 additions & 53 deletions src/agentscope/manager/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def load_model_configs(
f"list of dicts), but got {type(model_configs)}",
)

format_configs = _ModelConfig.format_configs(configs=cfgs)
formatted_configs = _format_configs(configs=cfgs)

# check if name is unique
for cfg in format_configs:
if cfg.config_name in self.model_configs:
for cfg in formatted_configs:
if cfg["config_name"] in self.model_configs:
logger.warning(
f"config_name [{cfg.config_name}] already exists.",
f"config_name [{cfg['config_name']}] already exists.",
)
continue
self.model_configs[cfg.config_name] = cfg
self.model_configs[cfg["config_name"]] = cfg

# print the loaded model configs
logger.info(
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_model_by_config_name(self, config_name: str) -> ModelWrapperBase:
f"Cannot find [{config_name}] in loaded configurations.",
)

model_type = config.model_type
model_type = config["model_type"]

kwargs = {k: v for k, v in config.items() if k != "model_type"}

Expand All @@ -164,55 +164,28 @@ def flush(self) -> None:
self.clear_model_configs()


class _ModelConfig(dict):
"""Base class for model config."""
def _format_configs(
configs: Union[Sequence[dict], dict],
) -> Sequence:
"""Check the format of model configs.
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
Args:
configs (Union[Sequence[dict], dict]): configs in dict format.
def __init__(
self,
config_name: str,
model_type: str = None,
**kwargs: Any,
):
"""Initialize the config with the given arguments, and checking the
type of the arguments.
Args:
config_name (`str`): A unique name of the model config.
model_type (`str`, optional): The class name (or its model type) of
the generated model wrapper. Defaults to None.
Raises:
`ValueError`: If `config_name` is not provided.
"""
if config_name is None:
raise ValueError("The `config_name` field is required for Cfg")
if model_type is None:
Returns:
Sequence[dict]: converted ModelConfig list.
"""
if isinstance(configs, dict):
configs = [configs]
for config in configs:
if "config_name" not in config:
raise ValueError(
"The `config_name` field is required for Cfg",
)
if "model_type" not in config:
logger.warning(
f"`model_type` is not provided in config [{config_name}],"
"`model_type` is not provided in config"
f"[{config['config_name']}],"
" use `PostAPIModelWrapperBase` by default.",
)
super().__init__(
config_name=config_name,
model_type=model_type,
**kwargs,
)

@classmethod
def format_configs(
cls,
configs: Union[Sequence[dict], dict],
) -> Sequence:
"""Covert config dicts into a list of _ModelConfig.
Args:
configs (Union[Sequence[dict], dict]): configs in dict format.
Returns:
Sequence[_ModelConfig]: converted ModelConfig list.
"""
if isinstance(configs, dict):
return [_ModelConfig(**configs)]
return [_ModelConfig(**cfg) for cfg in configs]
return configs

0 comments on commit 100e8cb

Please sign in to comment.