diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/.gitignore b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD new file mode 100644 index 0000000000000..84f2657a9f879 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", module_mapping={"google-cloud-aiplatform": ["vertexai"]} +) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md new file mode 100644 index 0000000000000..d625c2907cffb --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md @@ -0,0 +1,14 @@ +# LlamaIndex Embeddings Integration: Vertex + +Implements Vertex AI Embeddings Models: + +| Model | Release Date | +| ------------------------------------ | ----------------- | +| textembedding-gecko@003 | December 12, 2023 | +| textembedding-gecko@002 | November 2, 2023 | +| textembedding-gecko-multilingual@001 | November 2, 2023 | +| textembedding-gecko@001 | June 7, 2023 | +| multimodalembedding | | + +**Note**: Currently Vertex AI does not support async on `multimodalembedding`. +Otherwise, `VertexTextEmbedding` supports async interface. diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/multimodal_embedding.ipynb b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/multimodal_embedding.ipynb new file mode 100644 index 0000000000000..7be7dff32873a --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/multimodal_embedding.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3af200b75c4bd924", + "metadata": {}, + "source": [ + "# Vertex AI Multimodal Embedding\n", + "Uses APPLICATION_DEFAULT_CREDENTIALS if no credentials is specified. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b43f20b2f09ff70", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.embeddings.vertex import VertexMultiModalEmbedding\n", + "\n", + "embed_model = VertexMultiModalEmbedding(\n", + " project=\"speedy-atom-413006\", location=\"us-central1\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a10d4efa47801541", + "metadata": {}, + "outputs": [], + "source": [ + "image_url = \"https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg\"" + ] + }, + { + "cell_type": "markdown", + "id": "6e29951621ec9acc", + "metadata": {}, + "source": [ + "Download this image to `data/test-image.jpg`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45aca848dd1d17e3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.core.display import Image\n", + "\n", + "display(Image(url=image_url, width=500))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d3394b6ce654ec4", + "metadata": {}, + "outputs": [], + "source": [ + "result = embed_model.get_image_embedding(\"data/test-image.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75022fc91552014c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.00822397694,\n", + " 0.0167199261,\n", + " 0.0195552949,\n", + " 0.00935372803,\n", + " 0.00746282,\n", + " 0.011754944,\n", + " -0.0363474153,\n", + " 0.00836938061,\n", + " -0.0170917399,\n", + " 0.0218462963]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "103e4c523d039fdd", + "metadata": {}, + "outputs": [], + "source": [ + "text_result = embed_model.get_text_embedding(\n", + " \"a brown and white puppy laying in the grass with purple daisies in the background\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c14b8971981d61e5", + "metadata": {}, + "outputs": [], + "source": [ + "text_result_2 = embed_model.get_text_embedding(\"airplanes in the sky\")" + ] + }, + { + "cell_type": "markdown", + "id": "588f1585ba25bc57", + "metadata": {}, + "source": [ + "We expect that a similar description to the image will yield a higher similarity result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23a129176e8d1007", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.20342717022759096" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_model.similarity(result, text_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be0c503c3dd57412", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.009063958097860215" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_model.similarity(result, text_result_2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/text_embedding.ipynb b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/text_embedding.ipynb new file mode 100644 index 0000000000000..d992449a5258d --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/examples/text_embedding.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8085d744ceb233ff", + "metadata": {}, + "source": [ + "# Vertex AI Text Embedding\n", + "\n", + "Imports the VertexTextEmbedding class and initializes an instance named embed_model with a specified project and location. Uses APPLICATION_DEFAULT_CREDENTIALS if no credentials is specified. The default model is `textembedding-gecko@003` in document retrival mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c52b0b97984c1ceb", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.embeddings.vertex import VertexTextEmbedding\n", + "\n", + "embed_model = VertexTextEmbedding(project=\"speedy-atom-413006\", location=\"us-central1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61d58ea0808d0941", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'model_name': 'textembedding-gecko@003',\n", + " 'embed_batch_size': 10,\n", + " 'embed_mode': ,\n", + " 'additional_kwargs': {},\n", + " 'class_name': 'VertexTextEmbedding'}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_model.dict()" + ] + }, + { + "cell_type": "markdown", + "id": "c98da813ca018111", + "metadata": {}, + "source": [ + "## Document and Query Retrival" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f6e67d1951da538", + "metadata": {}, + "outputs": [], + "source": [ + "embed_text_result = embed_model.get_text_embedding(\"Hello World!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f61a801502c3de8f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.05736415088176727,\n", + " 0.0049842665903270245,\n", + " -0.07065856456756592,\n", + " -0.021812528371810913,\n", + " 0.060468606650829315]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_text_result[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "416ed8894817e213", + "metadata": {}, + "outputs": [], + "source": [ + "embed_query_result = embed_model.get_query_embedding(\"Hello World!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62510b52e204a271", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.05158292129635811,\n", + " -0.033334773033857346,\n", + " -0.03221268951892853,\n", + " -0.029282240197062492,\n", + " 0.020004423335194588]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_query_result[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d10c0164acddc5d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7375430761259468" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from llama_index.core.base.embeddings.base import SimilarityMode\n", + "\n", + "embed_model.similarity(\n", + " embed_text_result, embed_query_result, SimilarityMode.DOT_PRODUCT\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "68292f47908eabad", + "metadata": {}, + "source": [ + "## Using the async interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10aa2c79d07d6f77", + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()\n", + "\n", + "result = await embed_model.aget_text_embedding(\"Hello World!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "596498385119ecab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.05733369290828705,\n", + " 0.005178301595151424,\n", + " -0.07033716142177582,\n", + " -0.021963153034448624,\n", + " 0.06050697714090347]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[:5]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py new file mode 100644 index 0000000000000..2889c5b5e5d5a --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py @@ -0,0 +1,7 @@ +from llama_index.embeddings.vertex.base import ( + VertexTextEmbedding, + VertexMultiModalEmbedding, + VertexEmbeddingMode, +) + +__all__ = ["VertexTextEmbedding", "VertexMultiModalEmbedding", "VertexEmbeddingMode"] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py new file mode 100644 index 0000000000000..1095758019a3b --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py @@ -0,0 +1,228 @@ +from enum import Enum +from typing import Optional, List, Any, Dict, Union + +import vertexai +from llama_index.core.base.embeddings.base import Embedding, BaseEmbedding +from llama_index.core.bridge.pydantic import PrivateAttr, Field +from llama_index.core.callbacks import CallbackManager +from llama_index.core.embeddings import MultiModalEmbedding +from llama_index.core.schema import ImageType +from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE +from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput +from vertexai.vision_models import MultiModalEmbeddingModel, Image + +from google.auth import credentials as auth_credentials + + +class VertexEmbeddingMode(str, Enum): + """VertexAI embedding mode. + + Attributes: + DEFAULT_MODE (str): The default embedding mode, for older models before August 2023, + that does not support task_type + CLASSIFICATION_MODE (str): Optimizes embeddings for classification tasks. + CLUSTERING_MODE (str): Optimizes embeddings for clustering tasks. + SEMANTIC_SIMILARITY_MODE (str): Optimizes embeddings for tasks that require assessments of semantic similarity. + RETRIEVAL_MODE (str): Optimizes embeddings for retrieval tasks, including search and document retrieval. + """ + + DEFAULT_MODE = "default" + CLASSIFICATION_MODE = "classification" + CLUSTERING_MODE = "clustering" + SEMANTIC_SIMILARITY_MODE = "similarity" + RETRIEVAL_MODE = "retrieval" + + +_TEXT_EMBED_TASK_TYPE_MAPPING: Dict[VertexEmbeddingMode, str] = { + VertexEmbeddingMode.CLASSIFICATION_MODE: "CLASSIFICATION", + VertexEmbeddingMode.CLUSTERING_MODE: "CLUSTERING", + VertexEmbeddingMode.SEMANTIC_SIMILARITY_MODE: "SEMANTIC_SIMILARITY", + VertexEmbeddingMode.RETRIEVAL_MODE: "RETRIEVAL_DOCUMENT", +} + +_QUERY_EMBED_TASK_TYPE_MAPPING: Dict[VertexEmbeddingMode, str] = { + VertexEmbeddingMode.CLASSIFICATION_MODE: "CLASSIFICATION", + VertexEmbeddingMode.CLUSTERING_MODE: "CLUSTERING", + VertexEmbeddingMode.SEMANTIC_SIMILARITY_MODE: "SEMANTIC_SIMILARITY", + VertexEmbeddingMode.RETRIEVAL_MODE: "RETRIEVAL_QUERY", +} + + +def init_vertexai( + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, +) -> None: + """Init vertexai. + + Args: + project: The default GCP project to use when making Vertex API calls. + location: The default location to use when making API calls. + credentials: The default custom + credentials to use when making API calls. If not provided credentials + will be ascertained from the environment. + """ + vertexai.init( + project=project, + location=location, + credentials=credentials, + ) + + +def _get_embedding_request( + texts: List[str], embed_mode: VertexEmbeddingMode, is_query: bool +) -> List[Union[str, TextEmbeddingInput]]: + if embed_mode != VertexEmbeddingMode.DEFAULT_MODE: + mapping = ( + _QUERY_EMBED_TASK_TYPE_MAPPING + if is_query + else _TEXT_EMBED_TASK_TYPE_MAPPING + ) + texts = [ + TextEmbeddingInput(text=text, task_type=mapping[embed_mode]) + for text in texts + ] + return texts + + +class VertexTextEmbedding(BaseEmbedding): + embed_mode: VertexEmbeddingMode = Field(description="The embedding mode to use.") + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Vertex." + ) + + _model: TextEmbeddingModel = PrivateAttr() + + def __init__( + self, + model_name: str = "textembedding-gecko@003", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + embed_mode: VertexEmbeddingMode = VertexEmbeddingMode.RETRIEVAL_MODE, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + callback_manager: Optional[CallbackManager] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + init_vertexai(project=project, location=location, credentials=credentials) + callback_manager = callback_manager or CallbackManager([]) + additional_kwargs = additional_kwargs or {} + + super().__init__( + embed_mode=embed_mode, + additional_kwargs=additional_kwargs, + model_name=model_name, + embed_batch_size=embed_batch_size, + callback_manager=callback_manager, + ) + self._model = TextEmbeddingModel.from_pretrained(model_name) + + @classmethod + def class_name(cls) -> str: + return "VertexTextEmbedding" + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + texts = _get_embedding_request( + texts=texts, embed_mode=self.embed_mode, is_query=False + ) + embeddings = self._model.get_embeddings(texts, **self.additional_kwargs) + return [embedding.values for embedding in embeddings] + + def _get_text_embedding(self, text: str) -> Embedding: + return self._get_text_embeddings([text])[0] + + async def _aget_text_embedding(self, text: str) -> Embedding: + return (await self._aget_text_embeddings([text]))[0] + + async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]: + texts = _get_embedding_request( + texts=texts, embed_mode=self.embed_mode, is_query=False + ) + embeddings = await self._model.get_embeddings_async( + texts, **self.additional_kwargs + ) + return [embedding.values for embedding in embeddings] + + def _get_query_embedding(self, query: str) -> Embedding: + texts = _get_embedding_request( + texts=[query], embed_mode=self.embed_mode, is_query=True + ) + embeddings = self._model.get_embeddings(texts, **self.additional_kwargs) + return embeddings[0].values + + async def _aget_query_embedding(self, query: str) -> Embedding: + texts = _get_embedding_request( + texts=[query], embed_mode=self.embed_mode, is_query=True + ) + embeddings = await self._model.get_embeddings_async( + texts, **self.additional_kwargs + ) + return embeddings[0].values + + +class VertexMultiModalEmbedding(MultiModalEmbedding): + embed_dimension: int = Field(description="The vertex output embedding dimension.") + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Vertex." + ) + + _model: MultiModalEmbeddingModel = PrivateAttr() + _embed_dimension: int = PrivateAttr() + + def __init__( + self, + model_name: str = "multimodalembedding", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Any] = None, + embed_dimension: int = 1408, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + callback_manager: Optional[CallbackManager] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + init_vertexai(project=project, location=location, credentials=credentials) + callback_manager = callback_manager or CallbackManager([]) + additional_kwargs = additional_kwargs or {} + + super().__init__( + embed_dimension=embed_dimension, + additional_kwargs=additional_kwargs, + model_name=model_name, + embed_batch_size=embed_batch_size, + callback_manager=callback_manager, + ) + self._model = MultiModalEmbeddingModel.from_pretrained(model_name) + + @classmethod + def class_name(cls) -> str: + return "VertexMultiModalEmbedding" + + def _get_text_embedding(self, text: str) -> Embedding: + return self._model.get_embeddings( + contextual_text=text, + dimension=self.embed_dimension, + **self.additional_kwargs + ).text_embedding + + def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: + if isinstance(img_file_path, str): + image = Image.load_from_file(img_file_path) + else: + image = Image(image_bytes=img_file_path.getvalue()) + embeddings = self._model.get_embeddings( + image=image, dimension=self.embed_dimension, **self.additional_kwargs + ) + return embeddings.image_embedding + + def _get_query_embedding(self, query: str) -> Embedding: + return self._get_text_embedding(query) + + # Vertex AI SDK does not support async variants yet + async def _aget_text_embedding(self, text: str) -> Embedding: + return self._get_text_embedding(text) + + async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: + return self._get_image_embedding(img_file_path) + + async def _aget_query_embedding(self, query: str) -> Embedding: + return self._get_query_embedding(query) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml new file mode 100644 index 0000000000000..786c029857e87 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = true +import_path = "llama_index.embeddings.vertex" + +[tool.llamahub.class_authors] +VertexMultiModalEmbedding = "mustartt" +VertexTextEmbedding = "mustartt" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Henry Jiang "] +description = "llama-index embeddings vertex integration" +license = "MIT" +name = "llama-index-embeddings-vertex" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +llama-index-core = "^0.10.0" +google-cloud-aiplatform = ">=1.43.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD new file mode 100644 index 0000000000000..619cac15ff840 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD @@ -0,0 +1,3 @@ +python_tests( + interpreter_constraints=["==3.9.*", "==3.10.*"], +) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py new file mode 100644 index 0000000000000..2117305a9ff0e --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py @@ -0,0 +1,254 @@ +import io +import unittest +from unittest.mock import patch, Mock, MagicMock, AsyncMock + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.embeddings import MultiModalEmbedding +from vertexai.language_models import TextEmbedding +from vertexai.vision_models import MultiModalEmbeddingResponse + +from PIL import Image as PillowImage + +from llama_index.embeddings.vertex import ( + VertexTextEmbedding, + VertexMultiModalEmbedding, + VertexEmbeddingMode, +) + + +class VertexTextEmbeddingTest(unittest.TestCase): + @patch("vertexai.init") + @patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") + def test_init(self, model_mock: Mock, mock_init: Mock): + mock_cred = Mock(return_value="mock_credentials_instance") + embedding = VertexTextEmbedding( + model_name="textembedding-gecko@001", + project="test-project", + location="us-test-location", + credentials=mock_cred, + embed_mode=VertexEmbeddingMode.RETRIEVAL_MODE, + embed_batch_size=100, + ) + + mock_init.assert_called_once_with( + project="test-project", + location="us-test-location", + credentials=mock_cred, + ) + + self.assertIsInstance(embedding, BaseEmbedding) + + self.assertEqual(embedding.model_name, "textembedding-gecko@001") + self.assertEqual(embedding.embed_mode, VertexEmbeddingMode.RETRIEVAL_MODE) + self.assertEqual(embedding.embed_batch_size, 100) + + @patch("vertexai.init") + @patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") + def test_get_embedding_retrieval(self, model_mock: Mock, init_mock: Mock): + model = MagicMock() + model_mock.return_value = model + + embedding = VertexTextEmbedding( + project="test-project", + location="us-test-location", + embed_mode=VertexEmbeddingMode.RETRIEVAL_MODE, + additional_kwargs={"auto_truncate": True}, + ) + + model.get_embeddings.return_value = [TextEmbedding(values=[0.1, 0.2, 0.3])] + result = embedding.get_text_embedding("some text") + + model.get_embeddings.assert_called_once() + positional_args, keyword_args = model.get_embeddings.call_args + model.get_embeddings.reset_mock() + + self.assertEqual(len(positional_args[0]), 1) + self.assertEqual(positional_args[0][0].text, "some text") + self.assertEqual(positional_args[0][0].task_type, "RETRIEVAL_DOCUMENT") + self.assertEqual(result, [0.1, 0.2, 0.3]) + self.assertTrue(keyword_args["auto_truncate"]) + + model.get_embeddings.return_value = [TextEmbedding(values=[0.1, 0.2, 0.3])] + result = embedding.get_query_embedding("some query text") + + model.get_embeddings.assert_called_once() + positional_args, keyword_args = model.get_embeddings.call_args + + self.assertEqual(len(positional_args[0]), 1) + self.assertEqual(positional_args[0][0].text, "some query text") + self.assertEqual(positional_args[0][0].task_type, "RETRIEVAL_QUERY") + self.assertEqual(result, [0.1, 0.2, 0.3]) + self.assertTrue(keyword_args["auto_truncate"]) + + +class VertexTextEmbeddingTestAsync(unittest.IsolatedAsyncioTestCase): + @patch("vertexai.init") + @patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") + async def test_get_embedding_retrieval( + self, model_mock: AsyncMock, init_mock: AsyncMock + ): + model = MagicMock() + model.get_embeddings_async = ( + AsyncMock() + ) # Ensure get_embeddings is an AsyncMock for async calls + model_mock.return_value = model + + embedding = VertexTextEmbedding( + project="test-project", + location="us-test-location", + embed_mode=VertexEmbeddingMode.RETRIEVAL_MODE, + additional_kwargs={"auto_truncate": True}, + ) + + model.get_embeddings_async.return_value = [ + TextEmbedding(values=[0.1, 0.2, 0.3]) + ] + result = await embedding.aget_text_embedding("some text") + + model.get_embeddings_async.assert_called_once() + positional_args, keyword_args = model.get_embeddings_async.call_args + model.get_embeddings_async.reset_mock() + + self.assertEqual(len(positional_args[0]), 1) + self.assertEqual(positional_args[0][0].text, "some text") + self.assertEqual(positional_args[0][0].task_type, "RETRIEVAL_DOCUMENT") + self.assertEqual(result, [0.1, 0.2, 0.3]) + self.assertTrue(keyword_args["auto_truncate"]) + + model.get_embeddings_async.return_value = [ + TextEmbedding(values=[0.1, 0.2, 0.3]) + ] + result = await embedding.aget_query_embedding("some query text") + + model.get_embeddings_async.assert_called_once() + positional_args, keyword_args = model.get_embeddings_async.call_args + + self.assertEqual(len(positional_args[0]), 1) + self.assertEqual(positional_args[0][0].text, "some query text") + self.assertEqual(positional_args[0][0].task_type, "RETRIEVAL_QUERY") + self.assertEqual(result, [0.1, 0.2, 0.3]) + self.assertTrue(keyword_args["auto_truncate"]) + + +class VertexMultiModalEmbeddingTest(unittest.TestCase): + @patch("vertexai.init") + @patch("vertexai.vision_models.MultiModalEmbeddingModel.from_pretrained") + def test_init(self, model_mock: Mock, mock_init: Mock): + mock_cred = Mock(return_value="mock_credentials_instance") + embedding = VertexMultiModalEmbedding( + model_name="multimodalembedding", + project="test-project", + location="us-test-location", + credentials=mock_cred, + embed_dimension=1408, + embed_batch_size=100, + ) + + mock_init.assert_called_once_with( + project="test-project", + location="us-test-location", + credentials=mock_cred, + ) + + self.assertIsInstance(embedding, MultiModalEmbedding) + + self.assertEqual(embedding.model_name, "multimodalembedding") + self.assertEqual(embedding.embed_batch_size, 100) + self.assertEqual(embedding.embed_dimension, 1408) + + @patch("vertexai.init") + @patch("vertexai.vision_models.MultiModalEmbeddingModel.from_pretrained") + def test_text_embedding(self, model_mock: Mock, init_mock: Mock): + model = MagicMock() + model_mock.return_value = model + + embedding = VertexMultiModalEmbedding( + project="test-project", + location="us-test-location", + embed_dimension=1408, + additional_kwargs={"additional_kwarg": True}, + ) + + model.get_embeddings.return_value = MultiModalEmbeddingResponse( + _prediction_response=None, text_embedding=[0.1, 0.2, 0.3] + ) + + result = embedding.get_text_embedding("some text") + self.assertEqual(result, [0.1, 0.2, 0.3]) + + model.get_embeddings.assert_called_once() + positional_args, keyword_args = model.get_embeddings.call_args + + self.assertEqual(keyword_args["contextual_text"], "some text") + self.assertEqual(keyword_args["dimension"], 1408) + self.assertTrue(keyword_args["additional_kwarg"]) + + @patch("vertexai.init") + @patch("vertexai.vision_models.Image.load_from_file") + @patch("vertexai.vision_models.MultiModalEmbeddingModel.from_pretrained") + def test_image_embedding_path( + self, model_mock: Mock, load_file_mock: Mock, init_mock: Mock + ): + model = MagicMock() + model_mock.return_value = model + + embedding = VertexMultiModalEmbedding( + project="test-project", + location="us-test-location", + embed_dimension=1408, + additional_kwargs={"additional_kwarg": True}, + ) + + model.get_embeddings.return_value = MultiModalEmbeddingResponse( + _prediction_response=None, image_embedding=[0.1, 0.2, 0.3] + ) + + result = embedding.get_image_embedding("data/test-image.jpg") + self.assertEqual(result, [0.1, 0.2, 0.3]) + + model.get_embeddings.assert_called_once() + positional_args, keyword_args = model.get_embeddings.call_args + + load_file_mock.assert_called_once_with("data/test-image.jpg") + self.assertTrue("image" in keyword_args) + self.assertEqual(keyword_args["dimension"], 1408) + self.assertTrue(keyword_args["additional_kwarg"]) + + @patch("vertexai.init") + @patch("vertexai.vision_models.Image.load_from_file") + @patch("vertexai.vision_models.MultiModalEmbeddingModel.from_pretrained") + def test_image_embedding_bytes( + self, model_mock: Mock, load_file_mock: Mock, init_mock: Mock + ): + model = MagicMock() + model_mock.return_value = model + + embedding = VertexMultiModalEmbedding( + project="test-project", + location="us-test-location", + embed_dimension=1408, + additional_kwargs={"additional_kwarg": True}, + ) + + model.get_embeddings.return_value = MultiModalEmbeddingResponse( + _prediction_response=None, image_embedding=[0.1, 0.2, 0.3] + ) + + image = PillowImage.new("RGB", (128, 128)) + bytes_io = io.BytesIO() + image.save(bytes_io, "jpeg") + bytes_io.seek(0) + + result = embedding.get_image_embedding(bytes_io) + self.assertEqual(result, [0.1, 0.2, 0.3]) + + model.get_embeddings.assert_called_once() + positional_args, keyword_args = model.get_embeddings.call_args + + load_file_mock.assert_not_called() + self.assertEqual(keyword_args["dimension"], 1408) + self.assertTrue(keyword_args["additional_kwarg"]) + + +if __name__ == "__main__": + unittest.main()