diff --git a/camel/agents/role_playing.py b/camel/agents/role_playing.py index 3d12a1e8a..48d979971 100644 --- a/camel/agents/role_playing.py +++ b/camel/agents/role_playing.py @@ -12,7 +12,7 @@ # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== import copy -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union from camel.agents import ( ChatAgent, @@ -167,8 +167,8 @@ def __init__( else: self.critic = None - def init_chat(self, phase_type: PhaseType = None, - placeholders=None, phase_prompt=None): + def init_chat(self, phase_type: Union[PhaseType, None], + placeholders, phase_prompt: str): r"""Initializes the chat by resetting both the assistant and user agents, and sending the system messages again to the agents using chat messages. Returns the assistant's introductory message and the diff --git a/chatdev/chat_chain.py b/chatdev/chat_chain.py index 61f5d5f5d..da0386612 100644 --- a/chatdev/chat_chain.py +++ b/chatdev/chat_chain.py @@ -5,6 +5,7 @@ import shutil import time from datetime import datetime +from typing import Union from camel.agents import RolePlaying from camel.configs import ChatGPTConfig @@ -14,21 +15,31 @@ from chatdev.utils import log_and_print_online, now -def check_bool(s): - return s.lower() == "true" - +def check_bool(s: Union[str, bool]) -> bool: + """ Normalizes a string or bool to a bool value. + String must be either "True" or "False" (case insensitive). + """ + if isinstance(s, bool): + return s + else: + if s.lower() == "true": + return True + elif s.lower() == "false": + return False + else: + raise ValueError(f"Cannot convert '{s}' in config to bool") class ChatChain: def __init__(self, - config_path: str = None, - config_phase_path: str = None, - config_role_path: str = None, - task_prompt: str = None, - project_name: str = None, - org_name: str = None, + config_path: str, + config_phase_path: str, + config_role_path: str, + task_prompt: str, + project_name: Union[str, None] = None, + org_name: Union[str, None] = None, model_type: ModelType = ModelType.GPT_3_5_TURBO, - code_path: str = None) -> None: + code_path: Union[str, None] = None): """ Args: @@ -38,6 +49,8 @@ def __init__(self, task_prompt: the user input prompt for software project_name: the user input name for software org_name: the organization name of the human user + model_type: the model type for chatbot + code_path: the path to the code files, if working incrementally """ # load config file @@ -70,6 +83,11 @@ def __init__(self, incremental_develop=check_bool(self.config["incremental_develop"])) self.chat_env = ChatEnv(self.chat_env_config) + if not check_bool(self.config["incremental_develop"]): + if self.code_path: + # TODO: in this case, the code_path is used as the target (instead of the WareHouse directory) + raise RuntimeError("code_path is given, but Phase Config specifies incremental_develop=False. code_path will be ignored.") + # the user input prompt will be self-improved (if set "self_improve": "True" in ChatChainConfig.json) # the self-improvement is done in self.preprocess self.task_prompt_raw = task_prompt @@ -81,7 +99,8 @@ def __init__(self, self.role_prompts[role] = "\n".join(self.config_role[role]) # init log - self.start_time, self.log_filepath = self.get_logfilepath() + self.start_time: str = now() + self.log_filepath = self.get_logfilepath(self.start_time) # init SimplePhase instances # import all used phases in PhaseConfig.json from chatdev.phase @@ -162,23 +181,15 @@ def execute_chain(self): for phase_item in self.chain: self.execute_step(phase_item) - def get_logfilepath(self): + def get_logfilepath(self, start_time: str) -> str: """ - get the log path (under the software path) - Returns: - start_time: time for starting making the software - log_filepath: path to the log - + Returns log_filepath as a str path to the log (under the project's path). """ - start_time = now() filepath = os.path.dirname(__file__) - # root = "/".join(filepath.split("/")[:-1]) root = os.path.dirname(filepath) - # directory = root + "/WareHouse/" directory = os.path.join(root, "WareHouse") - log_filepath = os.path.join(directory, - "{}.log".format("_".join([self.project_name, self.org_name, start_time]))) - return start_time, log_filepath + log_filepath = os.path.join(directory, f"{self.project_name}_{self.org_name}_{start_time}.log") + return log_filepath def pre_processing(self): """ @@ -195,9 +206,9 @@ def pre_processing(self): # logs with error trials are left in WareHouse/ if os.path.isfile(file_path) and not filename.endswith(".py") and not filename.endswith(".log"): os.remove(file_path) - print("{} Removed.".format(file_path)) + print(f"{file_path} Removed.") - software_path = os.path.join(directory, "_".join([self.project_name, self.org_name, self.start_time])) + software_path = os.path.join(directory, f"{self.project_name}_{self.org_name}_{self.start_time}") self.chat_env.set_directory(software_path) # copy config files to software path @@ -207,6 +218,9 @@ def pre_processing(self): # copy code files to software path in incremental_develop mode if check_bool(self.config["incremental_develop"]): + if not self.code_path: + raise RuntimeError("code_path is not given, but working in incremental_develop mode.") + for root, dirs, files in os.walk(self.code_path): relative_path = os.path.relpath(root, self.code_path) target_dir = os.path.join(software_path, 'base', relative_path) @@ -218,7 +232,7 @@ def pre_processing(self): self.chat_env._load_from_hardware(os.path.join(software_path, 'base')) # write task prompt to software - with open(os.path.join(software_path, self.project_name + ".prompt"), "w") as f: + with open(os.path.join(software_path, f"{self.project_name}.prompt"), "w") as f: f.write(self.task_prompt_raw) preprocess_msg = "**[Preprocessing]**\n\n" @@ -306,7 +320,7 @@ def post_processing(self): time.sleep(1) shutil.move(self.log_filepath, - os.path.join(root + "/WareHouse", "_".join([self.project_name, self.org_name, self.start_time]), + os.path.join(root, "WareHouse", f"{self.project_name}_{self.org_name}_{self.start_time}", os.path.basename(self.log_filepath))) # @staticmethod diff --git a/chatdev/utils.py b/chatdev/utils.py index 80b4ff1ad..846980854 100644 --- a/chatdev/utils.py +++ b/chatdev/utils.py @@ -9,7 +9,7 @@ from online_log.app import send_msg -def now(): +def now() -> str: return time.strftime("%Y%m%d%H%M%S", time.localtime()) diff --git a/run.py b/run.py index b92139d68..659a148f1 100644 --- a/run.py +++ b/run.py @@ -15,6 +15,7 @@ import logging import os import sys +from typing import Tuple, List from camel.typing import ModelType @@ -24,15 +25,16 @@ from chatdev.chat_chain import ChatChain -def get_config(company): +def get_config(company: str) -> Tuple[str, str, str]: """ - return configuration json files for ChatChain - user can customize only parts of configuration json files, other files will be left for default + Returns configuration JSON files for ChatChain. + User can modify certain configuration JSON files, and the Default will be used for the other files. + Args: - company: customized configuration name under CompanyConfig/ + company: customized configuration name (subdir name in CompanyConfig/) Returns: - path to three configuration jsons: [config_path, config_phase_path, config_role_path] + tuple of str paths to three configuration JSON filess: [config_path, config_phase_path, config_role_path] """ config_dir = os.path.join(root, "CompanyConfig", company) default_config_dir = os.path.join(root, "CompanyConfig", "Default") @@ -43,7 +45,7 @@ def get_config(company): "RoleConfig.json" ] - config_paths = [] + config_paths: List[str] = [] for config_file in config_files: company_config_path = os.path.join(config_dir, config_file) @@ -51,25 +53,27 @@ def get_config(company): if os.path.exists(company_config_path): config_paths.append(company_config_path) - else: + elif os.path.exists(default_config_path): config_paths.append(default_config_path) - - return tuple(config_paths) + else: + raise FileNotFoundError(f"Cannot find {config_file} in config_dir={config_dir} nor default_config_dir={default_config_dir}") + + return (config_paths[0], config_paths[1], config_paths[2]) parser = argparse.ArgumentParser(description='argparse') parser.add_argument('--config', type=str, default="Default", - help="Name of config, which is used to load configuration under CompanyConfig/") + help="Name of config (subdir in CompanyConfig/)") parser.add_argument('--org', type=str, default="DefaultOrganization", - help="Name of organization, your software will be generated in WareHouse/name_org_timestamp") + help="Name of organization (your software will be generated in WareHouse/name_org_timestamp)") parser.add_argument('--task', type=str, default="Develop a basic Gomoku game.", help="Prompt of software") parser.add_argument('--name', type=str, default="Gomoku", - help="Name of software, your software will be generated in WareHouse/name_org_timestamp") -parser.add_argument('--model', type=str, default="GPT_3_5_TURBO", + help="Name of software (your software will be generated in WareHouse/name_org_timestamp)") +parser.add_argument('--model', type=str, default="GPT_3_5_TURBO", choices=['GPT_3_5_TURBO', 'GPT_4', 'GPT_4_32K'], help="GPT Model, choose from {'GPT_3_5_TURBO','GPT_4','GPT_4_32K'}") -parser.add_argument('--path', type=str, default="", - help="Your file directory, ChatDev will build upon your software in the Incremental mode") +parser.add_argument('--path', type=str, default=None, + help="Your file directory. If given, ChatDev will build upon your software in the Incremental mode.") args = parser.parse_args() # Start ChatDev