-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add sqlalchemy vectorizer helper #265
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
ce234bd
feat: add sqlalchemy vectorizer field
Askir 39d2d35
chore: add installing extras to ci
Askir 203a4bc
chore: simplify interface, add simple docs
Askir 3d75279
feat: allow arbitrary primary keys on parent model
Askir 34991b1
docs: update docs with simplified vectorizer field
Askir 43caade
chore: rename VectorizerField to Vectorizer
Askir a2c59e4
chore: update alembic exclusion mechanism
Askir a35be41
docs: update docs with review comments
Askir 37b662e
chore: align automatic table name with create_vectorizer
Askir df75e7e
chore: add option for any relationship properties
Askir 9a1fb56
chore: setup class event based rather than lazy so relationship works…
Askir 023b0f4
chore: update to embedding_relationship
Askir ddb37ff
chore: refactor tests add vcr mocks
Askir 717e23d
chore: rename to vectorizer_model; cleanup
Askir 882f91e
chore: fix uv lock
Askir 22233a0
chore: remove dummy key
Askir File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# SQLAlchemy Integration with pgai Vectorizer | ||
|
||
The `vectorizer_relationship` is a SQLAlchemy helper that integrates pgai's vectorization capabilities directly into your SQLAlchemy models. | ||
Think of it as a normal SQLAlchemy [relationship](https://docs.sqlalchemy.org/en/20/orm/basic_relationships.html), but with a preconfigured model instance under the hood. | ||
This allows you to easily query vector embeddings created by pgai using familiar SQLAlchemy patterns. | ||
|
||
## Installation | ||
|
||
To use the SQLAlchemy integration, install pgai with the SQLAlchemy extras: | ||
|
||
```bash | ||
pip install "pgai[sqlalchemy]" | ||
``` | ||
|
||
## Basic Usage | ||
|
||
Here's a basic example of how to use the `vectorizer_relationship`: | ||
|
||
```python | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
from pgai.sqlalchemy import vectorizer_relationship | ||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
class BlogPost(Base): | ||
__tablename__ = "blog_posts" | ||
|
||
id: Mapped[int] = mapped_column(primary_key=True) | ||
title: Mapped[str] | ||
content: Mapped[str] | ||
|
||
# Add vector embeddings for the content field | ||
content_embeddings = vectorizer_relationship( | ||
dimensions=768 | ||
) | ||
``` | ||
Note if you work with alembics autogenerate functionality for migrations, also check [Working with alembic](#working-with-alembic). | ||
|
||
### Semantic Search | ||
|
||
You can then perform semantic similarity search on the field using [pgvector-python's](https://github.com/pgvector/pgvector-python) distance functions: | ||
|
||
```python | ||
from sqlalchemy import func, text | ||
|
||
similar_posts = ( | ||
session.query(BlogPost.content_embeddings) | ||
.order_by( | ||
BlogPost.content_embeddings.embedding.cosine_distance( | ||
func.ai.openai_embed( | ||
"text-embedding-3-small", | ||
"search query", | ||
text("dimensions => 768") | ||
) | ||
) | ||
) | ||
.limit(5) | ||
.all() | ||
) | ||
``` | ||
|
||
Or if you already have the embeddings in your application: | ||
|
||
```python | ||
similar_posts = ( | ||
session.query(BlogPost.content_embeddings) | ||
.order_by( | ||
BlogPost.content_embeddings.embedding.cosine_distance( | ||
[3, 1, 2] | ||
) | ||
) | ||
.limit(5) | ||
.all() | ||
) | ||
``` | ||
|
||
## Configuration | ||
|
||
The `vectorizer_relationship` accepts the following parameters: | ||
|
||
- `dimensions` (int): The size of the embedding vector (required) | ||
- `target_schema` (str, optional): Override the schema for the embeddings table. If not provided, inherits from the parent model's schema | ||
- `target_table` (str, optional): Override the table name for embeddings. Default is `{table_name}_embedding_store` | ||
|
||
Additional parameters are simply forwarded to the underlying [SQLAlchemy relationship](https://docs.sqlalchemy.org/en/20/orm/relationships.html) so you can configure it as you desire. | ||
|
||
Think of the `vectorizer_relationship` as a normal SQLAlchemy relationship, but with a preconfigured model instance under the hood. | ||
|
||
|
||
## Setting up the Vectorizer | ||
|
||
After defining your model, you need to create the vectorizer using pgai's SQL functions: | ||
|
||
```sql | ||
SELECT ai.create_vectorizer( | ||
'blog_posts'::regclass, | ||
embedding => ai.embedding_openai('text-embedding-3-small', 768), | ||
chunking => ai.chunking_recursive_character_text_splitter( | ||
'content', | ||
50, -- chunk_size | ||
10 -- chunk_overlap | ||
) | ||
); | ||
``` | ||
|
||
We recommend adding this to a migration script and run it via alembic. | ||
|
||
|
||
## Querying Embeddings | ||
|
||
The `vectorizer_relationship` provides several ways to work with embeddings: | ||
|
||
### 1. Direct Access to Embeddings | ||
|
||
If you access the class proeprty of your model the `vectorizer_relationship` provide a SQLAlchemy model that you can query directly: | ||
|
||
```python | ||
# Get all embeddings | ||
embeddings = session.query(BlogPost.content_embeddings).all() | ||
|
||
# Access embedding properties | ||
for embedding in embeddings: | ||
print(embedding.embedding) # The vector embedding | ||
print(embedding.chunk) # The text chunk | ||
``` | ||
The model will have the primary key fields of the parent model as well as the following fields: | ||
- `chunk` (str): The text chunk that was embedded | ||
- `embedding` (Vector): The vector embedding | ||
- `chunk_seq` (int): The sequence number of the chunk | ||
- `embedding_uuid` (str): The UUID of the embedding | ||
- `parent` (ParentModel): The parent model instance | ||
|
||
### 2. Relationship Access | ||
|
||
|
||
```python | ||
blog_post = session.query(BlogPost).first() | ||
for embedding in blog_post.content_embeddings: | ||
print(embedding.chunk) | ||
``` | ||
Access the original posts through the parent relationship | ||
```python | ||
Askir marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for embedding in similar_posts: | ||
print(embedding.parent.title) | ||
``` | ||
|
||
### 3. Join Queries | ||
|
||
You can combine embedding queries with regular SQL queries using the relationship: | ||
|
||
```python | ||
results = ( | ||
session.query(BlogPost, BlogPost.content_embeddings) | ||
.join(BlogPost.content_embeddings) | ||
.filter(BlogPost.title.ilike("%search term%")) | ||
.all() | ||
) | ||
|
||
for post, embedding in results: | ||
print(f"Title: {post.title}") | ||
print(f"Chunk: {embedding.chunk}") | ||
``` | ||
|
||
## Working with alembic | ||
|
||
|
||
The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process. | ||
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`: | ||
|
||
```python | ||
def include_object(object, name, type_, reflected, compare_to): | ||
if type_ == "table" and name in target_metadata.info.get("pgai_managed_tables", set()): | ||
return False | ||
return True | ||
|
||
context.configure( | ||
connection=connection, | ||
target_metadata=target_metadata, | ||
include_object=include_object | ||
) | ||
``` | ||
|
||
This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import Any, Generic, TypeVar, overload | ||
|
||
from pgvector.sqlalchemy import Vector # type: ignore | ||
from sqlalchemy import ForeignKeyConstraint, Integer, Text, event, inspect | ||
from sqlalchemy.orm import ( | ||
DeclarativeBase, | ||
Mapped, | ||
Mapper, | ||
Relationship, | ||
RelationshipProperty, | ||
backref, | ||
mapped_column, | ||
relationship, | ||
) | ||
|
||
# Type variable for the parent model | ||
T = TypeVar("T", bound=DeclarativeBase) | ||
|
||
|
||
def to_pascal_case(text: str): | ||
# Split on any non-alphanumeric character | ||
words = "".join(char if char.isalnum() else " " for char in text).split() | ||
# Capitalize first letter of all words | ||
return "".join(word.capitalize() for word in words) | ||
|
||
|
||
class EmbeddingModel(DeclarativeBase, Generic[T]): | ||
"""Base type for embedding models with required attributes""" | ||
|
||
embedding_uuid: Mapped[str] | ||
chunk: Mapped[str] | ||
embedding: Mapped[Vector] | ||
chunk_seq: Mapped[int] | ||
parent: T # Type of the parent model | ||
|
||
|
||
class _Vectorizer: | ||
def __init__( | ||
self, | ||
dimensions: int, | ||
target_schema: str | None = None, | ||
target_table: str | None = None, | ||
**kwargs: Any, | ||
): | ||
self.dimensions = dimensions | ||
self.target_schema = target_schema | ||
self.target_table = target_table | ||
self.owner: type[DeclarativeBase] | None = None | ||
self.name: str | None = None | ||
self._embedding_class: type[EmbeddingModel[Any]] | None = None | ||
self._relationship: RelationshipProperty[Any] | None = None | ||
self._initialized = False | ||
self.relationship_args = kwargs | ||
event.listen(Mapper, "after_configured", self._initialize_all) | ||
|
||
def _initialize_all(self): | ||
"""Force initialization during mapper configuration""" | ||
if not self._initialized and self.owner is not None: | ||
self.__get__(None, self.owner) | ||
|
||
def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None: | ||
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema") | ||
self.target_schema = ( | ||
self.target_schema | ||
or table_args_schema_name | ||
or owner.registry.metadata.schema | ||
or "public" | ||
) | ||
|
||
def create_embedding_class( | ||
self, owner: type[DeclarativeBase] | ||
) -> type[EmbeddingModel[Any]]: | ||
assert self.name is not None | ||
table_name = self.target_table or f"{owner.__tablename__}_embedding_store" | ||
self.set_schemas_correctly(owner) | ||
class_name = f"{to_pascal_case(self.name)}Embedding" | ||
registry_instance = owner.registry | ||
base: type[DeclarativeBase] = owner.__base__ # type: ignore | ||
|
||
# Check if table already exists in metadata | ||
# There is probably a better way to do this | ||
# than accessing the internal _class_registry | ||
# Not doing this ends up in a recursion because | ||
# creating the new class reconfigures tha parent mapper | ||
# again triggering the after_configured event | ||
key = f"{self.target_schema}.{table_name}" | ||
if key in owner.metadata.tables: | ||
# Find the mapped class in the registry | ||
for cls in owner.registry._class_registry.values(): # type: ignore | ||
if hasattr(cls, "__table__") and cls.__table__.fullname == key: # type: ignore | ||
return cls # type: ignore | ||
|
||
# Get primary key information from the fully initialized model | ||
mapper = inspect(owner) | ||
pk_cols = mapper.primary_key | ||
|
||
# Create the complete class dictionary | ||
class_dict: dict[str, Any] = { | ||
"__tablename__": table_name, | ||
"registry": registry_instance, | ||
# Add all standard columns | ||
"embedding_uuid": mapped_column(Text, primary_key=True), | ||
"chunk": mapped_column(Text, nullable=False), | ||
"embedding": mapped_column(Vector(self.dimensions), nullable=False), | ||
"chunk_seq": mapped_column(Integer, nullable=False), | ||
} | ||
|
||
# Add primary key columns to the dictionary | ||
for col in pk_cols: | ||
class_dict[col.name] = mapped_column(col.type, nullable=False) | ||
|
||
# Create the table args with foreign key constraint | ||
table_args_dict: dict[str, Any] = dict() | ||
if self.target_schema and self.target_schema != owner.registry.metadata.schema: | ||
table_args_dict["schema"] = self.target_schema | ||
|
||
# Create the composite foreign key constraint | ||
fk_constraint = ForeignKeyConstraint( | ||
[col.name for col in pk_cols], # Local columns | ||
[ | ||
f"{owner.__tablename__}.{col.name}" for col in pk_cols | ||
], # Referenced columns | ||
ondelete="CASCADE", | ||
) | ||
|
||
# Add table args to class dictionary | ||
class_dict["__table_args__"] = (fk_constraint, table_args_dict) | ||
|
||
# Create the class using type() | ||
Embedding = type(class_name, (base,), class_dict) | ||
|
||
return Embedding # type: ignore | ||
|
||
@overload | ||
def __get__( | ||
self, obj: None, objtype: type[DeclarativeBase] | ||
) -> type[EmbeddingModel[Any]]: ... | ||
|
||
@overload | ||
def __get__( | ||
self, obj: DeclarativeBase, objtype: type[DeclarativeBase] | None = None | ||
) -> Relationship[EmbeddingModel[Any]]: ... | ||
|
||
def __get__( | ||
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None | ||
) -> Relationship[EmbeddingModel[Any]] | type[EmbeddingModel[Any]]: | ||
assert self.name is not None | ||
relationship_name = f"_{self.name}_relationship" | ||
if not self._initialized and objtype is not None: | ||
self._embedding_class = self.create_embedding_class(objtype) | ||
|
||
mapper = inspect(objtype) | ||
assert mapper is not None | ||
pk_cols = mapper.primary_key | ||
if not hasattr(objtype, relationship_name): | ||
self.relationship_instance = relationship( | ||
self._embedding_class, | ||
foreign_keys=[ | ||
getattr(self._embedding_class, col.name) for col in pk_cols | ||
], | ||
backref=self.relationship_args.pop( | ||
"backref", backref("parent", lazy="select") | ||
), | ||
**self.relationship_args, | ||
) | ||
setattr(objtype, f"{self.name}_model", self._embedding_class) | ||
setattr(objtype, relationship_name, self.relationship_instance) | ||
self._initialized = True | ||
if obj is None and self._initialized: | ||
return self._embedding_class # type: ignore | ||
|
||
return getattr(obj, relationship_name) | ||
|
||
def __set_name__(self, owner: type[DeclarativeBase], name: str): | ||
self.owner = owner | ||
self.name = name | ||
|
||
metadata = owner.registry.metadata | ||
if not hasattr(metadata, "info"): | ||
metadata.info = {} | ||
metadata.info.setdefault("pgai_managed_tables", set()).add(self.target_table) | ||
|
||
|
||
vectorizer_relationship = _Vectorizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add another reference to the Working with alembic section below.