Skip to content

Commit

Permalink
Gemini (#1520)
Browse files Browse the repository at this point in the history
This adds a Google Gemini embedding function and an RAG chat example 

TODO
- [x] JS support
- [x] Docs PR
  • Loading branch information
jeffchuber authored Dec 15, 2023
1 parent 3939974 commit 99c0e9f
Show file tree
Hide file tree
Showing 12 changed files with 1,944 additions and 15 deletions.
64 changes: 54 additions & 10 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,11 @@ def __init__(
api_key=api_key,
api_version=api_version,
azure_endpoint=api_base,
default_headers=default_headers
default_headers=default_headers,
).embeddings
else:
self._client = openai.OpenAI(
api_key=api_key,
base_url=api_base,
default_headers=default_headers
api_key=api_key, base_url=api_base, default_headers=default_headers
).embeddings
else:
self._client = openai.Embedding
Expand Down Expand Up @@ -209,7 +207,9 @@ def __call__(self, input: Documents) -> Embeddings:
# Call Cohere Embedding API for each document.
return [
embeddings
for embeddings in self._client.embed(texts=input, model=self._model_name, input_type="search_document")
for embeddings in self._client.embed(
texts=input, model=self._model_name, input_type="search_document"
)
]


Expand Down Expand Up @@ -260,9 +260,7 @@ class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""

def __init__(
self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"
):
def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"):
"""
Initialize the JinaEmbeddingFunction.
Expand All @@ -271,9 +269,11 @@ def __init__(
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en".
"""
self._model_name = model_name
self._api_url = 'https://api.jina.ai/v1/embeddings'
self._api_url = "https://api.jina.ai/v1/embeddings"
self._session = requests.Session()
self._session.headers.update({"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"})
self._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)

def __call__(self, input: Documents) -> Embeddings:
"""
Expand Down Expand Up @@ -552,6 +552,50 @@ def __call__(self, input: Documents) -> Embeddings:
]


class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key."""

"""Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval."""

def __init__(
self,
api_key: str,
model_name: str = "models/embedding-001",
task_type: str = "RETRIEVAL_DOCUMENT",
):
if not api_key:
raise ValueError("Please provide a Google API key.")

if not model_name:
raise ValueError("Please provide the model name.")

try:
import google.generativeai as genai
except ImportError:
raise ValueError(
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
)

genai.configure(api_key=api_key)
self._genai = genai
self._model_name = model_name
self._task_type = task_type
self._task_title = None
if self._task_type is "RETRIEVAL_DOCUMENT":
self._task_title = "Embedding of single string"

def __call__(self, input: Documents) -> Embeddings:
return [
self._genai.embed_content(
model=self._model_name,
content=text,
task_type=self._task_type,
title=self._task_title,
)["embedding"]
for text in input
]


class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
# Follow API Quickstart for Google Vertex AI
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
Expand Down
13 changes: 9 additions & 4 deletions clients/js/examples/node/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ var path = require("path");

var express = require("express");
var chroma = require("chromadb");
var openai = require("openai");

var app = express();
app.get("/", async (req, res) => {
const cc = new chroma.ChromaClient({ path: "http://localhost:8000" });
await cc.reset();

const openAIembedder = new chroma.OpenAIEmbeddingFunction("key")
const cohereAIEmbedder = new chroma.OpenAIEmbeddingFunction({ openai_api_key: "API_KEY" });
const google = new chroma.GoogleGenerativeAiEmbeddingFunction({ googleApiKey:"<APIKEY>" });

const collection = await cc.createCollection({
name: "test-from-js",
embeddingFunction: cohereAIEmbedder,
embeddingFunction: google,
});

await collection.add({
Expand All @@ -29,6 +27,13 @@ app.get("/", async (req, res) => {
let count = await collection.count();
console.log("count", count);

const googleQuery = new chroma.GoogleGenerativeAiEmbeddingFunction({ googleApiKey:"<APIKEY>", taskType: 'RETRIEVAL_QUERY' });

const queryCollection = await cc.getCollection({
name: "test-from-js",
embeddingFunction: googleQuery,
});

const query = await collection.query({
queryTexts: ["doc1"],
nResults: 1
Expand Down
1 change: 1 addition & 0 deletions clients/js/examples/node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"author": "",
"license": "ISC",
"dependencies": {
"@google/generative-ai": "^0.1.1",
"chromadb": "file:../..",
"cohere-ai": "^5.0.2",
"express": "^4.18.2",
Expand Down
116 changes: 115 additions & 1 deletion clients/js/examples/node/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# yarn lockfile v1


"@google/generative-ai@^0.1.1":
version "0.1.1"
resolved "https://registry.yarnpkg.com/@google/generative-ai/-/generative-ai-0.1.1.tgz#ecf0cd832620527f0e35c3aecc17c058d8ba52b8"
integrity sha512-cbzKa8mT9YkTrT4XUuENIuvlqiJjwDgcD2Ks4L99Az9dWLgdXn8xnETEAZLOpqzoGx+1PuATZqlUnVRAeLbMgA==

accepts@~1.3.8:
version "1.3.8"
resolved "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz"
Expand All @@ -10,6 +15,18 @@ accepts@~1.3.8:
mime-types "~2.1.34"
negotiator "0.6.3"

ansi-regex@^5.0.1:
version "5.0.1"
resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304"
integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==

ansi-styles@^4.0.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937"
integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==
dependencies:
color-convert "^2.0.1"

[email protected]:
version "1.1.1"
resolved "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz"
Expand Down Expand Up @@ -59,13 +76,37 @@ call-bind@^1.0.0:
get-intrinsic "^1.0.2"

"chromadb@file:../..":
version "1.5.0"
version "1.7.1-beta2"
dependencies:
cliui "^8.0.1"
isomorphic-fetch "^3.0.0"

cliui@^8.0.1:
version "8.0.1"
resolved "https://registry.yarnpkg.com/cliui/-/cliui-8.0.1.tgz#0c04b075db02cbfe60dc8e6cf2f5486b1a3608aa"
integrity sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==
dependencies:
string-width "^4.2.0"
strip-ansi "^6.0.1"
wrap-ansi "^7.0.0"

cohere-ai@^5.0.2:
version "5.0.2"
resolved "https://registry.npmjs.org/cohere-ai/-/cohere-ai-5.0.2.tgz"
integrity sha512-Svt8VC20/GgwCBF2kHYZI3JZkfqEoG6wCbTT6tohNK8x/aBFyMxlBUYEF0gRGXH1055vQpBjj5ewHF8LpnSSOA==

color-convert@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3"
integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==
dependencies:
color-name "~1.1.4"

color-name@~1.1.4:
version "1.1.4"
resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2"
integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==

combined-stream@^1.0.8:
version "1.0.8"
resolved "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz"
Expand Down Expand Up @@ -122,6 +163,11 @@ [email protected]:
resolved "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz"
integrity sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==

emoji-regex@^8.0.0:
version "8.0.0"
resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37"
integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==

encodeurl@~1.0.2:
version "1.0.2"
resolved "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz"
Expand Down Expand Up @@ -265,6 +311,19 @@ [email protected]:
resolved "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz"
integrity sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==

is-fullwidth-code-point@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d"
integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==

isomorphic-fetch@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/isomorphic-fetch/-/isomorphic-fetch-3.0.0.tgz#0267b005049046d2421207215d45d6a262b8b8b4"
integrity sha512-qvUtwJ3j6qwsF3jLxkZ72qCgjMysPzDfeV240JHiGZsANBYd+EEuu35v7dfrJ9Up0Ak07D7GGSkGhCHTqg/5wA==
dependencies:
node-fetch "^2.6.1"
whatwg-fetch "^3.4.1"

[email protected]:
version "0.3.0"
resolved "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz"
Expand Down Expand Up @@ -312,6 +371,13 @@ [email protected]:
resolved "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz"
integrity sha512-+EUsqGPLsM+j/zdChZjsnX51g4XrHFOIXwfnCVPGlQk/k5giakcKsuxCObBRu6DSm9opw/O6slWbJdghQM4bBg==

node-fetch@^2.6.1:
version "2.7.0"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.7.0.tgz#d0f0fa6e3e2dc1d27efcd8ad99d550bda94d187d"
integrity sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==
dependencies:
whatwg-url "^5.0.0"

object-inspect@^1.9.0:
version "1.12.3"
resolved "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz"
Expand Down Expand Up @@ -430,11 +496,32 @@ [email protected]:
resolved "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz"
integrity sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==

string-width@^4.1.0, string-width@^4.2.0:
version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
dependencies:
emoji-regex "^8.0.0"
is-fullwidth-code-point "^3.0.0"
strip-ansi "^6.0.1"

strip-ansi@^6.0.0, strip-ansi@^6.0.1:
version "6.0.1"
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
dependencies:
ansi-regex "^5.0.1"

[email protected]:
version "1.0.1"
resolved "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz"
integrity sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==

tr46@~0.0.3:
version "0.0.3"
resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a"
integrity sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==

type-is@~1.6.18:
version "1.6.18"
resolved "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz"
Expand All @@ -457,3 +544,30 @@ vary@~1.1.2:
version "1.1.2"
resolved "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz"
integrity sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==

webidl-conversions@^3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"
integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==

whatwg-fetch@^3.4.1:
version "3.6.20"
resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz#580ce6d791facec91d37c72890995a0b48d31c70"
integrity sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==

whatwg-url@^5.0.0:
version "5.0.0"
resolved "https://registry.yarnpkg.com/whatwg-url/-/whatwg-url-5.0.0.tgz#966454e8765462e37644d3626f6742ce8b70965d"
integrity sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==
dependencies:
tr46 "~0.0.3"
webidl-conversions "^3.0.0"

wrap-ansi@^7.0.0:
version "7.0.0"
resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43"
integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==
dependencies:
ansi-styles "^4.0.0"
string-width "^4.1.0"
strip-ansi "^6.0.0"
69 changes: 69 additions & 0 deletions clients/js/src/embeddings/GoogleGeminiEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

let googleGenAiApi: any;

export class GoogleGenerativeAiEmbeddingFunction implements IEmbeddingFunction {
private api_key: string;
private model: string;
private googleGenAiApi?: any;
private taskType: string;

constructor({ googleApiKey, model, taskType }: { googleApiKey: string, model?: string, taskType?: string }) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.api_key = googleApiKey;
this.model = model || "embedding-001";
this.taskType = taskType || "RETRIEVAL_DOCUMENT";
}

private async loadClient() {
if(this.googleGenAiApi) return;
try {
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
const { googleGenAi } = await GoogleGenerativeAiEmbeddingFunction.import();
googleGenAiApi = googleGenAi;
// googleGenAiApi.init(this.api_key);
googleGenAiApi = new googleGenAiApi(this.api_key);
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the @google/generative-ai package to use the GoogleGenerativeAiEmbeddingFunction, `npm install -S @google/generative-ai`");
}
throw _a; // Re-throw other errors
}
this.googleGenAiApi = googleGenAiApi;
}

public async generate(texts: string[]) {

await this.loadClient();
const model = this.googleGenAiApi.getGenerativeModel({ model: this.model});
const response = await model.batchEmbedContents({
requests: texts.map((t) => ({
content: { parts: [{ text: t }] },
taskType: this.taskType,
})),
});
const embeddings = response.embeddings.map((e: any) => e.values);

return embeddings;
}

/** @ignore */
static async import(): Promise<{
// @ts-ignore
googleGenAi: typeof import("@google/generative-ai");
}> {
try {
// @ts-ignore
const { GoogleGenerativeAI } = await import("@google/generative-ai");
const googleGenAi = GoogleGenerativeAI;
return { googleGenAi };
} catch (e) {
throw new Error(
"Please install @google/generative-ai as a dependency with, e.g. `yarn add @google/generative-ai`"
);
}
}

}
Loading

0 comments on commit 99c0e9f

Please sign in to comment.