diff --git a/agent.py b/agent.py index 0cdb572b6..b3c6cc208 100644 --- a/agent.py +++ b/agent.py @@ -15,73 +15,79 @@ from python.helpers.dirty_json import DirtyJson from python.helpers.defer import DeferredTask + class AgentContext: - _contexts: dict[str, 'AgentContext'] = {} + _contexts: dict[str, "AgentContext"] = {} _counter: int = 0 - - def __init__(self, config: 'AgentConfig', id:str|None = None, agent0: 'Agent|None' = None): + + def __init__( + self, config: "AgentConfig", id: str | None = None, agent0: "Agent|None" = None + ): # build context self.id = id or str(uuid.uuid4()) self.config = config self.log = Log.Log() self.agent0 = agent0 or Agent(0, self.config, self) self.paused = False - self.streaming_agent: Agent|None = None - self.process: DeferredTask|None = None + self.streaming_agent: Agent | None = None + self.process: DeferredTask | None = None AgentContext._counter += 1 - self.no = AgentContext._counter + self.no = AgentContext._counter self._contexts[self.id] = self @staticmethod - def get(id:str): + def get(id: str): return AgentContext._contexts.get(id, None) @staticmethod def first(): - if not AgentContext._contexts: return None + if not AgentContext._contexts: + return None return list(AgentContext._contexts.values())[0] - @staticmethod - def remove(id:str): + def remove(id: str): context = AgentContext._contexts.pop(id, None) - if context and context.process: context.process.kill() + if context and context.process: + context.process.kill() return context def reset(self): - if self.process: self.process.kill() + if self.process: + self.process.kill() self.log.reset() self.agent0 = Agent(0, self.config, self) self.streaming_agent = None - self.paused = False + self.paused = False - def communicate(self, msg: str, broadcast_level: int = 1): - self.paused=False #unpause if paused - + self.paused = False # unpause if paused + if self.process and self.process.is_alive(): - if self.streaming_agent: current_agent = self.streaming_agent - else: current_agent = self.agent0 + if self.streaming_agent: + current_agent = self.streaming_agent + else: + current_agent = self.agent0 # set intervention messages to agent(s): intervention_agent = current_agent - while intervention_agent and broadcast_level !=0: + while intervention_agent and broadcast_level != 0: intervention_agent.intervention_message = msg broadcast_level -= 1 - intervention_agent = intervention_agent.data.get("superior",None) + intervention_agent = intervention_agent.data.get("superior", None) else: self.process = DeferredTask(self.agent0.message_loop, msg) return self.process - - + + @dataclass -class AgentConfig: +class AgentConfig: chat_model: BaseChatModel | BaseLLM utility_model: BaseChatModel | BaseLLM - embeddings_model:Embeddings + embeddings_model: Embeddings prompts_subdir: str = "" memory_subdir: str = "" knowledge_subdir: str = "" @@ -99,8 +105,14 @@ class AgentConfig: code_exec_docker_enabled: bool = True code_exec_docker_name: str = "agent-zero-exe" code_exec_docker_image: str = "frdel/agent-zero-exe:latest" - code_exec_docker_ports: dict[str,int] = field(default_factory=lambda: {"22/tcp": 50022}) - code_exec_docker_volumes: dict[str, dict[str, str]] = field(default_factory=lambda: {files.get_abs_path("work_dir"): {"bind": "/root", "mode": "rw"}}) + code_exec_docker_ports: dict[str, int] = field( + default_factory=lambda: {"22/tcp": 50022} + ) + code_exec_docker_volumes: dict[str, dict[str, str]] = field( + default_factory=lambda: { + files.get_abs_path("work_dir"): {"bind": "/root", "mode": "rw"} + } + ) code_exec_ssh_enabled: bool = True code_exec_ssh_addr: str = "localhost" code_exec_ssh_port: int = 50022 @@ -108,20 +120,25 @@ class AgentConfig: code_exec_ssh_pass: str = "toor" additional: Dict[str, Any] = field(default_factory=dict) + # intervention exception class - skips rest of message loop iteration class InterventionException(Exception): pass -# killer exception class - not forwarded to LLM, cannot be fixed on its own, ends message loop -class KillerException(Exception): + +# repairable exception class - forwarded to LLM, may be fixed on its own +class RepairableException(Exception): pass + class Agent: - - def __init__(self, number:int, config: AgentConfig, context: AgentContext|None = None): - # agent config - self.config = config + def __init__( + self, number: int, config: AgentConfig, context: AgentContext | None = None + ): + + # agent config + self.config = config # agent context self.context = context or AgentContext(config) @@ -133,104 +150,162 @@ def __init__(self, number:int, config: AgentConfig, context: AgentContext|None = self.history = [] self.last_message = "" self.intervention_message = "" - self.rate_limiter = rate_limiter.RateLimiter(self.context.log,max_calls=self.config.rate_limit_requests,max_input_tokens=self.config.rate_limit_input_tokens,max_output_tokens=self.config.rate_limit_output_tokens,window_seconds=self.config.rate_limit_seconds) - self.data = {} # free data object all the tools can use + self.rate_limiter = rate_limiter.RateLimiter( + self.context.log, + max_calls=self.config.rate_limit_requests, + max_input_tokens=self.config.rate_limit_input_tokens, + max_output_tokens=self.config.rate_limit_output_tokens, + window_seconds=self.config.rate_limit_seconds, + ) + self.data = {} # free data object all the tools can use async def message_loop(self, msg: str): try: - printer = PrintStyle(italic=True, font_color="#b3ffd9", padding=False) + printer = PrintStyle(italic=True, font_color="#b3ffd9", padding=False) user_message = self.read_prompt("fw.user_message.md", message=msg) - await self.append_message(user_message, human=True) # Append the user's input to the history + await self.append_message( + user_message, human=True + ) # Append the user's input to the history memories = await self.fetch_memories(True) - - while True: # let the agent iterate on his thoughts until he stops by using a tool - self.context.streaming_agent = self #mark self as current streamer + + while ( + True + ): # let the agent iterate on his thoughts until he stops by using a tool + self.context.streaming_agent = self # mark self as current streamer agent_response = "" try: - system = self.read_prompt("agent.system.md", agent_name=self.agent_name) + "\n\n" + self.read_prompt("agent.tools.md") + system = ( + self.read_prompt("agent.system.md", agent_name=self.agent_name) + + "\n\n" + + self.read_prompt("agent.tools.md") + ) memories = await self.fetch_memories() - if memories: system+= "\n\n"+memories + if memories: + system += "\n\n" + memories + + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage(content=system), + MessagesPlaceholder(variable_name="messages"), + ] + ) - prompt = ChatPromptTemplate.from_messages([ - SystemMessage(content=system), - MessagesPlaceholder(variable_name="messages") ]) - inputs = {"messages": self.history} chain = prompt | self.config.chat_model formatted_inputs = prompt.format(messages=self.history) - tokens = int(len(formatted_inputs)/4) + tokens = int(len(formatted_inputs) / 4) self.rate_limiter.limit_call_and_input(tokens) - + # output that the agent is starting - PrintStyle(bold=True, font_color="green", padding=True, background_color="white").print(f"{self.agent_name}: Generating:") - log = self.context.log.log(type="agent", heading=f"{self.agent_name}: Generating:") - + PrintStyle( + bold=True, + font_color="green", + padding=True, + background_color="white", + ).print(f"{self.agent_name}: Generating:") + log = self.context.log.log( + type="agent", heading=f"{self.agent_name}: Generating:" + ) + async for chunk in chain.astream(inputs): - await self.handle_intervention(agent_response) # wait for intervention and handle it, if paused + await self.handle_intervention( + agent_response + ) # wait for intervention and handle it, if paused + + if isinstance(chunk, str): + content = chunk + elif hasattr(chunk, "content"): + content = str(chunk.content) + else: + content = str(chunk) - if isinstance(chunk, str): content = chunk - elif hasattr(chunk, "content"): content = str(chunk.content) - else: content = str(chunk) - if content: - printer.stream(content) # output the agent response stream - agent_response += content # concatenate stream into the response + printer.stream(content) # output the agent response stream + agent_response += ( + content # concatenate stream into the response + ) self.log_from_stream(agent_response, log) - self.rate_limiter.set_output_tokens(int(len(agent_response)/4)) # rough estimation - + self.rate_limiter.set_output_tokens( + int(len(agent_response) / 4) + ) # rough estimation + await self.handle_intervention(agent_response) - if self.last_message == agent_response: #if assistant_response is the same as last message in history, let him know - await self.append_message(agent_response) # Append the assistant's response to the history + if ( + self.last_message == agent_response + ): # if assistant_response is the same as last message in history, let him know + await self.append_message( + agent_response + ) # Append the assistant's response to the history warning_msg = self.read_prompt("fw.msg_repeat.md") - await self.append_message(warning_msg, human=True) # Append warning message to the history + await self.append_message( + warning_msg, human=True + ) # Append warning message to the history PrintStyle(font_color="orange", padding=True).print(warning_msg) self.context.log.log(type="warning", content=warning_msg) - else: #otherwise proceed with tool - await self.append_message(agent_response) # Append the assistant's response to the history - tools_result = await self.process_tools(agent_response) # process tools requested in agent message - if tools_result: #final response of message loop available - return tools_result #break the execution if the task is done + else: # otherwise proceed with tool + await self.append_message( + agent_response + ) # Append the assistant's response to the history + tools_result = await self.process_tools( + agent_response + ) # process tools requested in agent message + if tools_result: # final response of message loop available + return ( + tools_result # break the execution if the task is done + ) except InterventionException as e: - pass # intervention message has been handled in handle_intervention(), proceed with conversation loop + pass # intervention message has been handled in handle_intervention(), proceed with conversation loop except asyncio.CancelledError as e: - PrintStyle(font_color="white", background_color="red", padding=True).print(f"Context {self.context.id} terminated during message loop") - raise e # process cancelled from outside, kill the loop - except KillerException as e: - error_message = errors.format_error(e) - self.context.log.log(type="error", content=error_message) - raise e # kill the loop - except Exception as e: # Forward other errors to the LLM, maybe it can fix them + PrintStyle( + font_color="white", background_color="red", padding=True + ).print(f"Context {self.context.id} terminated during message loop") + raise e # process cancelled from outside, kill the loop + except RepairableException as e: # Forward repairable errors to the LLM, maybe it can fix them error_message = errors.format_error(e) - msg_response = self.read_prompt("fw.error.md", error=error_message) # error message template + msg_response = self.read_prompt( + "fw.error.md", error=error_message + ) # error message template await self.append_message(msg_response, human=True) PrintStyle(font_color="red", padding=True).print(msg_response) self.context.log.log(type="error", content=msg_response) - + except Exception as e: # Other exception kill the loop + error_message = errors.format_error(e) + PrintStyle(font_color="red", padding=True).print(error_message) + self.context.log.log(type="error", content=error_message) + raise e # kill the loop + finally: - self.context.streaming_agent = None # unset current streamer + self.context.streaming_agent = None # unset current streamer - def read_prompt(self, file:str, **kwargs): + def read_prompt(self, file: str, **kwargs): content = "" if self.config.prompts_subdir: try: - content = files.read_file(files.get_abs_path(f"./prompts/{self.config.prompts_subdir}/{file}"), **kwargs) + content = files.read_file( + files.get_abs_path( + f"./prompts/{self.config.prompts_subdir}/{file}" + ), + **kwargs, + ) except Exception as e: pass if not content: - content = files.read_file(files.get_abs_path(f"./prompts/default/{file}"), **kwargs) + content = files.read_file( + files.get_abs_path(f"./prompts/default/{file}"), **kwargs + ) return content - def get_data(self, field:str): + def get_data(self, field: str): return self.data.get(field, None) - def set_data(self, field:str, value): + def set_data(self, field: str, value): self.data[field] = value async def append_message(self, msg: str, human: bool = False): @@ -240,17 +315,21 @@ async def append_message(self, msg: str, human: bool = False): else: new_message = HumanMessage(content=msg) if human else AIMessage(content=msg) self.history.append(new_message) - await self.cleanup_history(self.config.msgs_keep_max, self.config.msgs_keep_start, self.config.msgs_keep_end) - if message_type=="ai": + await self.cleanup_history( + self.config.msgs_keep_max, + self.config.msgs_keep_start, + self.config.msgs_keep_end, + ) + if message_type == "ai": self.last_message = msg - def concat_messages(self,messages): + def concat_messages(self, messages): return "\n".join([f"{msg.type}: {msg.content}" for msg in messages]) - async def send_adhoc_message(self, system: str, msg: str, output_label:str): - prompt = ChatPromptTemplate.from_messages([ - SystemMessage(content=system), - HumanMessage(content=msg)]) + async def send_adhoc_message(self, system: str, msg: str, output_label: str): + prompt = ChatPromptTemplate.from_messages( + [SystemMessage(content=system), HumanMessage(content=msg)] + ) chain = prompt | self.config.utility_model response = "" @@ -258,40 +337,54 @@ async def send_adhoc_message(self, system: str, msg: str, output_label:str): logger = None if output_label: - PrintStyle(bold=True, font_color="orange", padding=True, background_color="white").print(f"{self.agent_name}: {output_label}:") + PrintStyle( + bold=True, font_color="orange", padding=True, background_color="white" + ).print(f"{self.agent_name}: {output_label}:") printer = PrintStyle(italic=True, font_color="orange", padding=False) - logger = self.context.log.log(type="adhoc", heading=f"{self.agent_name}: {output_label}:") + logger = self.context.log.log( + type="adhoc", heading=f"{self.agent_name}: {output_label}:" + ) formatted_inputs = prompt.format() - tokens = int(len(formatted_inputs)/4) + tokens = int(len(formatted_inputs) / 4) self.rate_limiter.limit_call_and_input(tokens) - + async for chunk in chain.astream({}): - if self.handle_intervention(): break # wait for intervention and handle it, if paused + if self.handle_intervention(): + break # wait for intervention and handle it, if paused - if isinstance(chunk, str): content = chunk - elif hasattr(chunk, "content"): content = str(chunk.content) - else: content = str(chunk) + if isinstance(chunk, str): + content = chunk + elif hasattr(chunk, "content"): + content = str(chunk.content) + else: + content = str(chunk) - if printer: printer.stream(content) - response+=content - if logger: logger.update(content=response) + if printer: + printer.stream(content) + response += content + if logger: + logger.update(content=response) - self.rate_limiter.set_output_tokens(int(len(response)/4)) + self.rate_limiter.set_output_tokens(int(len(response) / 4)) return response - + def get_last_message(self): if self.history: return self.history[-1] - async def replace_middle_messages(self,middle_messages): + async def replace_middle_messages(self, middle_messages): cleanup_prompt = self.read_prompt("fw.msg_cleanup.md") - summary = await self.send_adhoc_message(system=cleanup_prompt,msg=self.concat_messages(middle_messages), output_label="Mid messages cleanup summary") + summary = await self.send_adhoc_message( + system=cleanup_prompt, + msg=self.concat_messages(middle_messages), + output_label="Mid messages cleanup summary", + ) new_human_message = HumanMessage(content=summary) return [new_human_message] - async def cleanup_history(self, max:int, keep_start:int, keep_end:int): + async def cleanup_history(self, max: int, keep_start: int, keep_end: int): if len(self.history) <= max: return self.history @@ -317,14 +410,24 @@ async def cleanup_history(self, max:int, keep_start:int, keep_end:int): return self.history - async def handle_intervention(self, progress:str=""): - while self.context.paused: await asyncio.sleep(0.1) # wait if paused - if self.intervention_message: # if there is an intervention message, but not yet processed + async def handle_intervention(self, progress: str = ""): + while self.context.paused: + await asyncio.sleep(0.1) # wait if paused + if ( + self.intervention_message + ): # if there is an intervention message, but not yet processed msg = self.intervention_message - self.intervention_message = "" # reset the intervention message - if progress.strip(): await self.append_message(progress) # append the response generated so far - user_msg = self.read_prompt("fw.intervention.md", user_message=msg) # format the user intervention template - await self.append_message(user_msg,human=True) # append the intervention message + self.intervention_message = "" # reset the intervention message + if progress.strip(): + await self.append_message( + progress + ) # append the response generated so far + user_msg = self.read_prompt( + "fw.intervention.md", user_message=msg + ) # format the user intervention template + await self.append_message( + user_msg, human=True + ) # append the intervention message raise InterventionException(msg) async def process_tools(self, msg: str): @@ -335,30 +438,36 @@ async def process_tools(self, msg: str): tool_name = tool_request.get("tool_name", "") tool_args = tool_request.get("tool_args", {}) tool = self.get_tool(tool_name, tool_args, msg) - - await self.handle_intervention() # wait if paused and handle intervention message if needed + + await self.handle_intervention() # wait if paused and handle intervention message if needed await tool.before_execution(**tool_args) - await self.handle_intervention() # wait if paused and handle intervention message if needed + await self.handle_intervention() # wait if paused and handle intervention message if needed response = await tool.execute(**tool_args) - await self.handle_intervention() # wait if paused and handle intervention message if needed + await self.handle_intervention() # wait if paused and handle intervention message if needed await tool.after_execution(response) - await self.handle_intervention() # wait if paused and handle intervention message if needed - if response.break_loop: return response.message + await self.handle_intervention() # wait if paused and handle intervention message if needed + if response.break_loop: + return response.message else: msg = self.read_prompt("fw.msg_misformat.md") await self.append_message(msg, human=True) PrintStyle(font_color="red", padding=True).print(msg) - self.context.log.log(type="error", content=f"{self.agent_name}: Message misformat:") - + self.context.log.log( + type="error", content=f"{self.agent_name}: Message misformat:" + ) def get_tool(self, name: str, args: dict, message: str, **kwargs): - from python.tools.unknown import Unknown + from python.tools.unknown import Unknown from python.helpers.tool import Tool - + tool_class = Unknown - if files.exists("python/tools",f"{name}.py"): - module = importlib.import_module("python.tools." + name) # Import the module - class_list = inspect.getmembers(module, inspect.isclass) # Get all functions in the module + if files.exists("python/tools", f"{name}.py"): + module = importlib.import_module( + "python.tools." + name + ) # Import the module + class_list = inspect.getmembers( + module, inspect.isclass + ) # Get all functions in the module for cls in class_list: if cls[1] is not Tool and issubclass(cls[1], Tool): @@ -367,33 +476,41 @@ def get_tool(self, name: str, args: dict, message: str, **kwargs): return tool_class(agent=self, name=name, args=args, message=message, **kwargs) - async def fetch_memories(self,reset_skip=False): - if self.config.auto_memory_count<=0: return "" - if reset_skip: self.memory_skip_counter = 0 + async def fetch_memories(self, reset_skip=False): + if self.config.auto_memory_count <= 0: + return "" + if reset_skip: + self.memory_skip_counter = 0 if self.memory_skip_counter > 0: - self.memory_skip_counter-=1 + self.memory_skip_counter -= 1 return "" else: self.memory_skip_counter = self.config.auto_memory_skip from python.tools import memory_tool + messages = self.concat_messages(self.history) - memories = memory_tool.search(self,messages) - input = { - "conversation_history" : messages, - "raw_memories": memories - } - cleanup_prompt = self.read_prompt("msg.memory_cleanup.md").replace("{", "{{") - clean_memories = await self.send_adhoc_message(cleanup_prompt,json.dumps(input), output_label="Memory injection") + memories = memory_tool.search(self, messages) + input = {"conversation_history": messages, "raw_memories": memories} + cleanup_prompt = self.read_prompt("msg.memory_cleanup.md").replace( + "{", "{{" + ) + clean_memories = await self.send_adhoc_message( + cleanup_prompt, json.dumps(input), output_label="Memory injection" + ) return clean_memories def log_from_stream(self, stream: str, logItem: Log.LogItem): try: - if len(stream) < 25: return # no reason to try + if len(stream) < 25: + return # no reason to try response = DirtyJson.parse_string(stream) - if isinstance(response, dict): logItem.update(content=stream, kvps=response) #log if result is a dictionary already + if isinstance(response, dict): + logItem.update( + content=stream, kvps=response + ) # log if result is a dictionary already except Exception as e: pass def call_extension(self, name: str, **kwargs) -> Any: - pass \ No newline at end of file + pass diff --git a/example.env b/example.env index 327dac344..1663a38be 100644 --- a/example.env +++ b/example.env @@ -1,4 +1,4 @@ -API_KEY_OPENAI= +API_KEY_OPENAI=sk-hyBlbkFJCJjaYGCbqPTyT3uaYGCbqFBlbkFJCyJCyuPhYGCb API_KEY_ANTHROPIC= API_KEY_GROQ= API_KEY_PERPLEXITY= @@ -16,4 +16,8 @@ WEB_UI_PORT=50001 TOKENIZERS_PARALLELISM=true -PYDEVD_DISABLE_FILE_VALIDATION=1 \ No newline at end of file +PYDEVD_DISABLE_FILE_VALIDATION=1 + +OLLAMA_BASE_URL="http://127.0.0.1:11434" +LM_STUDIO_BASE_URL="http://127.0.0.1:1234/v1" +OPEN_ROUTER_BASE_URL="https://openrouter.ai/api/v1" \ No newline at end of file diff --git a/initialize.py b/initialize.py index fc6bac5d2..e2c92401c 100644 --- a/initialize.py +++ b/initialize.py @@ -7,7 +7,7 @@ def initialize(): chat_llm = models.get_openai_chat(model_name="gpt-4o-mini", temperature=0) # chat_llm = models.get_ollama_chat(model_name="gemma2:latest", temperature=0) # chat_llm = models.get_lmstudio_chat(model_name="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", temperature=0) - # chat_llm = models.get_openrouter(model_name="meta-llama/llama-3-8b-instruct:free") + # chat_llm = models.get_openrouter_chat(model_name="nousresearch/hermes-3-llama-3.1-405b") # chat_llm = models.get_azure_openai_chat(deployment_name="gpt-4o-mini", temperature=0) # chat_llm = models.get_anthropic_chat(model_name="claude-3-5-sonnet-20240620", temperature=0) # chat_llm = models.get_google_chat(model_name="gemini-1.5-flash", temperature=0) diff --git a/models.py b/models.py index 33af468b3..55498278b 100644 --- a/models.py +++ b/models.py @@ -2,11 +2,12 @@ from dotenv import load_dotenv from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings, AzureChatOpenAI, AzureOpenAIEmbeddings, AzureOpenAI from langchain_community.llms.ollama import Ollama +from langchain_ollama import ChatOllama from langchain_community.embeddings import OllamaEmbeddings from langchain_anthropic import ChatAnthropic from langchain_groq import ChatGroq from langchain_huggingface import HuggingFaceEmbeddings -from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory +from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory from pydantic.v1.types import SecretStr @@ -22,11 +23,12 @@ def get_api_key(service): # Ollama models -def get_ollama_chat(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url="http://localhost:11434"): - return Ollama(model=model_name,temperature=temperature, base_url=base_url) +def get_ollama_chat(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434", num_ctx=8192): + return ChatOllama(model=model_name,temperature=temperature, base_url=base_url, num_ctx=num_ctx) -def get_ollama_embedding(model_name:str, temperature=DEFAULT_TEMPERATURE): - return OllamaEmbeddings(model=model_name,temperature=temperature) +def get_ollama_embedding(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"): + + return OllamaEmbeddings(model=model_name,temperature=temperature, base_url=base_url) # HuggingFace models @@ -34,63 +36,46 @@ def get_huggingface_embedding(model_name:str): return HuggingFaceEmbeddings(model_name=model_name) # LM Studio and other OpenAI compatible interfaces -def get_lmstudio_chat(model_name:str, base_url="http://localhost:1234/v1", temperature=DEFAULT_TEMPERATURE): +def get_lmstudio_chat(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("LM_STUDIO_BASE_URL") or "http://127.0.0.1:1234/v1"): return ChatOpenAI(model_name=model_name, base_url=base_url, temperature=temperature, api_key="none") # type: ignore -def get_lmstudio_embedding(model_name:str, base_url="http://localhost:1234/v1"): +def get_lmstudio_embedding(model_name:str, base_url=os.getenv("LM_STUDIO_BASE_URL") or "http://127.0.0.1:1234/v1"): return OpenAIEmbeddings(model_name=model_name, base_url=base_url) # type: ignore # Anthropic models -def get_anthropic_chat(model_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("anthropic") +def get_anthropic_chat(model_name:str, api_key=get_api_key("anthropic"), temperature=DEFAULT_TEMPERATURE): return ChatAnthropic(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore # OpenAI models -def get_openai_chat(model_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("openai") +def get_openai_chat(model_name:str, api_key=get_api_key("openai"), temperature=DEFAULT_TEMPERATURE): return ChatOpenAI(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore -def get_openai_instruct(model_name:str,api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("openai") +def get_openai_instruct(model_name:str, api_key=get_api_key("openai"), temperature=DEFAULT_TEMPERATURE): return OpenAI(model=model_name, temperature=temperature, api_key=api_key) # type: ignore -def get_openai_embedding(model_name:str, api_key=None): - api_key = api_key or get_api_key("openai") +def get_openai_embedding(model_name:str, api_key=get_api_key("openai")): return OpenAIEmbeddings(model=model_name, api_key=api_key) # type: ignore -def get_azure_openai_chat(deployment_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE, azure_endpoint=None): - api_key = api_key or get_api_key("openai_azure") - azure_endpoint = azure_endpoint or os.getenv("OPENAI_AZURE_ENDPOINT") +def get_azure_openai_chat(deployment_name:str, api_key=get_api_key("openai_azure"), temperature=DEFAULT_TEMPERATURE, azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")): return AzureChatOpenAI(deployment_name=deployment_name, temperature=temperature, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore -def get_azure_openai_instruct(deployment_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE, azure_endpoint=None): - api_key = api_key or get_api_key("openai_azure") - azure_endpoint = azure_endpoint or os.getenv("OPENAI_AZURE_ENDPOINT") +def get_azure_openai_instruct(deployment_name:str, api_key=get_api_key("openai_azure"), temperature=DEFAULT_TEMPERATURE, azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")): return AzureOpenAI(deployment_name=deployment_name, temperature=temperature, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore -def get_azure_openai_embedding(deployment_name:str, api_key=None, azure_endpoint=None): - api_key = api_key or get_api_key("openai_azure") - azure_endpoint = azure_endpoint or os.getenv("OPENAI_AZURE_ENDPOINT") +def get_azure_openai_embedding(deployment_name:str, api_key=get_api_key("openai_azure"), azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")): return AzureOpenAIEmbeddings(deployment_name=deployment_name, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore # Google models -def get_google_chat(model_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("google") - return ChatGoogleGenerativeAI(model=model_name, temperature=temperature, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE }) # type: ignore +def get_google_chat(model_name:str, api_key=get_api_key("google"), temperature=DEFAULT_TEMPERATURE): + return GoogleGenerativeAI(model=model_name, temperature=temperature, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE }) # type: ignore # Groq models -def get_groq_chat(model_name:str, api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("groq") +def get_groq_chat(model_name:str, api_key=get_api_key("groq"), temperature=DEFAULT_TEMPERATURE): return ChatGroq(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore # OpenRouter models -def get_openrouter(model_name: str="meta-llama/llama-3.1-8b-instruct:free", api_key=None, temperature=DEFAULT_TEMPERATURE): - api_key = api_key or get_api_key("openrouter") - return ChatOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1", model=model_name, temperature=temperature) # type: ignore - -def get_embedding_hf(model_name="sentence-transformers/all-MiniLM-L6-v2"): - return HuggingFaceEmbeddings(model_name=model_name) - -def get_embedding_openai(api_key=None): - api_key = api_key or get_api_key("openai") - return OpenAIEmbeddings(api_key=api_key) #type: ignore +def get_openrouter_chat(model_name: str, api_key=get_api_key("openrouter"), temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OPEN_ROUTER_BASE_URL") or "https://openrouter.ai/api/v1"): + return ChatOpenAI(api_key=api_key, model=model_name, temperature=temperature, base_url=base_url) # type: ignore + +def get_openrouter_embedding(model_name: str, api_key=get_api_key("openrouter"), base_url=os.getenv("OPEN_ROUTER_BASE_URL") or "https://openrouter.ai/api/v1"): + return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url) # type: ignore \ No newline at end of file diff --git a/python/helpers/defer.py b/python/helpers/defer.py index 274e00330..8ef474493 100644 --- a/python/helpers/defer.py +++ b/python/helpers/defer.py @@ -4,20 +4,21 @@ class DeferredTask: def __init__(self, func, *args, **kwargs): - self._loop = asyncio.new_event_loop() + self._loop: asyncio.AbstractEventLoop = None # type: ignore self._task = None self._future = Future() - self._task_initialized = threading.Event() # Event to signal task initialization + self._task_initialized = threading.Event() self._start_task(func, *args, **kwargs) def _start_task(self, func, *args, **kwargs): - def run_in_thread(loop, func, args, kwargs): - asyncio.set_event_loop(loop) - self._task = loop.create_task(self._run(func, *args, **kwargs)) - self._task_initialized.set() # Signal that the task has been initialized - loop.run_forever() + def run_in_thread(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._task = self._loop.create_task(self._run(func, *args, **kwargs)) + self._task_initialized.set() + self._loop.run_forever() - self._thread = threading.Thread(target=run_in_thread, args=(self._loop, func, args, kwargs)) + self._thread = threading.Thread(target=run_in_thread) self._thread.start() async def _run(self, func, *args, **kwargs): @@ -27,13 +28,16 @@ async def _run(self, func, *args, **kwargs): except Exception as e: self._future.set_exception(e) finally: - self._loop.call_soon_threadsafe(self._loop.stop) + self._loop.call_soon_threadsafe(self._cleanup) + + def _cleanup(self): + self._loop.stop() def is_ready(self): return self._future.done() async def result(self, timeout=None): - if not self._task_initialized.wait(timeout): # Wait until the task is initialized + if not self._task_initialized.wait(timeout): raise RuntimeError("Task was not initialized properly.") try: @@ -42,7 +46,7 @@ async def result(self, timeout=None): raise TimeoutError("The task did not complete within the specified timeout.") def result_sync(self, timeout=None): - if not self._task_initialized.wait(timeout): # Wait until the task is initialized + if not self._task_initialized.wait(timeout): raise RuntimeError("Task was not initialized properly.") try: @@ -58,8 +62,9 @@ def is_alive(self): return self._thread.is_alive() and not self._future.done() def __del__(self): - if self._loop.is_running(): - self._loop.call_soon_threadsafe(self._loop.stop) - if self._thread.is_alive(): + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._cleanup) + if self._thread and self._thread.is_alive(): self._thread.join() - self._loop.close() \ No newline at end of file + if self._loop: + self._loop.close() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 63d179149..b90ea7ad0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ ansio==0.0.1 python-dotenv==1.0.1 langchain-groq==0.1.6 langchain-huggingface==0.0.3 +langchain-ollama==0.1.3 langchain-openai==0.1.15 langchain-community==0.2.7 langchain-anthropic==0.1.19