Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sqlalchemy vectorizer helper #265

Merged
merged 16 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:

- name: Install dependencies
working-directory: ./projects/pgai
run: uv sync
run: uv sync --all-extras

- name: Lint
run: just pgai lint
Expand Down
184 changes: 184 additions & 0 deletions docs/python-integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SQLAlchemy Integration with pgai Vectorizer

The `vectorizer_relationship` is a SQLAlchemy helper that integrates pgai's vectorization capabilities directly into your SQLAlchemy models.
Think of it as a normal SQLAlchemy [relationship](https://docs.sqlalchemy.org/en/20/orm/basic_relationships.html), but with a preconfigured model instance under the hood.
This allows you to easily query vector embeddings created by pgai using familiar SQLAlchemy patterns.

## Installation

To use the SQLAlchemy integration, install pgai with the SQLAlchemy extras:

```bash
pip install "pgai[sqlalchemy]"
```

## Basic Usage

Here's a basic example of how to use the `vectorizer_relationship`:

```python
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from pgai.sqlalchemy import vectorizer_relationship

class Base(DeclarativeBase):
pass

class BlogPost(Base):
__tablename__ = "blog_posts"

id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str]
content: Mapped[str]

# Add vector embeddings for the content field
content_embeddings = vectorizer_relationship(
dimensions=768
)
```
Note if you work with alembics autogenerate functionality for migrations, also check [Working with alembic](#working-with-alembic).

### Semantic Search

You can then perform semantic similarity search on the field using [pgvector-python's](https://github.com/pgvector/pgvector-python) distance functions:

```python
from sqlalchemy import func, text

similar_posts = (
session.query(BlogPost.content_embeddings)
.order_by(
BlogPost.content_embeddings.embedding.cosine_distance(
func.ai.openai_embed(
"text-embedding-3-small",
"search query",
text("dimensions => 768")
)
)
)
.limit(5)
.all()
)
```

Or if you already have the embeddings in your application:

```python
similar_posts = (
session.query(BlogPost.content_embeddings)
.order_by(
BlogPost.content_embeddings.embedding.cosine_distance(
[3, 1, 2]
)
)
.limit(5)
.all()
)
```

## Configuration

The `vectorizer_relationship` accepts the following parameters:

- `dimensions` (int): The size of the embedding vector (required)
- `target_schema` (str, optional): Override the schema for the embeddings table. If not provided, inherits from the parent model's schema
- `target_table` (str, optional): Override the table name for embeddings. Default is `{table_name}_embedding_store`

Additional parameters are simply forwarded to the underlying [SQLAlchemy relationship](https://docs.sqlalchemy.org/en/20/orm/relationships.html) so you can configure it as you desire.

Think of the `vectorizer_relationship` as a normal SQLAlchemy relationship, but with a preconfigured model instance under the hood.


## Setting up the Vectorizer

After defining your model, you need to create the vectorizer using pgai's SQL functions:

```sql
SELECT ai.create_vectorizer(
'blog_posts'::regclass,
embedding => ai.embedding_openai('text-embedding-3-small', 768),
chunking => ai.chunking_recursive_character_text_splitter(
'content',
50, -- chunk_size
10 -- chunk_overlap
)
);
```

We recommend adding this to a migration script and run it via alembic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add another reference to the Working with alembic section below.



## Querying Embeddings

The `vectorizer_relationship` provides several ways to work with embeddings:

### 1. Direct Access to Embeddings

If you access the class proeprty of your model the `vectorizer_relationship` provide a SQLAlchemy model that you can query directly:

```python
# Get all embeddings
embeddings = session.query(BlogPost.content_embeddings).all()

# Access embedding properties
for embedding in embeddings:
print(embedding.embedding) # The vector embedding
print(embedding.chunk) # The text chunk
```
The model will have the primary key fields of the parent model as well as the following fields:
- `chunk` (str): The text chunk that was embedded
- `embedding` (Vector): The vector embedding
- `chunk_seq` (int): The sequence number of the chunk
- `embedding_uuid` (str): The UUID of the embedding
- `parent` (ParentModel): The parent model instance

### 2. Relationship Access


```python
blog_post = session.query(BlogPost).first()
for embedding in blog_post.content_embeddings:
print(embedding.chunk)
```
Access the original posts through the parent relationship
```python
Askir marked this conversation as resolved.
Show resolved Hide resolved
for embedding in similar_posts:
print(embedding.parent.title)
```

### 3. Join Queries

You can combine embedding queries with regular SQL queries using the relationship:

```python
results = (
session.query(BlogPost, BlogPost.content_embeddings)
.join(BlogPost.content_embeddings)
.filter(BlogPost.title.ilike("%search term%"))
.all()
)

for post, embedding in results:
print(f"Title: {post.title}")
print(f"Chunk: {embedding.chunk}")
```

## Working with alembic


The `vectorizer_relationship` generates a new SQLAlchemy model, that is available under the attribute that you specify. If you are using alembic's autogenerate functionality to generate migrations, you will need to exclude these models from the autogenerate process.
These are added to a list in your metadata called `pgai_managed_tables` and you can exclude them by adding the following to your `env.py`:

```python
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in target_metadata.info.get("pgai_managed_tables", set()):
return False
return True

context.configure(
connection=connection,
target_metadata=target_metadata,
include_object=include_object
)
```

This should now prevent alembic from generating tables for these models when you run `alembic revision --autogenerate`.
184 changes: 184 additions & 0 deletions projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Any, Generic, TypeVar, overload

from pgvector.sqlalchemy import Vector # type: ignore
from sqlalchemy import ForeignKeyConstraint, Integer, Text, event, inspect
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Mapper,
Relationship,
RelationshipProperty,
backref,
mapped_column,
relationship,
)

# Type variable for the parent model
T = TypeVar("T", bound=DeclarativeBase)


def to_pascal_case(text: str):
# Split on any non-alphanumeric character
words = "".join(char if char.isalnum() else " " for char in text).split()
# Capitalize first letter of all words
return "".join(word.capitalize() for word in words)


class EmbeddingModel(DeclarativeBase, Generic[T]):
"""Base type for embedding models with required attributes"""

embedding_uuid: Mapped[str]
chunk: Mapped[str]
embedding: Mapped[Vector]
chunk_seq: Mapped[int]
parent: T # Type of the parent model


class _Vectorizer:
def __init__(
self,
dimensions: int,
target_schema: str | None = None,
target_table: str | None = None,
**kwargs: Any,
):
self.dimensions = dimensions
self.target_schema = target_schema
self.target_table = target_table
self.owner: type[DeclarativeBase] | None = None
self.name: str | None = None
self._embedding_class: type[EmbeddingModel[Any]] | None = None
self._relationship: RelationshipProperty[Any] | None = None
self._initialized = False
self.relationship_args = kwargs
event.listen(Mapper, "after_configured", self._initialize_all)

def _initialize_all(self):
"""Force initialization during mapper configuration"""
if not self._initialized and self.owner is not None:
self.__get__(None, self.owner)

def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None:
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema")
self.target_schema = (
self.target_schema
or table_args_schema_name
or owner.registry.metadata.schema
or "public"
)

def create_embedding_class(
self, owner: type[DeclarativeBase]
) -> type[EmbeddingModel[Any]]:
assert self.name is not None
table_name = self.target_table or f"{owner.__tablename__}_embedding_store"
self.set_schemas_correctly(owner)
class_name = f"{to_pascal_case(self.name)}Embedding"
registry_instance = owner.registry
base: type[DeclarativeBase] = owner.__base__ # type: ignore

# Check if table already exists in metadata
# There is probably a better way to do this
# than accessing the internal _class_registry
# Not doing this ends up in a recursion because
# creating the new class reconfigures tha parent mapper
# again triggering the after_configured event
key = f"{self.target_schema}.{table_name}"
if key in owner.metadata.tables:
# Find the mapped class in the registry
for cls in owner.registry._class_registry.values(): # type: ignore
if hasattr(cls, "__table__") and cls.__table__.fullname == key: # type: ignore
return cls # type: ignore

# Get primary key information from the fully initialized model
mapper = inspect(owner)
pk_cols = mapper.primary_key

# Create the complete class dictionary
class_dict: dict[str, Any] = {
"__tablename__": table_name,
"registry": registry_instance,
# Add all standard columns
"embedding_uuid": mapped_column(Text, primary_key=True),
"chunk": mapped_column(Text, nullable=False),
"embedding": mapped_column(Vector(self.dimensions), nullable=False),
"chunk_seq": mapped_column(Integer, nullable=False),
}

# Add primary key columns to the dictionary
for col in pk_cols:
class_dict[col.name] = mapped_column(col.type, nullable=False)

# Create the table args with foreign key constraint
table_args_dict: dict[str, Any] = dict()
if self.target_schema and self.target_schema != owner.registry.metadata.schema:
table_args_dict["schema"] = self.target_schema

# Create the composite foreign key constraint
fk_constraint = ForeignKeyConstraint(
[col.name for col in pk_cols], # Local columns
[
f"{owner.__tablename__}.{col.name}" for col in pk_cols
], # Referenced columns
ondelete="CASCADE",
)

# Add table args to class dictionary
class_dict["__table_args__"] = (fk_constraint, table_args_dict)

# Create the class using type()
Embedding = type(class_name, (base,), class_dict)

return Embedding # type: ignore

@overload
def __get__(
self, obj: None, objtype: type[DeclarativeBase]
) -> type[EmbeddingModel[Any]]: ...

@overload
def __get__(
self, obj: DeclarativeBase, objtype: type[DeclarativeBase] | None = None
) -> Relationship[EmbeddingModel[Any]]: ...

def __get__(
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None
) -> Relationship[EmbeddingModel[Any]] | type[EmbeddingModel[Any]]:
assert self.name is not None
relationship_name = f"_{self.name}_relationship"
if not self._initialized and objtype is not None:
self._embedding_class = self.create_embedding_class(objtype)

mapper = inspect(objtype)
assert mapper is not None
pk_cols = mapper.primary_key
if not hasattr(objtype, relationship_name):
self.relationship_instance = relationship(
self._embedding_class,
foreign_keys=[
getattr(self._embedding_class, col.name) for col in pk_cols
],
backref=self.relationship_args.pop(
"backref", backref("parent", lazy="select")
),
**self.relationship_args,
)
setattr(objtype, f"{self.name}_model", self._embedding_class)
setattr(objtype, relationship_name, self.relationship_instance)
self._initialized = True
if obj is None and self._initialized:
return self._embedding_class # type: ignore

return getattr(obj, relationship_name)

def __set_name__(self, owner: type[DeclarativeBase], name: str):
self.owner = owner
self.name = name

metadata = owner.registry.metadata
if not hasattr(metadata, "info"):
metadata.info = {}
metadata.info.setdefault("pgai_managed_tables", set()).add(self.target_table)


vectorizer_relationship = _Vectorizer
6 changes: 6 additions & 0 deletions projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ classifiers = [
"Operating System :: POSIX :: Linux",
]

[project.optional-dependencies]
sqlalchemy=[
"sqlalchemy>=2.0.36",
]

[project.urls]
Homepage = "https://github.com/timescale/pgai"
Repository = "https://github.com/timescale/pgai"
Expand Down Expand Up @@ -110,4 +115,5 @@ dev-dependencies = [
"testcontainers==4.8.1",
"build==1.2.2.post1",
"twine==5.1.1",
"psycopg2==2.9.10",
]
Loading
Loading