diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index f13bfdbb985c..e257dc7f8ec5 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -87,6 +87,10 @@ jobs: --health-retries 5 ports: - 5432:5432 + mongodb: + image: mongodb/mongodb-atlas-local:latest + ports: + - 27017:27017 steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -104,6 +108,9 @@ jobs: - name: Install pgvector when on linux run: | pip install -e .[retrievechat-pgvector] + - name: Install mongodb when on linux + run: | + pip install -e .[retrievechat-mongodb] - name: Install unstructured when python-version is 3.9 and on linux if: matrix.python-version == '3.9' run: | diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py index 97411e9fc004..d07b4d15cb62 100644 --- a/autogen/agentchat/chat.py +++ b/autogen/agentchat/chat.py @@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite return chat_order +def _post_process_carryover_item(carryover_item): + if isinstance(carryover_item, str): + return carryover_item + elif isinstance(carryover_item, dict) and "content" in carryover_item: + return str(carryover_item["content"]) + else: + return str(carryover_item) + + def __post_carryover_processing(chat_info: Dict[str, Any]) -> None: iostream = IOStream.get_default() @@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None: UserWarning, ) print_carryover = ( - ("\n").join([t for t in chat_info["carryover"]]) + ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]]) if isinstance(chat_info["carryover"], list) else chat_info["carryover"] ) @@ -153,7 +162,7 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: For example: - `"sender"` - the sender agent. - `"recipient"` - the recipient agent. - - `"clear_history" (bool) - whether to clear the chat history with the agent. + - `"clear_history"` (bool) - whether to clear the chat history with the agent. Default is True. - `"silent"` (bool or None) - (Experimental) whether to print the messages in this conversation. Default is False. diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 20b6376d01d9..d7d49d6200ca 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -186,7 +186,8 @@ def get_docs_by_ids( ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"], ids will always be included. + If None, will include ["metadatas", "documents"], ids will always be included. This may differ + depending on the implementation. kwargs: dict | Additional keyword arguments. Returns: @@ -200,7 +201,7 @@ class VectorDBFactory: Factory class for creating vector databases. """ - PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"] + PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"] @staticmethod def create_vector_db(db_type: str, **kwargs) -> VectorDB: @@ -222,6 +223,10 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB: from .pgvectordb import PGVectorDB return PGVectorDB(**kwargs) + if db_type.lower() in ["mdb", "mongodb", "atlas"]: + from .mongodb import MongoDBAtlasVectorDB + + return MongoDBAtlasVectorDB(**kwargs) if db_type.lower() in ["qdrant", "qdrantdb"]: from .qdrant import QdrantVectorDB diff --git a/autogen/agentchat/contrib/vectordb/mongodb.py b/autogen/agentchat/contrib/vectordb/mongodb.py new file mode 100644 index 000000000000..2e0580fe826b --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/mongodb.py @@ -0,0 +1,553 @@ +from copy import deepcopy +from time import monotonic, sleep +from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union + +import numpy as np +from pymongo import MongoClient, UpdateOne, errors +from pymongo.collection import Collection +from pymongo.driver_info import DriverInfo +from pymongo.operations import SearchIndexModel +from sentence_transformers import SentenceTransformer + +from .base import Document, ItemID, QueryResults, VectorDB +from .utils import get_logger + +logger = get_logger(__name__) + +DEFAULT_INSERT_BATCH_SIZE = 100_000 +_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."] +_DELAY = 0.5 + + +def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]: + """Utility changes _id field from Collection into id for Document.""" + return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs] + + +class MongoDBAtlasVectorDB(VectorDB): + """ + A Collection object for MongoDB. + """ + + def __init__( + self, + connection_string: str = "", + database_name: str = "vector_db", + embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode, + collection_name: str = None, + index_name: str = "vector_index", + overwrite: bool = False, + wait_until_index_ready: float = None, + wait_until_document_ready: float = None, + ): + """ + Initialize the vector database. + + Args: + connection_string: str | The MongoDB connection string to connect to. Default is ''. + database_name: str | The name of the database. Default is 'vector_db'. + embedding_function: Callable | The embedding function used to generate the vector representation. + collection_name: str | The name of the collection to create for this vector database + Defaults to None + index_name: str | Index name for the vector database, defaults to 'vector_index' + overwrite: bool = False + wait_until_index_ready: float | None | Blocking call to wait until the + database indexes are ready. None, the default, means no wait. + wait_until_document_ready: float | None | Blocking call to wait until the + database indexes are ready. None, the default, means no wait. + """ + self.embedding_function = embedding_function + self.index_name = index_name + self._wait_until_index_ready = wait_until_index_ready + self._wait_until_document_ready = wait_until_document_ready + + # This will get the model dimension size by computing the embeddings dimensions + self.dimensions = self._get_embedding_size() + + try: + self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen")) + self.client.admin.command("ping") + logger.debug("Successfully created MongoClient") + except errors.ServerSelectionTimeoutError as err: + raise ConnectionError("Could not connect to MongoDB server") from err + + self.db = self.client[database_name] + logger.debug(f"Atlas Database name: {self.db.name}") + if collection_name: + self.active_collection = self.create_collection(collection_name, overwrite) + else: + self.active_collection = None + + def _is_index_ready(self, collection: Collection, index_name: str): + """Check for the index name in the list of available search indexes to see if the + specified index is of status READY + + Args: + collection (Collection): MongoDB Collection to for the search indexes + index_name (str): Vector Search Index name + + Returns: + bool : True if the index is present and READY false otherwise + """ + for index in collection.list_search_indexes(index_name): + if index["type"] == "vectorSearch" and index["status"] == "READY": + return True + return False + + def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"): + """Waits for the index action to be completed. Otherwise throws a TimeoutError. + + Timeout set on instantiation. + action: "create" or "delete" + """ + assert action in ["create", "delete"], f"{action=} must be create or delete." + start = monotonic() + while monotonic() - start < self._wait_until_index_ready: + if action == "create" and self._is_index_ready(collection, index_name): + return + elif action == "delete" and len(list(collection.list_search_indexes())) == 0: + return + sleep(_DELAY) + + raise TimeoutError(f"Index {self.index_name} is not ready!") + + def _wait_for_document(self, collection: Collection, index_name: str, doc: Document): + start = monotonic() + while monotonic() - start < self._wait_until_document_ready: + query_result = _vector_search( + embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(), + n_results=1, + collection=collection, + index_name=index_name, + ) + if query_result and query_result[0][0]["_id"] == doc["id"]: + return + sleep(_DELAY) + + raise TimeoutError(f"Document {self.index_name} is not ready!") + + def _get_embedding_size(self): + return len(self.embedding_function(_SAMPLE_SENTENCE)[0]) + + def list_collections(self): + """ + List the collections in the vector database. + + Returns: + List[str] | The list of collections. + """ + return self.db.list_collection_names() + + def create_collection( + self, + collection_name: str, + overwrite: bool = False, + get_or_create: bool = True, + ) -> Collection: + """ + Create a collection in the vector database and create a vector search index in the collection. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get or create the collection. Default is True + """ + if overwrite: + self.delete_collection(collection_name) + + if collection_name not in self.db.list_collection_names(): + # Create a new collection + coll = self.db.create_collection(collection_name) + self.create_index_if_not_exists(index_name=self.index_name, collection=coll) + return coll + + if get_or_create: + # The collection already exists, return it. + coll = self.db[collection_name] + self.create_index_if_not_exists(index_name=self.index_name, collection=coll) + return coll + else: + # get_or_create is False and the collection already exists, raise an error. + raise ValueError(f"Collection {collection_name} already exists.") + + def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None: + """ + Creates a vector search index on the specified collection in MongoDB. + + Args: + MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index". + collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None. + """ + if not self._is_index_ready(collection, index_name): + self.create_vector_search_index(collection, index_name) + + def get_collection(self, collection_name: str = None) -> Collection: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Collection | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.debug( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + self.active_collection = self.db[collection_name] + + return self.active_collection + + def delete_collection(self, collection_name: str) -> None: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + """ + for index in self.db[collection_name].list_search_indexes(): + self.db[collection_name].drop_search_index(index["name"]) + if self._wait_until_index_ready: + self._wait_for_index(self.db[collection_name], index["name"], "delete") + return self.db[collection_name].drop() + + def create_vector_search_index( + self, + collection: Collection, + index_name: Union[str, None] = "vector_index", + similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine", + ) -> None: + """Create a vector search index in the collection. + + Args: + collection: An existing Collection in the Atlas Database. + index_name: Vector Search Index name. + similarity: Algorithm used for measuring vector similarity. + kwargs: Additional keyword arguments. + + Returns: + None + """ + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.dimensions, + "path": "embedding", + "similarity": similarity, + }, + ] + }, + name=index_name, + type="vectorSearch", + ) + # Create the search index + try: + collection.create_search_index(model=search_index_model) + if self._wait_until_index_ready: + self._wait_for_index(collection, index_name, "create") + logger.debug(f"Search index {index_name} created successfully.") + except Exception as e: + logger.error( + f"Error creating search index: {e}. \n" + f"Your client must be connected to an Atlas cluster. " + f"You may have to manually create a Collection and Search Index " + f"if you are on a free/shared cluster." + ) + raise e + + def insert_docs( + self, + docs: List[Document], + collection_name: str = None, + upsert: bool = False, + batch_size=DEFAULT_INSERT_BATCH_SIZE, + **kwargs, + ) -> None: + """Insert Documents and Vector Embeddings into the collection of the vector database. + + For large numbers of Documents, insertion is performed in batches. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + batch_size: Number of documents to be inserted in each batch + """ + if not docs: + logger.info("No documents to insert.") + return + + collection = self.get_collection(collection_name) + if upsert: + self.update_docs(docs, collection.name, upsert=True) + else: + # Sanity checking the first document + if docs[0].get("content") is None: + raise ValueError("The document content is required.") + if docs[0].get("id") is None: + raise ValueError("The document id is required.") + + input_ids = set() + result_ids = set() + id_batch = [] + text_batch = [] + metadata_batch = [] + size = 0 + i = 0 + for doc in docs: + id = doc["id"] + text = doc["content"] + metadata = doc.get("metadata", {}) + id_batch.append(id) + text_batch.append(text) + metadata_batch.append(metadata) + id_size = 1 if isinstance(id, int) else len(id) + size += len(text) + len(metadata) + id_size + if (i + 1) % batch_size == 0 or size >= 47_000_000: + result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) + input_ids.update(id_batch) + id_batch = [] + text_batch = [] + metadata_batch = [] + size = 0 + i += 1 + if text_batch: + result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore + input_ids.update(id_batch) + + if result_ids != input_ids: + logger.warning( + "Possible data corruption. " + "input_ids not in result_ids: {in_diff}.\n" + "result_ids not in input_ids: {out_diff}".format( + in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids) + ) + ) + if self._wait_until_document_ready and docs: + self._wait_for_document(collection, self.index_name, docs[-1]) + + def _insert_batch( + self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID] + ) -> Set[ItemID]: + """Compute embeddings for and insert a batch of Documents into the Collection. + + For performance reasons, we chose to call self.embedding_function just once, + with the hopefully small tradeoff of having recreating Document dicts. + + Args: + collection: MongoDB Collection + texts: List of the main contents of each document + metadatas: List of metadata mappings + ids: List of ids. Note that these are stored as _id in Collection. + + Returns: + List of ids inserted. + """ + n_texts = len(texts) + if n_texts == 0: + return [] + # Embed and create the documents + embeddings = self.embedding_function(texts).tolist() + assert ( + len(embeddings) == n_texts + ), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})." + to_insert = [ + {"_id": i, "content": t, "metadata": m, "embedding": e} + for i, t, m, e in zip(ids, texts, metadatas, embeddings) + ] + # insert the documents in MongoDB Atlas + insert_result = collection.insert_many(to_insert) # type: ignore + return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs + + def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None: + """Update documents, including their embeddings, in the Collection. + + Optionally allow upsert as kwarg. + + Uses deepcopy to avoid changing docs. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection. + """ + + n_docs = len(docs) + logger.info(f"Preparing to embed and update {n_docs=}") + # Compute the embeddings + embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist() + # Prepare the updates + all_updates = [] + for i in range(n_docs): + doc = deepcopy(docs[i]) + doc["embedding"] = embeddings[i] + doc["_id"] = doc.pop("id") + + all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False))) + # Perform update in bulk + collection = self.get_collection(collection_name) + result = collection.bulk_write(all_updates) + + if self._wait_until_document_ready and docs: + self._wait_for_document(collection, self.index_name, docs[-1]) + + # Log a result summary + logger.info( + "Matched: %s, Modified: %s, Upserted: %s", + result.matched_count, + result.modified_count, + result.upserted_count, + ) + + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs): + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + """ + collection = self.get_collection(collection_name) + return collection.delete_many({"_id": {"$in": ids}}) + + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. + If None, will include ["metadata", "content"], ids will always be included. + Basically, use include to choose whether to include embedding and metadata + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + if include is None: + include_fields = {"_id": 1, "content": 1, "metadata": 1} + else: + include_fields = {k: 1 for k in set(include).union({"_id"})} + collection = self.get_collection(collection_name) + if ids is not None: + docs = collection.find({"_id": {"$in": ids}}, include_fields) + # Return with _id field from Collection into id for Document + return with_id_rename(docs) + else: + docs = collection.find({}, include_fields) + # Return with _id field from Collection into id for Document + return with_id_rename(docs) + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. Ones of importance follow: + oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm. + It determines the number of nearest neighbor candidates to consider during the search phase. + A higher value leads to more accuracy, but is slower. Default is 10 + + Returns: + QueryResults | For each query string, a list of nearest documents and their scores. + """ + collection = self.get_collection(collection_name) + # Trivial case of an empty collection + if collection.count_documents({}) == 0: + return [] + + logger.debug(f"Using index: {self.index_name}") + results = [] + for query_text in queries: + # Compute embedding vector from semantic query + logger.debug(f"Query: {query_text}") + query_vector = np.array(self.embedding_function([query_text])).tolist()[0] + # Find documents with similar vectors using the specified index + query_result = _vector_search( + query_vector, + n_results, + collection, + self.index_name, + distance_threshold, + **kwargs, + oversampling_factor=kwargs.get("oversampling_factor", 10), + ) + # Change each _id key to id. with_id_rename, but with (doc, score) tuples + results.append( + [({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result] + ) + return results + + +def _vector_search( + embedding_vector: List[float], + n_results: int, + collection: Collection, + index_name: str, + distance_threshold: float = -1.0, + oversampling_factor=10, + include_embedding=False, +) -> List[Tuple[Dict, float]]: + """Core $vectorSearch Aggregation pipeline. + + Args: + embedding_vector: Embedding vector of semantic query + n_results: Number of documents to return. Defaults to 4. + collection: MongoDB Collection with vector index + index_name: Name of the vector index + distance_threshold: Only distance measures smaller than this will be returned. + Don't filter with it if 1 < x < 0. Default is -1. + oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm. + It determines the number of nearest neighbor candidates to consider during the search phase. + A higher value leads to more accuracy, but is slower. Default = 10 + + Returns: + List of tuples of length n_results from Collection. + Each tuple contains a document dict and a score. + """ + + pipeline = [ + { + "$vectorSearch": { + "index": index_name, + "limit": n_results, + "numCandidates": n_results * oversampling_factor, + "queryVector": embedding_vector, + "path": "embedding", + } + }, + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + ] + if distance_threshold >= 0.0: + similarity_threshold = 1.0 - distance_threshold + pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}}) + + if not include_embedding: + pipeline.append({"$project": {"embedding": 0}}) + + logger.debug("pipeline: %s", pipeline) + agg = collection.aggregate(pipeline) + return [(doc, doc.pop("score")) for doc in agg] diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 81c666de022c..674c8b9248d7 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -11,6 +11,7 @@ from openai import BadRequestError +from autogen.agentchat.chat import _post_process_carryover_item from autogen.exception_utils import InvalidCarryOverType, SenderRequired from .._pydantic import model_dump @@ -1722,7 +1723,7 @@ def check_termination_and_human_reply( sender_name = "the sender" if sender is None else sender.name if self.human_input_mode == "ALWAYS": reply = self.get_human_input( - f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " ) no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" # if the human input is empty, and the message is a termination message, then we will terminate the conversation @@ -1835,7 +1836,7 @@ async def a_check_termination_and_human_reply( sender_name = "the sender" if sender is None else sender.name if self.human_input_mode == "ALWAYS": reply = await self.a_get_human_input( - f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " ) no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" # if the human input is empty, and the message is a termination message, then we will terminate the conversation @@ -2364,7 +2365,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str: if isinstance(kwargs["carryover"], str): content += "\nContext: \n" + kwargs["carryover"] elif isinstance(kwargs["carryover"], list): - content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]]) + content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]]) else: raise InvalidCarryOverType( "Carryover should be a string or a list of strings. Not adding carryover to the message." diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index 62078d42631d..8ed6f909e6bc 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -99,17 +99,11 @@ def __init__(self, **kwargs: Any): if not self._aws_secret_key: self._aws_secret_key = os.getenv("AWS_SECRET_KEY") - if not self._aws_session_token: - self._aws_session_token = os.getenv("AWS_SESSION_TOKEN") - if not self._aws_region: self._aws_region = os.getenv("AWS_REGION") if self._api_key is None and ( - self._aws_access_key is None - or self._aws_secret_key is None - or self._aws_session_token is None - or self._aws_region is None + self._aws_access_key is None or self._aws_secret_key is None or self._aws_region is None ): raise ValueError("API key or AWS credentials are required to use the Anthropic API.") diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 4e9d794a1f75..4cc7c697f738 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -454,12 +454,20 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) + def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + """Update openai_config with AWS credentials from config.""" + required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] + + for key in required_keys: + if key in config: + openai_config[key] = config[key] + def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Create a client with the given config to override openai_config, after removing extra kwargs. For Azure models/deployment names there's a convenience modification of model removing dots in - the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name + the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot from the name and create a client that connects to "gpt-35-turbo" Azure deployment. """ @@ -485,6 +493,8 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s client = GeminiClient(**openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("anthropic"): + if "api_key" not in config: + self._configure_openai_config_for_bedrock(config, openai_config) if anthropic_import_exception: raise ImportError("Please install `anthropic` to use Anthropic API.") client = AnthropicClient(**openai_config) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 73d41cddbf53..33790c9851c6 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -6,7 +6,7 @@ "config_list": [{ "api_type": "google", "model": "gemini-pro", - "api_key": os.environ.get("GOOGLE_API_KEY"), + "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"), "safety_settings": [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, @@ -32,6 +32,7 @@ from __future__ import annotations import base64 +import logging import os import random import re @@ -45,13 +46,19 @@ import vertexai from google.ai.generativelanguage import Content, Part from google.api_core.exceptions import InternalServerError +from google.auth.credentials import Credentials from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from PIL import Image from vertexai.generative_models import Content as VertexAIContent from vertexai.generative_models import GenerativeModel +from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold +from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import Part as VertexAIPart +from vertexai.generative_models import SafetySetting as VertexAISafetySetting + +logger = logging.getLogger(__name__) class GeminiClient: @@ -81,29 +88,36 @@ def _initialize_vertexai(self, **params): vertexai_init_args["project"] = params["project_id"] if "location" in params: vertexai_init_args["location"] = params["location"] + if "credentials" in params: + assert isinstance( + params["credentials"], Credentials + ), "Object type google.auth.credentials.Credentials is expected!" + vertexai_init_args["credentials"] = params["credentials"] if vertexai_init_args: vertexai.init(**vertexai_init_args) def __init__(self, **kwargs): """Uses either either api_key for authentication from the LLM config - (specifying the GOOGLE_API_KEY environment variable also works), + (specifying the GOOGLE_GEMINI_API_KEY environment variable also works), or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified, - where project_id and location can also be passed as parameters. Service account key file can also be used. - If neither a service account key file, nor the api_key are passed, then the default credentials will be used, - which could be a personal account if the user is already authenticated in, like in Google Cloud Shell. + where project_id and location can also be passed as parameters. Previously created credentials object can be provided, + or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed, + then the default credentials will be used, which could be a personal account if the user is already authenticated in, + like in Google Cloud Shell. Args: api_key (str): The API key for using Gemini. + credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai. google_application_credentials (str): Path to the JSON service account key file of the service account. - Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable - can also be set instead of using this argument. + Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable + can also be set instead of using this argument. project_id (str): Google Cloud project id, which is only valid in case no API key is specified. location (str): Compute region to be used, like 'us-west1'. - This parameter is only valid in case no API key is specified. + This parameter is only valid in case no API key is specified. """ self.api_key = kwargs.get("api_key", None) if not self.api_key: - self.api_key = os.getenv("GOOGLE_API_KEY") + self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY") if self.api_key is None: self.use_vertexai = True self._initialize_vertexai(**kwargs) @@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) stream = params.get("stream", False) n_response = params.get("n", 1) + system_instruction = params.get("system_instruction", None) + response_validation = params.get("response_validation", True) generation_config = { gemini_term: params[autogen_term] for autogen_term, gemini_term in self.PARAMS_MAPPING.items() if autogen_term in params } - safety_settings = params.get("safety_settings", {}) + if self.use_vertexai: + safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {})) + else: + safety_settings = params.get("safety_settings", {}) if stream: warnings.warn( @@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion: gemini_messages = self._oai_messages_to_gemini_messages(messages) if self.use_vertexai: model = GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, ) + chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) else: # we use chat model by default model = genai.GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, ) genai.configure(api_key=self.api_key) - chat = model.start_chat(history=gemini_messages[:-1]) + chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): ans = None try: - response = chat.send_message(gemini_messages[-1], stream=stream) + response = chat.send_message( + gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings + ) except InternalServerError: delay = 5 * (2**attempt) warnings.warn( @@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion: # B. handle the vision model if self.use_vertexai: model = GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, ) else: model = genai.GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, ) genai.configure(api_key=self.api_key) # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) - # response = chat.send_message(gemini_messages[-1]) + # response = chat.send_message(gemini_messages[-1].parts) user_message = self._oai_content_to_gemini_content(messages[-1]["content"]) if len(messages) > 2: warnings.warn( @@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: """Convert content from OAI format to Gemini format""" rst = [] if isinstance(content, str): + if content == "": + content = "empty" # Empty content is not allowed. if self.use_vertexai: rst.append(VertexAIPart.from_text(content)) else: @@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li return rst + @staticmethod + def _to_vertexai_safety_settings(safety_settings): + """Convert safety settings to VertexAI format if needed, + like when specifying them in the OAI_CONFIG_LIST + """ + if isinstance(safety_settings, list) and all( + [ + isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting) + for safety_setting in safety_settings + ] + ): + vertexai_safety_settings = [] + for safety_setting in safety_settings: + if safety_setting["category"] not in VertexAIHarmCategory.__members__: + invalid_category = safety_setting["category"] + logger.error(f"Safety setting category {invalid_category} is invalid") + elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__: + invalid_threshold = safety_setting["threshold"] + logger.error(f"Safety threshold {invalid_threshold} is invalid") + else: + vertexai_safety_setting = VertexAISafetySetting( + category=safety_setting["category"], + threshold=safety_setting["threshold"], + ) + vertexai_safety_settings.append(vertexai_safety_setting) + return vertexai_safety_settings + else: + return safety_settings + def _to_pil(data: str) -> Image.Image: """ diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index 749727d952c0..df70e01ff7df 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -13,7 +13,15 @@ from openai.types.beta.assistant import Assistant from packaging.version import parse -NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"] +NON_CACHE_KEY = [ + "api_key", + "base_url", + "api_type", + "api_version", + "azure_ad_token", + "azure_ad_token_provider", + "credentials", +] DEFAULT_AZURE_API_VERSION = "2024-02-01" OAI_PRICE1K = { # https://openai.com/api/pricing/ @@ -25,6 +33,9 @@ # gpt-4 "gpt-4": (0.03, 0.06), "gpt-4-32k": (0.06, 0.12), + # gpt-4o-mini + "gpt-4o-mini": (0.000150, 0.000600), + "gpt-4o-mini-2024-07-18": (0.000150, 0.000600), # gpt-3.5 turbo "gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125 "gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index 365285e09551..220007a2bd12 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -36,6 +36,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int: "gpt-4-vision-preview": 128000, "gpt-4o": 128000, "gpt-4o-2024-05-13": 128000, + "gpt-4o-mini": 128000, + "gpt-4o-mini-2024-07-18": 128000, } return max_token_limit[model] diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index 29e40fff384c..b5663fe4c578 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -4,7 +4,8 @@ - net8.0 + netstandard2.0;net6.0;net8.0 + net8.0 preview enable True diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj index 2948c9bf283c..fe7553b937f4 100644 --- a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj +++ b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj @@ -2,7 +2,7 @@ Exe - $(TestTargetFramework) + $(TestTargetFrameworks) enable enable True diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj index 6f55a04592f5..d4323ee4c924 100644 --- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj +++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj @@ -2,7 +2,7 @@ Exe - $(TestTargetFramework) + $(TestTargetFrameworks) enable True $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110 diff --git a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj index b1779b56c390..d1df8a8ed161 100644 --- a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj +++ b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + $(TestTargetFrameworks) enable enable true diff --git a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj index 5277408d595d..62c9d61633c9 100644 --- a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj +++ b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj @@ -1,7 +1,7 @@  Exe - $(TestTargetFramework) + $(TestTargetFrameworks) enable True $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110 diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj index ffe18f8a616a..49c0e21c9ece 100644 --- a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj +++ b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + $(TestTargetFrameworks) enable enable True diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj index 6c2266512929..df1064e18c44 100644 --- a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj +++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj @@ -2,7 +2,7 @@ Exe - $(TestTargetFramework) + $(TestTargetFrameworks) True $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110 enable diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj index 41f3b7d1d381..76675ba12346 100644 --- a/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj +++ b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj @@ -1,7 +1,7 @@  - net8.0 + $(TestTargetFrameworks) enable enable diff --git a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj index fefc439e00ba..a4fd32e7e345 100644 --- a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj +++ b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj @@ -1,8 +1,8 @@  - netstandard2.0 - AutoGen.Anthropic + $(PackageTargetFrameworks) + AutoGen.Anthropic diff --git a/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs index cd95d837cffd..68b3c14bdee6 100644 --- a/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs +++ b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs @@ -29,7 +29,7 @@ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerial public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) { var field = value.GetType().GetField(value.ToString()); - var attribute = field.GetCustomAttribute(); + var attribute = field?.GetCustomAttribute(); if (attribute != null) { diff --git a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj index 60aeb3ae3fca..8cf9e9183d40 100644 --- a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj +++ b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj @@ -1,6 +1,6 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.Core @@ -17,7 +17,10 @@ - + + + + diff --git a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj index 72c67fe78016..96b331f2df3b 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj +++ b/dotnet/src/AutoGen.DotnetInteractive/AutoGen.DotnetInteractive.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) enable enable AutoGen.DotnetInteractive diff --git a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs index 1ca19fcbcfff..3797dfcff649 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs @@ -104,6 +104,11 @@ public async Task StartAsync(string workingDirectory, CancellationToken ct public bool RestoreDotnetInteractive() { + if (this.installingDirectory is null) + { + throw new Exception("Installing directory is not set"); + } + this.WriteLine("Restore dotnet interactive tool"); // write RestoreInteractive.config from embedded resource to this.workingDirectory var assembly = Assembly.GetAssembly(typeof(InteractiveService))!; diff --git a/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj index 29c4d1bb9c6f..9a60596503bc 100644 --- a/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj +++ b/dotnet/src/AutoGen.Gemini/AutoGen.Gemini.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) diff --git a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj index f45a2f7eba5f..8725d564df41 100644 --- a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj +++ b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.LMStudio diff --git a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs index 9d0daa535b23..c3930abc0def 100644 --- a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs +++ b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs @@ -80,7 +80,7 @@ protected override Task SendAsync(HttpRequestMessage reques { // request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}"); var uriBuilder = new UriBuilder(_modelServiceUrl); - uriBuilder.Path = request.RequestUri.PathAndQuery; + uriBuilder.Path = request.RequestUri?.PathAndQuery ?? throw new InvalidOperationException("RequestUri is null"); request.RequestUri = uriBuilder.Uri; return base.SendAsync(request, cancellationToken); } diff --git a/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj b/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj index 25cc05fec922..ee905d117791 100644 --- a/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj +++ b/dotnet/src/AutoGen.Mistral/AutoGen.Mistral.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.Mistral diff --git a/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs b/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs index 5a4f9f9cb189..9ecf11428397 100644 --- a/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs +++ b/dotnet/src/AutoGen.Mistral/Converters/JsonPropertyNameEnumConverter.cs @@ -29,7 +29,7 @@ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerial public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) { var field = value.GetType().GetField(value.ToString()); - var attribute = field.GetCustomAttribute(); + var attribute = field?.GetCustomAttribute(); if (attribute != null) { diff --git a/dotnet/src/AutoGen.Mistral/MistralClient.cs b/dotnet/src/AutoGen.Mistral/MistralClient.cs index 5fc3d110985e..8c6802f30eb1 100644 --- a/dotnet/src/AutoGen.Mistral/MistralClient.cs +++ b/dotnet/src/AutoGen.Mistral/MistralClient.cs @@ -49,7 +49,7 @@ public async IAsyncEnumerable StreamingChatCompletionsAs var response = await HttpRequestRaw(HttpMethod.Post, chatCompletionRequest, streaming: true); using var stream = await response.Content.ReadAsStreamAsync(); using StreamReader reader = new StreamReader(stream); - string line; + string? line = null; SseEvent currentEvent = new SseEvent(); while ((line = await reader.ReadLineAsync()) != null) @@ -67,13 +67,13 @@ public async IAsyncEnumerable StreamingChatCompletionsAs else if (currentEvent.EventType == null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data))) ?? throw new Exception("Failed to deserialize response"); + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data ?? string.Empty))) ?? throw new Exception("Failed to deserialize response"); yield return res; } else if (currentEvent.EventType != null) { var res = await JsonSerializer.DeserializeAsync( - new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data))); + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data ?? string.Empty))); throw new Exception(res?.Error.Message); } diff --git a/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj index a939f138c1c1..512fe92f3e3e 100644 --- a/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj +++ b/dotnet/src/AutoGen.Ollama/AutoGen.Ollama.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.Ollama True diff --git a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs index 3919b238d659..9e85ca12fd9e 100644 --- a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs +++ b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs @@ -101,7 +101,7 @@ private IEnumerable ProcessMultiModalMessage(MultiModalMessage multiMo // collect all the images var images = imageMessages.SelectMany(m => ProcessImageMessage((ImageMessage)m, agent) - .SelectMany(m => (m as IMessage)?.Content.Images)); + .SelectMany(m => (m as IMessage)?.Content.Images ?? [])); var message = new Message() { diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index b192cde1024b..c957801f0238 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using AutoGen.OpenAI.Extension; @@ -32,13 +33,8 @@ namespace AutoGen.OpenAI; public class OpenAIChatAgent : IStreamingAgent { private readonly OpenAIClient openAIClient; - private readonly string modelName; - private readonly float _temperature; - private readonly int _maxTokens = 1024; - private readonly IEnumerable? _functions; - private readonly string _systemMessage; - private readonly ChatCompletionsResponseFormat? _responseFormat; - private readonly int? _seed; + private readonly ChatCompletionsOptions options; + private readonly string systemMessage; /// /// Create a new instance of . @@ -62,16 +58,36 @@ public OpenAIChatAgent( int? seed = null, ChatCompletionsResponseFormat? responseFormat = null, IEnumerable? functions = null) + : this( + openAIClient: openAIClient, + name: name, + options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions), + systemMessage: systemMessage) { + } + + /// + /// Create a new instance of . + /// + /// openai client + /// agent name + /// system message + /// chat completion option. The option can't contain messages + public OpenAIChatAgent( + OpenAIClient openAIClient, + string name, + ChatCompletionsOptions options, + string systemMessage = "You are a helpful AI assistant") + { + if (options.Messages is { Count: > 0 }) + { + throw new ArgumentException("Messages should not be provided in options"); + } + this.openAIClient = openAIClient; - this.modelName = modelName; this.Name = name; - _temperature = temperature; - _maxTokens = maxTokens; - _functions = functions; - _systemMessage = systemMessage; - _responseFormat = responseFormat; - _seed = seed; + this.options = options; + this.systemMessage = systemMessage; } public string Name { get; } @@ -116,22 +132,35 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions // add system message if there's no system message in messages if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) { - oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages); + oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages); } - var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) + // clone the options by serializing and deserializing + var json = JsonSerializer.Serialize(this.options); + var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options"); + + foreach (var m in oaiMessages) { - MaxTokens = options?.MaxToken ?? _maxTokens, - Temperature = options?.Temperature ?? _temperature, - ResponseFormat = _responseFormat, - Seed = _seed, - }; + settings.Messages.Add(m); + } + + settings.Temperature = options?.Temperature ?? settings.Temperature; + settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens; - var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()); - var functions = openAIFunctionDefinitions ?? _functions; - if (functions is not null && functions.Count() > 0) + foreach (var functions in this.options.Tools) { - foreach (var f in functions) + settings.Tools.Add(functions); + } + + foreach (var stopSequence in this.options.StopSequences) + { + settings.StopSequences.Add(stopSequence); + } + + var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()).ToList(); + if (openAIFunctionDefinitions is { Count: > 0 }) + { + foreach (var f in openAIFunctionDefinitions) { settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); } @@ -147,4 +176,31 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions return settings; } + + private static ChatCompletionsOptions CreateChatCompletionOptions( + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + { + var options = new ChatCompletionsOptions(modelName, []) + { + Temperature = temperature, + MaxTokens = maxTokens, + Seed = seed, + ResponseFormat = responseFormat, + }; + + if (functions is not null) + { + foreach (var f in functions) + { + options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + return options; + } } diff --git a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj index 7220cfe5c628..e3a2f41c8f7a 100644 --- a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj +++ b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj @@ -1,6 +1,6 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.OpenAI diff --git a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj index 3bd96f93b687..1cc4d8e127a7 100644 --- a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj +++ b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen.SemanticKernel $(NoWarn);SKEXP0110 diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index ddc34a071cbf..3cb5a23da14c 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -1,6 +1,6 @@  - netstandard2.0 + $(PackageTargetFrameworks) AutoGen diff --git a/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs index 1a742b11c799..eda3c001a249 100644 --- a/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs +++ b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs @@ -18,7 +18,7 @@ public class HumanInputMiddleware : IMiddleware private readonly string prompt; private readonly string exitKeyword; private Func, CancellationToken, Task> isTermination; - private Func getInput = Console.ReadLine; + private Func getInput = Console.ReadLine; private Action writeLine = Console.WriteLine; public string? Name => nameof(HumanInputMiddleware); @@ -27,7 +27,7 @@ public HumanInputMiddleware( string exitKeyword = "exit", HumanInputMode mode = HumanInputMode.AUTO, Func, CancellationToken, Task>? isTermination = null, - Func? getInput = null, + Func? getInput = null, Action? writeLine = null) { this.prompt = prompt; @@ -56,6 +56,8 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); } + input ??= string.Empty; + return new TextMessage(Role.Assistant, input, agent.Name); } @@ -74,6 +76,8 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); } + input ??= string.Empty; + return new TextMessage(Role.Assistant, input, agent.Name); } @@ -85,7 +89,7 @@ private async Task DefaultIsTermination(IEnumerable messages, Ca return messages?.Last().IsGroupChatTerminateMessage() is true; } - private string GetInput() + private string? GetInput() { return Console.ReadLine(); } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj index ac479ed2e722..ac9617c1a573 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj +++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj @@ -1,7 +1,7 @@ - $(TestTargetFramework) + $(TestTargetFrameworks) enable false True diff --git a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj index 0f77db2c1c36..7f7001a877d1 100644 --- a/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj +++ b/dotnet/test/AutoGen.DotnetInteractive.Tests/AutoGen.DotnetInteractive.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) enable false True diff --git a/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj index f4fb55825e54..0b9b7e2a24b0 100644 --- a/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj +++ b/dotnet/test/AutoGen.Gemini.Tests/AutoGen.Gemini.Tests.csproj @@ -2,7 +2,7 @@ Exe - $(TestTargetFramework) + $(TestTargetFrameworks) enable enable True diff --git a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj index d734119dbb09..aa20a835e9b9 100644 --- a/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj +++ b/dotnet/test/AutoGen.Mistral.Tests/AutoGen.Mistral.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) enable false True diff --git a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj index 1e26b38d8a4f..c5ca19556244 100644 --- a/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj +++ b/dotnet/test/AutoGen.Ollama.Tests/AutoGen.Ollama.Tests.csproj @@ -1,7 +1,7 @@ - $(TestTargetFramework) + $(TestTargetFrameworks) enable false True diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj index 04800a631ee6..b176bc3e6ac2 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj +++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) false True True diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs index 8ff66f5c86bf..85f898547b00 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs @@ -28,10 +28,8 @@ public async Task GetWeatherAsync(string location) [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task BasicConversationTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -60,10 +58,8 @@ public async Task BasicConversationTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatMessageContentConnectorTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -107,10 +103,8 @@ public async Task OpenAIChatMessageContentConnectorTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatAgentToolCallTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -176,10 +170,8 @@ public async Task OpenAIChatAgentToolCallTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatAgentToolCallInvokingTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -236,4 +228,52 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync() } } } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync() + { + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionsOptions(deployName, []) + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var respond = await openAIChatAgent.SendAsync("hello"); + respond.GetContent()?.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages() + { + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionsOptions(deployName, [new ChatRequestUserMessage("hi")]) + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var action = () => new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + action.Should().ThrowExactly().WithMessage("Messages should not be provided in options"); + } + + private OpenAIClient CreateOpenAIClientFromAzureOpenAI() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + return new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + } } diff --git a/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj b/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj index 8be4b55b1722..7f42b67da715 100644 --- a/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj +++ b/dotnet/test/AutoGen.SemanticKernel.Tests/AutoGen.SemanticKernel.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) enable false $(NoWarn);SKEXP0110 diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj b/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj index 2e0ead045bef..f7d814a6cdef 100644 --- a/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/AutoGen.SourceGenerator.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) enable false True diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj index 3dc669b5edd8..ce968b91f556 100644 --- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj +++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) True True $(NoWarn);xUnit1013;SKEXP0110 diff --git a/dotnet/test/AutoGen.WebAPI.Tests/AutoGen.WebAPI.Tests.csproj b/dotnet/test/AutoGen.WebAPI.Tests/AutoGen.WebAPI.Tests.csproj index 3a9caf38fc8e..7ec6c408cfe8 100644 --- a/dotnet/test/AutoGen.WebAPI.Tests/AutoGen.WebAPI.Tests.csproj +++ b/dotnet/test/AutoGen.WebAPI.Tests/AutoGen.WebAPI.Tests.csproj @@ -1,7 +1,7 @@  - $(TestTargetFramework) + $(TestTargetFrameworks) enable enable false diff --git a/notebook/agentchat_RetrieveChat_mongodb.ipynb b/notebook/agentchat_RetrieveChat_mongodb.ipynb new file mode 100644 index 000000000000..18494e28401d --- /dev/null +++ b/notebook/agentchat_RetrieveChat_mongodb.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using RetrieveChat Powered by MongoDB Atlas for Retrieve Augmented Code Generation and Question Answering\n", + "\n", + "AutoGen offers conversable agents powered by LLM, tool or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n", + "Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n", + "\n", + "RetrieveChat is a conversational system for retrieval-augmented code generation and question answering. In this notebook, we demonstrate how to utilize RetrieveChat to generate code and answer questions based on customized documentations that are not present in the LLM's training dataset. RetrieveChat uses the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`, which is similar to the usage of `AssistantAgent` and `UserProxyAgent` in other notebooks (e.g., [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)). Essentially, `RetrieveAssistantAgent` and `RetrieveUserProxyAgent` implement a different auto-reply mechanism corresponding to the RetrieveChat prompts.\n", + "\n", + "## Table of Contents\n", + "We'll demonstrate six examples of using RetrieveChat for code generation and question answering:\n", + "\n", + "- [Example 1: Generate code based off docstrings w/o human feedback](#example-1)\n", + "\n", + "````{=mdx}\n", + ":::info Requirements\n", + "Some extra dependencies are needed for this notebook, which can be installed via pip:\n", + "\n", + "```bash\n", + "pip install pyautogen[retrievechat-mongodb] flaml[automl]\n", + "```\n", + "\n", + "For more information, please refer to the [installation guide](/docs/installation/).\n", + ":::\n", + "````\n", + "\n", + "Ensure you have a MongoDB Atlas instance with Cluster Tier >= M30." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set your API Endpoint\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "models to use: ['gpt-3.5-turbo-0125']\n" + ] + } + ], + "source": [ + "import json\n", + "import os\n", + "\n", + "import autogen\n", + "from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n", + "from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n", + "\n", + "# Accepted file formats for that can be stored in\n", + "# a vector database instance\n", + "from autogen.retrieve_utils import TEXT_FORMATS\n", + "\n", + "config_list = [{\"model\": \"gpt-3.5-turbo-0125\", \"api_key\": os.environ[\"OPENAI_API_KEY\"], \"api_type\": \"openai\"}]\n", + "assert len(config_list) > 0\n", + "print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````{=mdx}\n", + ":::tip\n", + "Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).\n", + ":::\n", + "````\n", + "\n", + "## Construct agents for RetrieveChat\n", + "\n", + "We start by initializing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to \"You are a helpful assistant.\" for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.message_generator` to combine the instructions and a retrieval augmented generation task for an initial prompt to be sent to the LLM assistant." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accepted file formats for `docs_path`:\n", + "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" + ] + } + ], + "source": [ + "print(\"Accepted file formats for `docs_path`:\")\n", + "print(TEXT_FORMATS)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", + "assistant = RetrieveAssistantAgent(\n", + " name=\"assistant\",\n", + " system_message=\"You are a helpful assistant.\",\n", + " llm_config={\n", + " \"timeout\": 600,\n", + " \"cache_seed\": 42,\n", + " \"config_list\": config_list,\n", + " },\n", + ")\n", + "\n", + "# 2. create the RetrieveUserProxyAgent instance named \"ragproxyagent\"\n", + "# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n", + "# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default,\n", + "# it is set to None, which works only if the collection is already created.\n", + "# `task` indicates the kind of task we're working on. In this example, it's a `code` task.\n", + "# `chunk_token_size` is the chunk token size for the retrieve chat. By default, it is set to `max_tokens * 0.6`, here we set it to 2000.\n", + "# `custom_text_types` is a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.\n", + "# This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.\n", + "# In this example, we set it to [\"non-existent-type\"] to only process markdown files. Since no \"non-existent-type\" files are included in the `websit/docs`,\n", + "# no files there will be processed. However, the explicitly included urls will still be processed.\n", + "# **NOTE** Upon the first time adding in the documents, initial query may be slower due to index creation and document indexing time\n", + "ragproxyagent = RetrieveUserProxyAgent(\n", + " name=\"ragproxyagent\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=3,\n", + " retrieve_config={\n", + " \"task\": \"code\",\n", + " \"docs_path\": [\n", + " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n", + " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n", + " os.path.join(os.path.abspath(\"\"), \"..\", \"website\", \"docs\"),\n", + " ],\n", + " \"custom_text_types\": [\"non-existent-type\"],\n", + " \"chunk_token_size\": 2000,\n", + " \"model\": config_list[0][\"model\"],\n", + " \"vector_db\": \"mongodb\", # MongoDB Atlas database\n", + " \"collection_name\": \"demo_collection\",\n", + " \"db_config\": {\n", + " \"connection_string\": os.environ[\"MONGODB_URI\"], # MongoDB Atlas connection string\n", + " \"database_name\": \"test_db\", # MongoDB Atlas database\n", + " \"index_name\": \"vector_index\",\n", + " \"wait_until_index_ready\": 120.0, # Setting to wait 120 seconds or until index is constructed before querying\n", + " \"wait_until_document_ready\": 120.0, # Setting to wait 120 seconds or until document is properly indexed after insertion/update\n", + " },\n", + " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n", + " \"overwrite\": False, # set to True if you want to overwrite an existing collection, each overwrite will force a index creation and reupload of documents\n", + " },\n", + " code_execution_config=False, # set to False if you don't want to execute the code\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 1\n", + "\n", + "[Back to top](#table-of-contents)\n", + "\n", + "Use RetrieveChat to help generate sample code and automatically run the code and fix errors if there is any.\n", + "\n", + "Problem: Which API should I use if I want to use FLAML for a classification task and I want to train the model in 30 seconds. Use spark to parallel the training. Force cancel jobs if time limit is reached." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-07-25 13:47:30,700 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `demo_collection`.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to create collection.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-07-25 13:47:31,048 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", + "2024-07-25 13:47:31,051 - autogen.agentchat.contrib.vectordb.mongodb - INFO - No documents to insert.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", + "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n", + "\n", + "Context is: # Integrate - Spark\n", + "\n", + "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", + "- Use Spark ML estimators for AutoML.\n", + "- Use Spark to run training in parallel spark jobs.\n", + "\n", + "## Spark ML Estimators\n", + "\n", + "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", + "\n", + "### Data\n", + "\n", + "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", + "\n", + "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", + "\n", + "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", + "- `index_col` is the column name to use as the index, default is None.\n", + "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", + "\n", + "Here is an example code snippet for Spark Data:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", + "# Creating a dictionary\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "```\n", + "\n", + "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", + "\n", + "Here is an example of how to use it:\n", + "\n", + "```python\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "```\n", + "\n", + "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", + "\n", + "### Estimators\n", + "\n", + "#### Model List\n", + "\n", + "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", + "\n", + "#### Usage\n", + "\n", + "First, prepare your data in the required format as described in the previous section.\n", + "\n", + "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", + "\n", + "Here is an example code snippet using SparkML models in AutoML:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "# prepare your data in pandas-on-spark format as we previously mentioned\n", + "\n", + "automl = flaml.AutoML()\n", + "settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", + " \"task\": \"regression\",\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=psdf,\n", + " label=label,\n", + " **settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", + "\n", + "## Parallel Spark Jobs\n", + "\n", + "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", + "\n", + "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "\n", + "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", + "\n", + "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", + "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", + "\n", + "An example code snippet for using parallel Spark jobs:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "automl_experiment = flaml.AutoML()\n", + "automl_settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2,\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=dataframe,\n", + " label=label,\n", + " **automl_settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", + "# Research\n", + "\n", + "For technical details, please check our research publications.\n", + "\n", + "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021flaml,\n", + " title={FLAML: A Fast and Lightweight AutoML Library},\n", + " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", + " year={2021},\n", + " booktitle={MLSys},\n", + "}\n", + "```\n", + "\n", + "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021cfo,\n", + " title={Frugal Optimization for Cost-related Hyperparameters},\n", + " author={Qingyun Wu and Chi Wang and Silu Huang},\n", + " year={2021},\n", + " booktitle={AAAI},\n", + "}\n", + "```\n", + "\n", + "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021blendsearch,\n", + " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", + " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", + " year={2021},\n", + " booktitle={ICLR},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{liuwang2021hpolm,\n", + " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", + " author={Susan Xueqing Liu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ACL},\n", + "}\n", + "```\n", + "\n", + "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021chacha,\n", + " title={ChaCha for Online AutoML},\n", + " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", + " year={2021},\n", + " booktitle={ICML},\n", + "}\n", + "```\n", + "\n", + "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", + "\n", + "```bibtex\n", + "@inproceedings{wuwang2021fairautoml,\n", + " title={Fair AutoML},\n", + " author={Qingyun Wu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ArXiv preprint arXiv:2111.06495},\n", + "}\n", + "```\n", + "\n", + "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", + "\n", + "```bibtex\n", + "@inproceedings{kayaliwang2022default,\n", + " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", + " author={Moe Kayali and Chi Wang},\n", + " year={2022},\n", + " booktitle={ArXiv preprint arXiv:2202.09927},\n", + "}\n", + "```\n", + "\n", + "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", + "\n", + "```bibtex\n", + "@inproceedings{zhang2023targeted,\n", + " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", + " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", + " booktitle={International Conference on Learning Representations},\n", + " year={2023},\n", + " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", + "}\n", + "```\n", + "\n", + "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2023EcoOptiGen,\n", + " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", + " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2303.04673},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2023empirical,\n", + " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", + " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2306.01337},\n", + "}\n", + "```\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "To use FLAML to perform a classification task and use Spark for parallel training with a timeout of 30 seconds and force canceling jobs if the time limit is reached, you can follow the below code snippet:\n", + "\n", + "```python\n", + "import flaml\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "# Prepare your data in pandas-on-spark format\n", + "data = {\n", + " \"feature1\": [val1, val2, val3, val4],\n", + " \"feature2\": [val5, val6, val7, val8],\n", + " \"target\": [class1, class2, class1, class2],\n", + "}\n", + "\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"target\"\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "\n", + "# Prepare your features using VectorAssembler\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf)\n", + "\n", + "# Define AutoML settings and fit the model\n", + "automl = flaml.AutoML()\n", + "settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"accuracy\",\n", + " \"task\": \"classification\",\n", + " \"estimator_list\": [\"lgbm_spark\"], # Optional\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=psdf,\n", + " label=label,\n", + " **settings,\n", + ")\n", + "```\n", + "\n", + "In the code:\n", + "- Replace `val1, val2, ..., class1, class2` with your actual data values.\n", + "- Ensure the features and target columns are correctly specified in the data dictionary.\n", + "- Set the `time_budget` parameter to 30 to limit the training time.\n", + "- The `force_cancel` parameter is set to `True` to force cancel Spark jobs if the time limit is exceeded.\n", + "\n", + "Make sure to adapt the code to your specific dataset and requirements.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "UPDATE CONTEXT\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n", + "\u001b[32mNo more context, will terminate.\u001b[0m\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "TERMINATE\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# reset the assistant. Always reset the assistant before starting a new conversation.\n", + "assistant.reset()\n", + "\n", + "# given a problem, we use the ragproxyagent to generate a prompt to be sent to the assistant as the initial message.\n", + "# the assistant receives the message and generates a response. The response will be sent back to the ragproxyagent for processing.\n", + "# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n", + "# With human-in-loop, the conversation will continue until the user says \"exit\".\n", + "code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n", + "chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)" + ] + } + ], + "metadata": { + "front_matter": { + "description": "Explore the use of AutoGen's RetrieveChat for tasks like code generation from docstrings, answering complex questions with human feedback, and exploiting features like Update Context, custom prompts, and few-shot learning.", + "tags": [ + "RAG" + ] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "skip_test": "Requires interactive usage" + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebook/agentchat_society_of_mind.ipynb b/notebook/agentchat_society_of_mind.ipynb index 79e5990a2aff..df3a6c543397 100644 --- a/notebook/agentchat_society_of_mind.ipynb +++ b/notebook/agentchat_society_of_mind.ipynb @@ -57,7 +57,7 @@ "\n", "### Example Group Chat with Two Agents\n", "\n", - "In this example, we will use an AssistantAgent and a UserProxy agent (configured for code execution) to work together to solve a problem. Executing code requires *at least* two conversation turns (one to write the code, and one to execute the code). If the code fails, or needs further refinement, then additional turns may also be needed. When will then wrap these agents in a SocietyOfMindAgent, hiding the internal discussion from other agents (though will still appear in the console), and ensuring that the response is suitable as a standalone message." + "In this example, we will use an AssistantAgent and a UserProxy agent (configured for code execution) to work together to solve a problem. Executing code requires *at least* two conversation turns (one to write the code, and one to execute the code). If the code fails, or needs further refinement, then additional turns may also be needed. We will then wrap these agents in a SocietyOfMindAgent, hiding the internal discussion from other agents (though will still appear in the console), and ensuring that the response is suitable as a standalone message." ] }, { diff --git a/samples/apps/cap/py/README.md b/samples/apps/cap/py/README.md index a885e001e155..2e11d38330d0 100644 --- a/samples/apps/cap/py/README.md +++ b/samples/apps/cap/py/README.md @@ -3,7 +3,7 @@ ## I just want to run the remote AutoGen agents! *Python Instructions (Windows, Linux, MacOS):* -pip install autogencap-rajan.jedi +pip install autogencap 1) AutoGen require OAI_CONFIG_LIST. AutoGen python requirements: 3.8 <= python <= 3.11 diff --git a/samples/apps/cap/py/autogencap/ActorConnector.py b/samples/apps/cap/py/autogencap/ActorConnector.py index 588d4c463b62..1595f641fc83 100644 --- a/samples/apps/cap/py/autogencap/ActorConnector.py +++ b/samples/apps/cap/py/autogencap/ActorConnector.py @@ -120,6 +120,16 @@ def send_txt_msg(self, msg): def send_bin_msg(self, msg_type: str, msg): self._sender.send_bin_msg(msg_type, msg) + def send_proto_msg(self, msg): + bin_msg = msg.SerializeToString() + class_type = type(msg) + self._sender.send_bin_msg(class_type.__name__, bin_msg) + + def send_recv_proto_msg(self, msg, num_attempts=5): + bin_msg = msg.SerializeToString() + class_type = type(msg) + return self.send_recv_msg(class_type.__name, bin_msg, num_attempts) + def send_recv_msg(self, msg_type: str, msg, num_attempts=5): original_timeout: int = 0 if num_attempts == -1: diff --git a/samples/apps/cap/py/autogencap/DebugLog.py b/samples/apps/cap/py/autogencap/DebugLog.py index c3d6ca421276..f8a3f209ee3a 100644 --- a/samples/apps/cap/py/autogencap/DebugLog.py +++ b/samples/apps/cap/py/autogencap/DebugLog.py @@ -34,19 +34,25 @@ def WriteLog(self, level, context, msg): class ConsoleLogger(BaseLogger): - def __init__(self): + def __init__(self, use_color=True): super().__init__() + self._use_color = use_color + + def _colorize(self, msg, color): + if self._use_color: + return colored(msg, color) + return msg def WriteLog(self, level, context, msg): - timestamp = colored(datetime.datetime.now().strftime("%m/%d/%y %H:%M:%S"), "dark_grey") + timestamp = self._colorize(datetime.datetime.now().strftime("%m/%d/%y %H:%M:%S"), "dark_grey") # Translate level number to name and color - level_name = colored(LEVEL_NAMES[level], LEVEL_COLOR[level]) + level_name = self._colorize(LEVEL_NAMES[level], LEVEL_COLOR[level]) # Left justify the context and color it blue - context = colored(context.ljust(14), "blue") + context = self._colorize(context.ljust(14), "blue") # Left justify the threadid and color it blue - thread_id = colored(str(threading.get_ident()).ljust(5), "blue") + thread_id = self._colorize(str(threading.get_ident()).ljust(5), "blue") # color the msg based on the level - msg = colored(msg, LEVEL_COLOR[level]) + msg = self._colorize(msg, LEVEL_COLOR[level]) print(f"{thread_id} {timestamp} {level_name}: [{context}] {msg}") diff --git a/samples/apps/cap/py/demo/CAPAutoGenPairDemo.py b/samples/apps/cap/py/demo/CAPAutoGenPairDemo.py index 732bfecad17b..00ff7a892878 100644 --- a/samples/apps/cap/py/demo/CAPAutoGenPairDemo.py +++ b/samples/apps/cap/py/demo/CAPAutoGenPairDemo.py @@ -1,13 +1,16 @@ import time +import autogencap.DebugLog as DebugLog from autogencap.ag_adapter.CAPPair import CAPPair from autogencap.ComponentEnsemble import ComponentEnsemble -from autogencap.DebugLog import Info +from autogencap.DebugLog import ConsoleLogger, Info from autogen import AssistantAgent, UserProxyAgent, config_list_from_json def cap_ag_pair_demo(): + DebugLog.LOGGER = ConsoleLogger(use_color=False) + config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST") assistant = AssistantAgent("assistant", llm_config={"config_list": config_list}) user_proxy = UserProxyAgent( @@ -20,7 +23,10 @@ def cap_ag_pair_demo(): ensemble = ComponentEnsemble() pair = CAPPair(ensemble, user_proxy, assistant) - pair.initiate_chat("Plot a chart of MSFT daily closing prices for last 1 Month.") + user_cmd = "Plot a chart of MSFT daily closing prices for last 1 Month" + print(f"Default: {user_cmd}") + user_cmd = input("Enter a command: ") or user_cmd + pair.initiate_chat(user_cmd) # Wait for the pair to finish try: diff --git a/samples/apps/cap/py/demo/single_threaded.py b/samples/apps/cap/py/demo/single_threaded.py index 43cffbf02c89..d95f67128e64 100644 --- a/samples/apps/cap/py/demo/single_threaded.py +++ b/samples/apps/cap/py/demo/single_threaded.py @@ -19,14 +19,19 @@ def single_threaded_demo(): greeter_link.send_txt_msg("Hello World!") no_msg = 0 + + # This is where we process the messages in this thread + # instead of using a separate thread + + # 5 consecutive times with no message received + # will break the loop while no_msg < 5: + # Get the message for the actor message = agent.get_message() + # Let the actor process the message agent.dispatch_message(message) - if message is None: - no_msg += 1 - - message = agent.get_message() - agent.dispatch_message(message) + # If no message is received, increment the counter otherwise reset it + no_msg = no_msg + 1 if message is None else 0 ensemble.disconnect() diff --git a/samples/apps/cap/py/demo/standalone/Assistant.py b/samples/apps/cap/py/demo/standalone/assistant.py similarity index 100% rename from samples/apps/cap/py/demo/standalone/Assistant.py rename to samples/apps/cap/py/demo/standalone/assistant.py diff --git a/samples/apps/cap/py/pyproject.toml b/samples/apps/cap/py/pyproject.toml index 8988604a334f..8a0fe227e805 100644 --- a/samples/apps/cap/py/pyproject.toml +++ b/samples/apps/cap/py/pyproject.toml @@ -3,8 +3,8 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "autogencap_rajan.jedi" -version = "0.0.10" +name = "autogencap" +version = "0.0.11" authors = [ { name="Rajan Chari", email="rajan.jedi@gmail.com" }, ] diff --git a/setup.py b/setup.py index 9117ed45ceac..13a88be5f0a4 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], "retrievechat": retrieve_chat, "retrievechat-pgvector": retrieve_chat_pgvector, + "retrievechat-mongodb": [*retrieve_chat, "pymongo>=4.0.0"], "retrievechat-qdrant": [*retrieve_chat, "qdrant_client", "fastembed>=0.3.1"], "autobuild": ["chromadb", "sentence-transformers", "huggingface-hub", "pysqlite3"], "teachable": ["chromadb"], diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py index 6fc69097fc0d..7132cb72053b 100755 --- a/test/agentchat/contrib/test_gpt_assistant.py +++ b/test/agentchat/contrib/test_gpt_assistant.py @@ -28,6 +28,7 @@ filter_dict={ "api_type": ["openai"], "model": [ + "gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-4-turbo-preview", diff --git a/test/agentchat/contrib/vectordb/test_mongodb.py b/test/agentchat/contrib/vectordb/test_mongodb.py new file mode 100644 index 000000000000..3ae1ed572591 --- /dev/null +++ b/test/agentchat/contrib/vectordb/test_mongodb.py @@ -0,0 +1,402 @@ +import logging +import os +import random +from time import monotonic, sleep +from typing import List + +import pytest + +from autogen.agentchat.contrib.vectordb.base import Document + +try: + import pymongo + import sentence_transformers + + from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB +except ImportError: + # To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true + logger = logging.getLogger(__name__) + logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]") + pytest.skip(allow_module_level=True) + +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.errors import OperationFailure + +logger = logging.getLogger(__name__) + +MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/?directConnection=true") +MONGODB_DATABASE = os.environ.get("DATABASE", "autogen_test_db") +MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "autogen_test_vectorstore") +MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "vector_index") + +RETRIES = 10 +DELAY = 2 +TIMEOUT = 120.0 + + +def _wait_for_predicate(predicate, err, timeout=TIMEOUT, interval=DELAY): + """Generic to block until the predicate returns true + + Args: + predicate (Callable[, bool]): A function that returns a boolean value + err (str): Error message to raise if nothing occurs + timeout (float, optional): Length of time to wait for predicate. Defaults to TIMEOUT. + interval (float, optional): Interval to check predicate. Defaults to DELAY. + + Raises: + TimeoutError: _description_ + """ + start = monotonic() + while not predicate(): + if monotonic() - start > TIMEOUT: + raise TimeoutError(err) + sleep(DELAY) + + +def _delete_search_indexes(collection: Collection, wait=True): + """Deletes all indexes in a collection + + Args: + collection (pymongo.Collection): MongoDB Collection Abstraction + """ + for index in collection.list_search_indexes(): + try: + collection.drop_search_index(index["name"]) + except OperationFailure: + # Delete already issued + pass + if wait: + _wait_for_predicate(lambda: not list(collection.list_search_indexes()), "Not all collections deleted") + + +def _empty_collections_and_delete_indexes(database, collections=None, wait=True): + """Empty all collections within the database and remove indexes + + Args: + database (pymongo.Database): MongoDB Database Abstraction + """ + for collection_name in collections or database.list_collection_names(): + _delete_search_indexes(database[collection_name], wait) + database[collection_name].drop() + + +@pytest.fixture +def db(): + """VectorDB setup and teardown, including collections and search indexes""" + database = MongoClient(MONGODB_URI)[MONGODB_DATABASE] + _empty_collections_and_delete_indexes(database) + vectorstore = MongoDBAtlasVectorDB( + connection_string=MONGODB_URI, + database_name=MONGODB_DATABASE, + wait_until_index_ready=TIMEOUT, + overwrite=True, + ) + yield vectorstore + _empty_collections_and_delete_indexes(database) + + +@pytest.fixture +def example_documents() -> List[Document]: + """Note mix of integers and strings as ids""" + return [ + Document(id=1, content="Dogs are tough.", metadata={"a": 1}), + Document(id=2, content="Cats have fluff.", metadata={"b": 1}), + Document(id="1", content="What is a sandwich?", metadata={"c": 1}), + Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}), + ] + + +@pytest.fixture +def db_with_indexed_clxn(collection_name): + """VectorDB with a collection created immediately""" + database = MongoClient(MONGODB_URI)[MONGODB_DATABASE] + _empty_collections_and_delete_indexes(database, [collection_name], wait=True) + vectorstore = MongoDBAtlasVectorDB( + connection_string=MONGODB_URI, + database_name=MONGODB_DATABASE, + wait_until_index_ready=TIMEOUT, + collection_name=collection_name, + overwrite=True, + ) + yield vectorstore, vectorstore.db[collection_name] + _empty_collections_and_delete_indexes(database, [collection_name]) + + +_COLLECTION_NAMING_CACHE = [] + + +@pytest.fixture +def collection_name(): + collection_id = random.randint(0, 100) + while collection_id in _COLLECTION_NAMING_CACHE: + collection_id = random.randint(0, 100) + _COLLECTION_NAMING_CACHE.append(collection_id) + + return f"{MONGODB_COLLECTION}_{collection_id}" + + +def test_create_collection(db, collection_name): + """ + def create_collection(collection_name: str, + overwrite: bool = False) -> Collection + Create a collection in the vector database. + - Case 1. if the collection does not exist, create the collection. + - Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + - Case 3. the collection exists and overwrite is False return the existing collection. + - Case 4. the collection exists and overwrite is False and get_or_create is False, raise a ValueError + """ + collection_case_1 = db.create_collection( + collection_name=collection_name, + ) + assert collection_case_1.name == collection_name + + collection_case_2 = db.create_collection( + collection_name=collection_name, + overwrite=True, + ) + assert collection_case_2.name == collection_name + + collection_case_3 = db.create_collection( + collection_name=collection_name, + ) + assert collection_case_3.name == collection_name + + with pytest.raises(ValueError): + db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False) + + +def test_get_collection(db, collection_name): + with pytest.raises(ValueError): + db.get_collection() + + collection_created = db.create_collection(collection_name) + assert isinstance(collection_created, Collection) + assert collection_created.name == collection_name + + collection_got = db.get_collection(collection_name) + assert collection_got.name == collection_created.name + assert collection_got.name == db.active_collection.name + + +def test_delete_collection(db, collection_name): + assert collection_name not in db.list_collections() + collection = db.create_collection(collection_name) + assert collection_name in db.list_collections() + db.delete_collection(collection.name) + assert collection_name not in db.list_collections() + + +def test_insert_docs(db, collection_name, example_documents): + # Test that there's an active collection + with pytest.raises(ValueError) as exc: + db.insert_docs(example_documents) + assert "No collection is specified" in str(exc.value) + + # Test upsert + db.insert_docs(example_documents, collection_name, upsert=True) + + # Create a collection + db.delete_collection(collection_name) + collection = db.create_collection(collection_name) + + # Insert example documents + db.insert_docs(example_documents, collection_name=collection_name) + found = list(collection.find({})) + assert len(found) == len(example_documents) + # Check that documents have correct fields, including "_id" and "embedding" but not "id" + assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found]) + # Check ids + assert {doc["_id"] for doc in found} == {1, "1", 2, "2"} + # Check embedding lengths + assert len(found[0]["embedding"]) == 384 + + +def test_update_docs(db_with_indexed_clxn, example_documents): + db, collection = db_with_indexed_clxn + # Use update_docs to insert new documents + db.update_docs(example_documents, collection.name, upsert=True) + # Test that no changes were made to example_documents + assert set(example_documents[0].keys()) == {"id", "content", "metadata"} + assert collection.count_documents({}) == len(example_documents) + found = list(collection.find({})) + # Check that documents have correct fields, including "_id" and "embedding" but not "id" + assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found]) + assert all([isinstance(doc["embedding"][0], float) for doc in found]) + assert all([len(doc["embedding"]) == db.dimensions for doc in found]) + # Check ids + assert {doc["_id"] for doc in found} == {1, "1", 2, "2"} + + # Update an *existing* Document + updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10}) + db.update_docs([updated_doc], collection.name) + assert collection.find_one({"_id": 1})["content"] == "Cats are tough." + + # Upsert a *new* Document + new_id = 3 + new_doc = Document(id=new_id, content="Cats are tough.") + db.update_docs([new_doc], collection.name, upsert=True) + assert collection.find_one({"_id": new_id})["content"] == "Cats are tough." + + # Attempting to use update to insert a new doc + # *without* setting upsert set to True + # is a no-op in MongoDB. # TODO Confirm behaviour and autogen's preference. + new_id = 4 + new_doc = Document(id=new_id, content="That is NOT a sandwich?") + db.update_docs([new_doc], collection.name) + assert collection.find_one({"_id": new_id}) is None + + +def test_delete_docs(db_with_indexed_clxn, example_documents): + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + # Delete the 1s + db.delete_docs(ids=[1, "1"], collection_name=clxn.name) + # Confirm just the 2s remain + assert {2, "2"} == {doc["_id"] for doc in clxn.find({})} + + +def test_get_docs_by_ids(db_with_indexed_clxn, example_documents): + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + + # Test without setting "include" kwarg + docs = db.get_docs_by_ids(ids=[2, "2"], collection_name=clxn.name) + assert len(docs) == 2 + assert all([doc["id"] in [2, "2"] for doc in docs]) + assert set(docs[0].keys()) == {"id", "content", "metadata"} + + # Test with include + docs = db.get_docs_by_ids(ids=[2], include=["content"], collection_name=clxn.name) + assert len(docs) == 1 + assert set(docs[0].keys()) == {"id", "content"} + + # Test with empty ids list + docs = db.get_docs_by_ids(ids=[], include=["content"], collection_name=clxn.name) + assert len(docs) == 0 + + # Test with empty ids list + docs = db.get_docs_by_ids(ids=None, include=["content"], collection_name=clxn.name) + assert len(docs) == 4 + + +def test_retrieve_docs_empty(db_with_indexed_clxn): + db, clxn = db_with_indexed_clxn + assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == [] + + +def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents): + db, clxn = db_with_indexed_clxn + db.insert_docs(example_documents, collection_name=clxn.name) + # Empty list of queries returns empty list of results + results = db.retrieve_docs(queries=[], collection_name=clxn.name, n_results=2) + assert results == [] + + +def test_retrieve_docs(db_with_indexed_clxn, example_documents): + """Begin testing Atlas Vector Search + NOTE: Indexing may take some time, so we must be patient on the first query. + We have the wait_until_index_ready flag to ensure index is created and ready + Immediately adding documents and then querying is only standard for testing + """ + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + + n_results = 2 # Number of closest docs to return + + def results_ready(): + results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results) + return len(results[0]) == n_results + + _wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.") + + results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results) + assert {doc[0]["id"] for doc in results[0]} == {1, 2} + assert all(["embedding" not in doc[0] for doc in results[0]]) + + +def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents): + """Begin testing Atlas Vector Search + NOTE: Indexing may take some time, so we must be patient on the first query. + We have the wait_until_index_ready flag to ensure index is created and ready + Immediately adding documents and then querying is only standard for testing + """ + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + + n_results = 2 # Number of closest docs to return + + def results_ready(): + results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results) + return len(results[0]) == n_results + + _wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.") + + results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results, include_embedding=True) + assert {doc[0]["id"] for doc in results[0]} == {1, 2} + assert all(["embedding" in doc[0] for doc in results[0]]) + + +def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents): + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + n_results = 2 # Number of closest docs to return + + queries = ["Some good pets", "What kind of Sandwich?"] + + def results_ready(): + results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results) + return all([len(res) == n_results for res in results]) + + _wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.") + + results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=2) + + assert len(results) == len(queries) + assert all([len(res) == n_results for res in results]) + assert {doc[0]["id"] for doc in results[0]} == {1, 2} + assert {doc[0]["id"] for doc in results[1]} == {"1", "2"} + + +def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents): + db, clxn = db_with_indexed_clxn + # Insert example documents + db.insert_docs(example_documents, collection_name=clxn.name) + + n_results = 2 # Number of closest docs to return + queries = ["Cats"] + + def results_ready(): + results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results) + return len(results[0]) == n_results + + _wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.") + + # Distance Threshold of .3 means that the score must be .7 or greater + # only one result should be that value + results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results, distance_threshold=0.3) + assert len(results[0]) == 1 + assert all([doc[1] >= 0.7 for doc in results[0]]) + + +def test_wait_until_document_ready(collection_name, example_documents): + database = MongoClient(MONGODB_URI)[MONGODB_DATABASE] + _empty_collections_and_delete_indexes(database, [collection_name], wait=True) + try: + vectorstore = MongoDBAtlasVectorDB( + connection_string=MONGODB_URI, + database_name=MONGODB_DATABASE, + wait_until_index_ready=TIMEOUT, + collection_name=collection_name, + overwrite=True, + wait_until_document_ready=TIMEOUT, + ) + vectorstore.insert_docs(example_documents) + assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4) + finally: + _empty_collections_and_delete_indexes(database, [collection_name]) diff --git a/test/agentchat/test_chats.py b/test/agentchat/test_chats.py index 480a28051b4b..896287de2786 100755 --- a/test/agentchat/test_chats.py +++ b/test/agentchat/test_chats.py @@ -10,6 +10,7 @@ import autogen from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats +from autogen.agentchat.chat import _post_process_carryover_item sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from conftest import reason, skip_openai # noqa: E402 @@ -620,6 +621,15 @@ def my_writing_task(sender, recipient, context): print(chat_results[1].summary, chat_results[1].cost) +def test_post_process_carryover_item(): + gemini_carryover_item = {"content": "How can I help you?", "role": "model"} + assert ( + _post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"] + ), "Incorrect carryover postprocessing" + carryover_item = "How can I help you?" + assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing" + + if __name__ == "__main__": test_chats() # test_chats_general() @@ -628,3 +638,4 @@ def my_writing_task(sender, recipient, context): # test_chats_w_func() # test_chat_messages_for_summary() # test_udf_message_in_chats() + test_post_process_carryover_item() diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 0a8d1daebc80..c0d37a7bd7a1 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -25,7 +25,13 @@ here = os.path.abspath(os.path.dirname(__file__)) -gpt4_config_list = [{"model": "gpt-4"}, {"model": "gpt-4-turbo"}, {"model": "gpt-4-32k"}, {"model": "gpt-4o"}] +gpt4_config_list = [ + {"model": "gpt-4"}, + {"model": "gpt-4-turbo"}, + {"model": "gpt-4-32k"}, + {"model": "gpt-4o"}, + {"model": "gpt-4o-mini"}, +] @pytest.fixture @@ -1463,6 +1469,58 @@ def sample_function(): ) +def test_process_gemini_carryover(): + dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER") + content = "I am your assistant." + carryover_content = "How can I help you?" + gemini_kwargs = {"carryover": [{"content": carryover_content}]} + proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs) + assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing" + + +def test_process_carryover(): + dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER") + content = "I am your assistant." + carryover = "How can I help you?" + kwargs = {"carryover": carryover} + proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs) + assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing" + + carryover_l = ["How can I help you?"] + kwargs = {"carryover": carryover_l} + proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs) + assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing" + + proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None}) + assert proc_content_empty_carryover == content, "Incorrect carryover processing" + + +def test_handle_gemini_carryover(): + dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER") + content = "I am your assistant" + carryover_content = "How can I help you?" + gemini_kwargs = {"carryover": [{"content": carryover_content}]} + proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs) + assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing" + + +def test_handle_carryover(): + dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER") + content = "I am your assistant." + carryover = "How can I help you?" + kwargs = {"carryover": carryover} + proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs) + assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing" + + carryover_l = ["How can I help you?"] + kwargs = {"carryover": carryover_l} + proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs) + assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing" + + proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None}) + assert proc_content_empty_carryover == content, "Incorrect carryover processing" + + if __name__ == "__main__": # test_trigger() # test_context() @@ -1473,6 +1531,10 @@ def sample_function(): # test_max_turn() # test_process_before_send() # test_message_func() + test_summary() test_adding_duplicate_function_warning() # test_function_registration_e2e_sync() + + test_process_gemini_carryover() + test_process_carryover() diff --git a/test/agentchat/test_function_call.py b/test/agentchat/test_function_call.py index d3e174949b4b..0f1d4f909426 100755 --- a/test/agentchat/test_function_call.py +++ b/test/agentchat/test_function_call.py @@ -213,7 +213,7 @@ def test_update_function(): config_list_gpt4 = autogen.config_list_from_json( OAI_CONFIG_LIST, filter_dict={ - "tags": ["gpt-4", "gpt-4-32k", "gpt-4o"], + "tags": ["gpt-4", "gpt-4-32k", "gpt-4o", "gpt-4o-mini"], }, file_location=KEY_LOC, ) diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 35f35620b8b9..61fdbe6d735a 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -1,15 +1,26 @@ +import os from unittest.mock import MagicMock, patch import pytest try: + import google.auth from google.api_core.exceptions import InternalServerError + from google.auth.credentials import Credentials + from google.cloud.aiplatform.initializer import global_config as vertexai_global_config + from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold + from vertexai.generative_models import HarmCategory as VertexAIHarmCategory + from vertexai.generative_models import SafetySetting as VertexAISafetySetting from autogen.oai.gemini import GeminiClient skip = False except ImportError: GeminiClient = object + VertexAIHarmBlockThreshold = object + VertexAIHarmCategory = object + VertexAISafetySetting = object + vertexai_global_config = object InternalServerError = object skip = True @@ -30,7 +41,24 @@ def __init__(self, text, choices, usage, cost, model): @pytest.fixture def gemini_client(): - return GeminiClient(api_key="fake_api_key") + system_message = [ + "You are a helpful AI assistant.", + ] + return GeminiClient(api_key="fake_api_key", system_message=system_message) + + +@pytest.fixture +def gemini_google_auth_default_client(): + system_message = [ + "You are a helpful AI assistant.", + ] + return GeminiClient(system_message=system_message) + + +@pytest.fixture +def gemini_client_with_credentials(): + mock_credentials = MagicMock(Credentials) + return GeminiClient(credentials=mock_credentials) # Test compute location initialization and configuration @@ -42,9 +70,13 @@ def test_compute_location_initialization(): ) # Should raise an AssertionError due to specifying API key and compute location -@pytest.fixture -def gemini_google_auth_default_client(): - return GeminiClient() +# Test project initialization and configuration +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_project_initialization(): + with pytest.raises(AssertionError): + GeminiClient( + api_key="fake_api_key", project_id="fake-project-id" + ) # Should raise an AssertionError due to specifying API key and compute location @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @@ -52,6 +84,23 @@ def test_valid_initialization(gemini_client): assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set" +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_google_application_credentials_initialization(): + GeminiClient(google_application_credentials="credentials.json", project_id="fake-project-id") + assert ( + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] == "credentials.json" + ), "Incorrect Google Application Credentials initialization" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_vertexai_initialization(): + mock_credentials = MagicMock(Credentials) + GeminiClient(credentials=mock_credentials, project_id="fake-project-id", location="us-west1") + assert vertexai_global_config.location == "us-west1", "Incorrect VertexAI location initialization" + assert vertexai_global_config.project == "fake-project-id", "Incorrect VertexAI project initialization" + assert vertexai_global_config.credentials == mock_credentials, "Incorrect VertexAI credentials initialization" + + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") def test_gemini_message_handling(gemini_client): messages = [ @@ -94,6 +143,113 @@ def test_gemini_message_handling(gemini_client): assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text" +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_gemini_empty_message_handling(gemini_client): + messages = [ + {"role": "system", "content": "You are my personal assistant."}, + {"role": "model", "content": "How can I help you?"}, + {"role": "user", "content": ""}, + { + "role": "model", + "content": "Please provide me with some context or a request! I need more information to assist you.", + }, + {"role": "user", "content": ""}, + ] + + converted_messages = gemini_client._oai_messages_to_gemini_messages(messages) + assert converted_messages[-3].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly" + assert converted_messages[-1].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_vertexai_safety_setting_conversion(gemini_client): + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, + ] + converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings) + harm_categories = [ + VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT, + VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH, + VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ] + expected_safety_settings = [ + VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH) + for category in harm_categories + ] + + def compare_safety_settings(converted_safety_settings, expected_safety_settings): + for i, expected_setting in enumerate(expected_safety_settings): + converted_setting = converted_safety_settings[i] + yield expected_setting.to_dict() == converted_setting.to_dict() + + assert len(converted_safety_settings) == len( + expected_safety_settings + ), "The length of the safety settings is incorrect" + settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings) + assert all(settings_comparison), "Converted safety settings are incorrect" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_vertexai_default_safety_settings_dict(gemini_client): + safety_settings = { + VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH, + VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH, + VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH, + VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH, + } + converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings) + + expected_safety_settings = { + category: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH for category in safety_settings.keys() + } + + def compare_safety_settings(converted_safety_settings, expected_safety_settings): + for expected_setting_key in expected_safety_settings.keys(): + expected_setting = expected_safety_settings[expected_setting_key] + converted_setting = converted_safety_settings[expected_setting_key] + yield expected_setting == converted_setting + + assert len(converted_safety_settings) == len( + expected_safety_settings + ), "The length of the safety settings is incorrect" + settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings) + assert all(settings_comparison), "Converted safety settings are incorrect" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +def test_vertexai_safety_setting_list(gemini_client): + harm_categories = [ + VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT, + VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH, + VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ] + + expected_safety_settings = safety_settings = [ + VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH) + for category in harm_categories + ] + + print(safety_settings) + + converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings) + + def compare_safety_settings(converted_safety_settings, expected_safety_settings): + for i, expected_setting in enumerate(expected_safety_settings): + converted_setting = converted_safety_settings[i] + yield expected_setting.to_dict() == converted_setting.to_dict() + + assert len(converted_safety_settings) == len( + expected_safety_settings + ), "The length of the safety settings is incorrect" + settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings) + assert all(settings_comparison), "Converted safety settings are incorrect" + + # Test error handling @patch("autogen.oai.gemini.genai") @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @@ -150,6 +306,62 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client): assert response.choices[0].message.content == "Example response", "Response content should match expected output" +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +@patch("autogen.oai.gemini.GenerativeModel") +@patch("autogen.oai.gemini.vertexai.init") +def test_vertexai_create_response(mock_init, mock_generative_model, gemini_client_with_credentials): + # Mock the genai model configuration and creation process + mock_chat = MagicMock() + mock_model = MagicMock() + mock_init.return_value = None + mock_generative_model.return_value = mock_model + mock_model.start_chat.return_value = mock_chat + + # Set up a mock for the chat history item access and the text attribute return + mock_history_part = MagicMock() + mock_history_part.text = "Example response" + mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + + # Setup the mock to return a mocked chat response + mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + + # Call the create method + response = gemini_client_with_credentials.create( + {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} + ) + + # Assertions to check if response is structured as expected + assert response.choices[0].message.content == "Example response", "Response content should match expected output" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +@patch("autogen.oai.gemini.GenerativeModel") +@patch("autogen.oai.gemini.vertexai.init") +def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, gemini_google_auth_default_client): + # Mock the genai model configuration and creation process + mock_chat = MagicMock() + mock_model = MagicMock() + mock_init.return_value = None + mock_generative_model.return_value = mock_model + mock_model.start_chat.return_value = mock_chat + + # Set up a mock for the chat history item access and the text attribute return + mock_history_part = MagicMock() + mock_history_part.text = "Example response" + mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + + # Setup the mock to return a mocked chat response + mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + + # Call the create method + response = gemini_google_auth_default_client.create( + {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} + ) + + # Assertions to check if response is structured as expected + assert response.choices[0].message.content == "Example response", "Response content should match expected output" + + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") @patch("autogen.oai.gemini.genai.configure") @@ -195,3 +407,49 @@ def test_create_vision_model_response(mock_configure, mock_generative_model, gem assert ( response.choices[0].message.content == "Vision model output" ), "Response content should match expected output from vision model" + + +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +@patch("autogen.oai.gemini.GenerativeModel") +@patch("autogen.oai.gemini.vertexai.init") +def test_vertexai_create_vision_model_response(mock_init, mock_generative_model, gemini_google_auth_default_client): + # Mock the genai model configuration and creation process + mock_model = MagicMock() + mock_init.return_value = None + mock_generative_model.return_value = mock_model + + # Set up a mock to simulate the vision model behavior + mock_vision_response = MagicMock() + mock_vision_part = MagicMock(text="Vision model output") + + # Setting up the chain of return values for vision model response + mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part + + mock_model.generate_content.return_value = mock_vision_response + + # Call the create method with vision model parameters + response = gemini_google_auth_default_client.create( + { + "model": "gemini-pro-vision", # Vision model name + "messages": [ + { + "content": [ + {"type": "text", "text": "Let's play a game."}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + "role": "user", + } + ], # Assuming a simple content input for vision + "stream": False, + } + ) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Vision model output" + ), "Response content should match expected output from vision model" diff --git a/test/oai/test_utils.py b/test/oai/test_utils.py index 96956d07d90d..fd81d3f9f548 100755 --- a/test/oai/test_utils.py +++ b/test/oai/test_utils.py @@ -55,8 +55,8 @@ { "model": "gpt-35-turbo-v0301", "tags": ["gpt-3.5-turbo", "gpt35_turbo"], - "api_key": "111113fc7e8a46419bfac511bb301111", - "base_url": "https://1111.openai.azure.com", + "api_key": "Your Azure OAI API Key", + "base_url": "https://deployment_name.openai.azure.com", "api_type": "azure", "api_version": "2024-02-01" }, diff --git a/website/docs/Use-Cases/agent_chat.md b/website/docs/Use-Cases/agent_chat.md index c55b0d29d5d6..59156c0eb046 100644 --- a/website/docs/Use-Cases/agent_chat.md +++ b/website/docs/Use-Cases/agent_chat.md @@ -18,7 +18,7 @@ designed to solve tasks through inter-agent conversations. Specifically, the age The figure below shows the built-in agents in AutoGen. ![Agent Chat Example](images/autogen_agents.png) -We have designed a generic [`ConversableAgent`](../reference/agentchat/conversable_agent#conversableagent-objects) +We have designed a generic [`ConversableAgent`](../reference/agentchat/conversable_agent.md#conversableagent-objects) class for Agents that are capable of conversing with each other through the exchange of messages to jointly finish a task. An agent can communicate with other agents and perform actions. Different agents can differ in what actions they perform after receiving messages. Two representative subclasses are [`AssistantAgent`](../reference/agentchat/assistant_agent.md#assistantagent-objects) and [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) - The [`AssistantAgent`](../reference/agentchat/assistant_agent.md#assistantagent-objects) is designed to act as an AI assistant, using LLMs by default but not requiring human input or code execution. It could write Python code (in a Python coding block) for a user to execute when a message (typically a description of a task that needs to be solved) is received. Under the hood, the Python code is written by LLM (e.g., GPT-4). It can also receive the execution results and suggest corrections or bug fixes. Its behavior can be altered by passing a new system message. The LLM [inference](#enhanced-inference) configuration can be configured via [`llm_config`]. diff --git a/website/docs/topics/non-openai-models/cloud-gemini.ipynb b/website/docs/topics/non-openai-models/cloud-gemini.ipynb index da773e0d4472..70dc808df616 100644 --- a/website/docs/topics/non-openai-models/cloud-gemini.ipynb +++ b/website/docs/topics/non-openai-models/cloud-gemini.ipynb @@ -64,7 +64,7 @@ " },\n", " {\n", " \"model\": \"gemini-1.5-pro\",\n", - " \"project\": \"your-awesome-google-cloud-project-id\",\n", + " \"project_id\": \"your-awesome-google-cloud-project-id\",\n", " \"location\": \"us-west1\",\n", " \"google_application_credentials\": \"your-google-service-account-key.json\"\n", " },\n", diff --git a/website/docs/topics/non-openai-models/cloud-gemini_vertexai.ipynb b/website/docs/topics/non-openai-models/cloud-gemini_vertexai.ipynb index eaec2b72b268..e618966dc6cc 100644 --- a/website/docs/topics/non-openai-models/cloud-gemini_vertexai.ipynb +++ b/website/docs/topics/non-openai-models/cloud-gemini_vertexai.ipynb @@ -14,9 +14,14 @@ "\n", "## Requirements\n", "\n", - "AutoGen requires `Python>=3.8`. To run this notebook example, please install with the [gemini] option:\n", + "Install AutoGen with Gemini features:\n", "```bash\n", - "pip install \"pyautogen[gemini]\"\n", + "pip install pyautogen[gemini]\n", + "```\n", + "\n", + "### Install other Dependencies of this Notebook\n", + "```bash\n", + "pip install chromadb markdownify pypdf\n", "```\n", "\n", "### Google Cloud Account\n", @@ -66,41 +71,6 @@ " * Please consider restricting the permissions on the key file. For example, you could run `chmod 600 autogen-with-gemini-service-account-key.json` if your keyfile is called autogen-with-gemini-service-account-key.json." ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "execution": { - "iopub.execute_input": "2023-02-13T23:40:52.317406Z", - "iopub.status.busy": "2023-02-13T23:40:52.316561Z", - "iopub.status.idle": "2023-02-13T23:40:52.321193Z", - "shell.execute_reply": "2023-02-13T23:40:52.320628Z" - } - }, - "outputs": [], - "source": [ - "# %pip install \"pyautogen[gemini]\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "execution": { - "iopub.execute_input": "2023-02-13T23:40:54.634335Z", - "iopub.status.busy": "2023-02-13T23:40:54.633929Z", - "iopub.status.idle": "2023-02-13T23:40:56.105700Z", - "shell.execute_reply": "2023-02-13T23:40:56.105085Z" - }, - "slideshow": { - "slide_type": "slide" - } - }, - "outputs": [], - "source": [ - "import autogen" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -109,7 +79,9 @@ "### Configure Authentication\n", "\n", "Authentication happens using standard [Google Cloud authentication methods](https://cloud.google.com/docs/authentication),
which means\n", - "that either an already active session can be reused, or by specifying the Google application credentials of a service account.\n", + "that either an already active session can be reused, or by specifying the Google application credentials of a service account.

\n", + "Additionally, AutoGen also supports authentication using `Credentials` objects in Python with the [google-auth library](https://google-auth.readthedocs.io/), which enables even more flexibility.
\n", + "For example, we can even use impersonated credentials.\n", "\n", "#### Use Service Account Keyfile\n", "\n", @@ -121,7 +93,13 @@ "\n", "If you are using [Cloud Shell](https://shell.cloud.google.com/cloudshell) or [Cloud Shell editor](https://shell.cloud.google.com/cloudshell/editor) in Google Cloud,
then you are already authenticated. If you have the Google Cloud SDK installed locally,
then you can login by running `gcloud auth login` in the command line. \n", "\n", - "Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install)." + "Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install).\n", + "\n", + "#### Authentication with the Google Auth Library for Python\n", + "\n", + "The google-auth library supports a wide range of authentication scenarios, and you can simply pass a previously created `Credentials` object to the `llm_config`.
\n", + "The [official documentation](https://google-auth.readthedocs.io/) of the Python package provides a detailed overview of the supported methods and usage examples.
\n", + "If you are already authenticated, like in [Cloud Shell](https://shell.cloud.google.com/cloudshell), or after running the `gcloud auth login` command in a CLI, then the `google.auth.default()` Python method will automatically return your currently active credentials." ] }, { @@ -147,7 +125,7 @@ " {\n", " \"model\": \"gemini-1.5-pro\",\n", " \"api_type\": \"google\",\n", - " \"project\": \"autogen-with-gemini\",\n", + " \"project_id\": \"autogen-with-gemini\",\n", " \"location\": \"us-west1\",\n", " \"google_application_credentials\": \"autogen-with-gemini-service-account-key.json\"\n", " },\n", @@ -172,17 +150,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from vertexai.generative_models import (\n", - " GenerationConfig,\n", - " GenerativeModel,\n", - " HarmBlockThreshold,\n", - " HarmCategory,\n", - " Part,\n", - ")\n", + "from vertexai.generative_models import HarmBlockThreshold, HarmCategory\n", "\n", "safety_settings = {\n", " HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n", @@ -194,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -205,6 +177,7 @@ "from PIL import Image\n", "from termcolor import colored\n", "\n", + "import autogen\n", "from autogen import Agent, AssistantAgent, ConversableAgent, UserProxyAgent\n", "from autogen.agentchat.contrib.img_utils import _to_pil, get_image_data\n", "from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent\n", @@ -215,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -242,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -252,33 +225,28 @@ "\u001b[33muser_proxy\u001b[0m (to assistant):\n", "\n", "\n", - " Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script, \n", - " which returns the value of the definite integral.\n", + " Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n", + " which returns the value of the definite integral\n", "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to user_proxy):\n", "\n", "Plan:\n", - "1. (Code) Use Python's numerical integration library to compute the integral.\n", - "2. (Language) Output the result.\n", + "1. (code) Use Python's `scipy.integrate.quad` function to compute the integral. \n", "\n", "```python\n", "# filename: integral.py\n", - "import scipy.integrate\n", + "from scipy.integrate import quad\n", "\n", - "f = lambda x: x**2\n", - "result, error = scipy.integrate.quad(f, 0, 1)\n", + "def f(x):\n", + " return x**2\n", + "\n", + "result, error = quad(f, 0, 1)\n", "\n", "print(f\"The definite integral of x^2 from 0 to 1 is: {result}\")\n", "```\n", "\n", - "Let me know when you have executed the code. \n", + "Let me know when you have executed this code. \n", "\n", "\n", "--------------------------------------------------------------------------------\n", @@ -294,13 +262,11 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to user_proxy):\n", "\n", - "The code executed successfully and returned the value of the definite integral as approximately 0.33333333333333337. \n", - "\n", - "This aligns with the analytical solution:\n", + "The script executed successfully and returned the definite integral's value as approximately 0.33333333333333337. \n", "\n", - "The integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n", + "This aligns with the analytical solution. The indefinite integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n", "\n", - "Therefore, the answer is verified to be correct.\n", + "Therefore, the script successfully computed the integral of x^2 from 0 to 1.\n", "\n", "TERMINATE\n", "\n", @@ -325,7 +291,7 @@ " assistant,\n", " message=\"\"\"\n", " Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n", - " which returns the value of the definite integral.\"\"\",\n", + " which returns the value of the definite integral\"\"\",\n", ")" ] }, @@ -334,12 +300,52 @@ "metadata": {}, "source": [ "## Example with Gemini Multimodal\n", - "Authentication is the same for vision models as for the text based Gemini models" + "Authentication is the same for vision models as for the text based Gemini models.
\n", + "In this example an object of type `Credentials` will be supplied in order to authenticate.
\n", + "Here, we will use the google application default credentials, so make sure to run the following commands if you are not yet authenticated:\n", + "```bash\n", + "export GOOGLE_APPLICATION_CREDENTIALS=autogen-with-gemini-service-account-key.json\n", + "gcloud auth application-default login\n", + "gcloud config set project autogen-with-gemini\n", + "```\n", + "The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is a path to our service account JSON keyfile, as described in the [Use Service Account Keyfile](#Use Service Account Keyfile) section above.
\n", + "We also need to set the Google cloud project, which is `autogen-with-gemini` in this example.

\n", + "\n", + "Note, we could also run `gcloud auth login` in case we wish to use our personal Google account instead of a service account.\n", + "In this case we need to run the following commands:\n", + "```bash\n", + "gcloud auth login\n", + "gcloud config set project autogen-with-gemini\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import google.auth\n", + "\n", + "scopes = [\"https://www.googleapis.com/auth/cloud-platform\"]\n", + "\n", + "credentials, project_id = google.auth.default(scopes)\n", + "\n", + "gemini_vision_config = [\n", + " {\n", + " \"model\": \"gemini-pro-vision\",\n", + " \"api_type\": \"google\",\n", + " \"project_id\": project_id,\n", + " \"credentials\": credentials,\n", + " \"location\": \"us-west1\",\n", + " \"safety_settings\": safety_settings,\n", + " }\n", + "]" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -356,7 +362,7 @@ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", "\u001b[33mGemini Vision\u001b[0m (to user_proxy):\n", "\n", - " The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.\n", + " The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -364,17 +370,17 @@ { "data": { "text/plain": [ - "ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n.', 'role': 'assistant'}, {'content': ' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', 'role': 'user'}], summary=' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', cost={'usage_including_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}, 'usage_excluding_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}}, human_input=[])" + "ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n.', 'role': 'assistant'}, {'content': \" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", 'role': 'user'}], summary=\" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", cost={'usage_including_cached_inference': {'total_cost': 0.0001995, 'gemini-pro-vision': {'cost': 0.0001995, 'prompt_tokens': 267, 'completion_tokens': 44, 'total_tokens': 311}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_agent = MultimodalConversableAgent(\n", - " \"Gemini Vision\", llm_config={\"config_list\": config_list_gemini_vision, \"seed\": seed}, max_consecutive_auto_reply=1\n", + " \"Gemini Vision\", llm_config={\"config_list\": gemini_vision_config, \"seed\": seed}, max_consecutive_auto_reply=1\n", ")\n", "\n", "user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=0)\n", @@ -415,7 +421,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" }, "vscode": { "interpreter": {