Skip to content

Commit

Permalink
feat: allow users to configure a base_url for the vectorizer OpenAI e…
Browse files Browse the repository at this point in the history
…mbedder (#351)
  • Loading branch information
smoya authored Jan 14, 2025
1 parent 596b9a6 commit 66ceb3d
Show file tree
Hide file tree
Showing 18 changed files with 2,070 additions and 990 deletions.
1 change: 1 addition & 0 deletions docs/vectorizer-api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ The function takes several parameters to customize the OpenAI embedding configur
| dimensions | int | - | ✔ | Define the number of dimensions for the embedding vectors. This should match the output dimensions of the chosen model. |
| chat_user | text | - | ✖ | The identifier for the user making the API call. This can be useful for tracking API usage or for OpenAI's monitoring purposes. |
| api_key_name | text | `OPENAI_API_KEY` | ✖ | Set [the name of the environment variable that contains the OpenAI API key][openai-use-env-var]. This allows for flexible API key management without hardcoding keys in the database. On Timescale Cloud, you should set this to the name of the secret that contains the OpenAI API key. |
| base_url | text | - | ✖ | Set the base_url of the OpenAI API. Note: no default configured here to allow configuration of the vectorizer worker through `OPENAI_BASE_URL` env var. |
#### Returns
A JSON configuration object that you can use in [ai.create_vectorizer](#create-vectorizers).
Expand Down
2 changes: 2 additions & 0 deletions projects/extension/sql/idempotent/008-embedding.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ create or replace function ai.embedding_openai
, dimensions pg_catalog.int4
, chat_user pg_catalog.text default null
, api_key_name pg_catalog.text default 'OPENAI_API_KEY'
, base_url text default null
) returns pg_catalog.jsonb
as $func$
select json_object
Expand All @@ -15,6 +16,7 @@ as $func$
, 'dimensions': dimensions
, 'user': chat_user
, 'api_key_name': api_key_name
, 'base_url': base_url
absent on null
)
$func$ language sql immutable security invoker
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

-- dropping in favour of the new signature (adding base_url param)
drop function if exists ai.embedding_openai(text,integer,text,text);
2 changes: 1 addition & 1 deletion projects/extension/tests/contents/output16.expected
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE EXTENSION
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_openai(text,integer,text,text,text)
function ai.embedding_voyageai(text,integer,text,text)
function ai.enable_vectorizer_schedule(integer)
function ai.execute_vectorizer(integer)
Expand Down
2 changes: 1 addition & 1 deletion projects/extension/tests/contents/output17.expected
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE EXTENSION
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_openai(text,integer,text,text,text)
function ai.embedding_voyageai(text,integer,text,text)
function ai.enable_vectorizer_schedule(integer)
function ai.execute_vectorizer(integer)
Expand Down
8 changes: 4 additions & 4 deletions projects/extension/tests/privileges/function.expected
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@
f | bob | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | fred | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | jill | execute | YES | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | alice | execute | YES | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text)
f | bob | execute | no | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text)
f | fred | execute | no | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text)
f | jill | execute | YES | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text)
f | alice | execute | YES | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text, base_url text)
f | bob | execute | no | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text, base_url text)
f | fred | execute | no | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text, base_url text)
f | jill | execute | YES | ai | embedding_openai(model text, dimensions integer, chat_user text, api_key_name text, base_url text)
f | alice | execute | YES | ai | embedding_voyageai(model text, dimensions integer, input_type text, api_key_name text)
f | bob | execute | no | ai | embedding_voyageai(model text, dimensions integer, input_type text, api_key_name text)
f | fred | execute | no | ai | embedding_voyageai(model text, dimensions integer, input_type text, api_key_name text)
Expand Down
5 changes: 2 additions & 3 deletions projects/pgai/pgai/vectorizer/embedders/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import TypedDict, override

from ..embeddings import (
BaseURLMixin,
BatchApiCaller,
Embedder,
EmbeddingResponse,
Expand Down Expand Up @@ -58,22 +59,20 @@ class OllamaOptions(TypedDict, total=False):
stop: Sequence[str]


class Ollama(BaseModel, Embedder):
class Ollama(BaseModel, BaseURLMixin, Embedder):
"""
Embedder that uses Ollama to embed documents into vector representations.
Attributes:
implementation (Literal["ollama"]): The literal identifier for this
implementation.
model (str): The name of the Ollama model used for embeddings.
base_url (str): The base url used to access the Ollama API.
options (dict): Additional ollama-specific runtime options
keep_alive (str): How long to keep the model loaded after the request
"""

implementation: Literal["ollama"]
model: str
base_url: str | None = None
options: OllamaOptions | None = None
keep_alive: str | None = None # this is only `str` because of the SQL API

Expand Down
7 changes: 5 additions & 2 deletions projects/pgai/pgai/vectorizer/embedders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..embeddings import (
ApiKeyMixin,
BaseURLMixin,
BatchApiCaller,
ChunkEmbeddingError,
Embedder,
Expand All @@ -29,7 +30,7 @@
)


class OpenAI(ApiKeyMixin, BaseModel, Embedder):
class OpenAI(ApiKeyMixin, BaseURLMixin, BaseModel, Embedder):
"""
Embedder that uses OpenAI's API to embed documents into vector representations.
Expand Down Expand Up @@ -60,7 +61,9 @@ def _openai_user(self) -> str | openai.NotGiven:

@cached_property
def _embedder(self) -> resources.AsyncEmbeddings:
return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3).embeddings
return openai.AsyncOpenAI(
base_url=self.base_url, api_key=self._api_key, max_retries=3
).embeddings

@override
def _max_chunks_per_batch(self) -> int:
Expand Down
11 changes: 11 additions & 0 deletions projects/pgai/pgai/vectorizer/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ async def setup(self) -> None: # noqa: B027 empty on purpose
"""


class BaseURLMixin:
"""
A mixin class that provides functionality for managing base URLs.
Attributes:
base_url (str | None): The base URL for the API.
"""

base_url: str | None = None


class ApiKeyMixin:
"""
A mixin class that provides functionality for managing API keys.
Expand Down
1 change: 1 addition & 0 deletions projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,5 @@ dev-dependencies = [
"testcontainers==4.8.1",
"build==1.2.2.post1",
"twine==5.1.1",
"mitmproxy>=11.0.2",
]
Loading

0 comments on commit 66ceb3d

Please sign in to comment.