Skip to content

Commit

Permalink
feat: allow arbitrary primary keys on parent model
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Dec 4, 2024
1 parent fdd2f48 commit 6c4a71e
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 46 deletions.
124 changes: 83 additions & 41 deletions projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Generic, TypeVar

from pgvector.sqlalchemy import Vector # type: ignore
from sqlalchemy import ForeignKey, Integer, Text
from sqlalchemy import ForeignKeyConstraint, Integer, Text, inspect
from sqlalchemy.orm import DeclarativeBase, Mapped, backref, mapped_column, relationship

# Type variable for the parent model
Expand All @@ -11,7 +11,6 @@
def to_pascal_case(text: str):
# Split on any non-alphanumeric character
words = "".join(char if char.isalnum() else " " for char in text).split()

# Capitalize first letter of all words
return "".join(word.capitalize() for word in words)

Expand All @@ -20,7 +19,6 @@ class EmbeddingModel(DeclarativeBase, Generic[T]):
"""Base type for embedding models with required attributes"""

embedding_uuid: Mapped[str]
id: Mapped[int]
chunk: Mapped[str]
embedding: Mapped[Vector]
chunk_seq: Mapped[int]
Expand All @@ -36,15 +34,24 @@ def __init__(
add_relationship: bool = False,
):
self.add_relationship = add_relationship

# Store table/view configuration
self.dimensions = dimensions
self.target_schema = target_schema
self.target_table = target_table
self.owner: type[DeclarativeBase] | None = None
self.name: str | None = None

def set_schemas_correctly(self, owner: type[T]) -> None:
self._embedding_class: type[EmbeddingModel[Any]] | None = None
self._initialized = False

def _relationship_property(
self, obj: Any = None
) -> Mapped[list[EmbeddingModel[Any]]]:
# Force initialization if not done yet
if not self._initialized:
_ = self.__get__(obj, self.owner)
# Return the actual relationship
return getattr(obj, f"_{self.name}_relation")

def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None:
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema")
self.target_schema = (
self.target_schema
Expand All @@ -54,52 +61,87 @@ def set_schemas_correctly(self, owner: type[T]) -> None:
)

def create_embedding_class(
self, owner: type[T], name: str
) -> type[EmbeddingModel[T]]:
table_name = self.target_table or f"{owner.__tablename__}_{name}_store"
self, owner: type[DeclarativeBase]
) -> type[EmbeddingModel[Any]]:
assert self.name is not None
table_name = self.target_table or f"{owner.__tablename__}_{self.name}_store"
self.set_schemas_correctly(owner)
class_name = f"{to_pascal_case(name)}Embedding"
class_name = f"{to_pascal_case(self.name)}Embedding"
registry_instance = owner.registry
base: type[DeclarativeBase] = owner.__base__ # type: ignore

class Embedding(base):
__tablename__ = table_name
__table_args__ = (
{"info": {"pgai_managed": True}, "schema": self.target_schema}
if self.target_schema
and self.target_schema != owner.registry.metadata.schema
else {"info": {"pgai_managed": True}}
)
registry = registry_instance

embedding_uuid = mapped_column(Text, primary_key=True)
id = mapped_column(
Integer, ForeignKey(f"{owner.__tablename__}.id", ondelete="CASCADE")
)
chunk = mapped_column(Text, nullable=False)
embedding = mapped_column(
Vector(self.dimensions), nullable=False
)
chunk_seq = mapped_column(Integer, nullable=False)

Embedding.__name__ = class_name
# Get primary key information from the fully initialized model
mapper = inspect(owner)
pk_cols = mapper.primary_key

# Create the complete class dictionary
class_dict: dict[str, Any] = {
"__tablename__": table_name,
"registry": registry_instance,
# Add all standard columns
"embedding_uuid": mapped_column(Text, primary_key=True),
"chunk": mapped_column(Text, nullable=False),
"embedding": mapped_column(Vector(self.dimensions), nullable=False),
"chunk_seq": mapped_column(Integer, nullable=False),
}

# Add primary key columns to the dictionary
for col in pk_cols:
class_dict[col.name] = mapped_column(col.type, nullable=False)

# Create the table args with foreign key constraint
table_args_dict: dict[str, Any] = {"info": {"pgai_managed": True}}
if self.target_schema and self.target_schema != owner.registry.metadata.schema:
table_args_dict["schema"] = self.target_schema

# Create the composite foreign key constraint
fk_constraint = ForeignKeyConstraint(
[col.name for col in pk_cols], # Local columns
[
f"{owner.__tablename__}.{col.name}" for col in pk_cols
], # Referenced columns
ondelete="CASCADE",
)

# Add table args to class dictionary
class_dict["__table_args__"] = (fk_constraint, table_args_dict)

# Create the class using type()
Embedding = type(class_name, (base,), class_dict)

return Embedding # type: ignore

def __get__(
self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None
) -> type[EmbeddingModel[Any]]:
if not self._initialized and objtype is not None:
self._embedding_class = self.create_embedding_class(objtype)

# Set up relationship if requested
if self.add_relationship:
mapper = inspect(objtype)
pk_cols = mapper.primary_key

relationship_instance = relationship(
self._embedding_class,
foreign_keys=[
getattr(self._embedding_class, col.name) for col in pk_cols
],
backref=backref("parent", lazy="select"),
)
# Store actual relationship under a private name
setattr(objtype, f"_{self.name}_relation", relationship_instance)

self._initialized = True

if self._embedding_class is None:
raise RuntimeError("Embedding class not properly initialized")

return self._embedding_class

def __set_name__(self, owner: type[DeclarativeBase], name: str):
self.owner = owner
self.name = name
self._embedding_class = self.create_embedding_class(owner, name)

# Set up relationship
if self.add_relationship:
relationship_instance = relationship(
self._embedding_class,
foreign_keys=[self._embedding_class.id],
backref=backref("parent", lazy="select"),
)
setattr(owner, f"{name}_relation", relationship_instance)
# Add the property that ensures initialization
setattr(owner, f"{name}_relation", property(self._relationship_property))
1 change: 0 additions & 1 deletion projects/pgai/tests/vectorizer/extensions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ def initialized_engine(
with engine.connect() as conn:
conn.execute(text("DROP SCHEMA public CASCADE; CREATE SCHEMA public;"))
conn.commit()

Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class BlogPost(Base):
session.query(BlogPost, BlogPost.content_embeddings)
.join(
BlogPost.content_embeddings,
BlogPost.id == BlogPost.content_embeddings.id,
BlogPost.id == BlogPost.content_embeddings.id, # type: ignore
)
.filter(BlogPost.title.ilike("%Python%"))
.all()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np
from click.testing import CliRunner
from sqlalchemy import Column, Engine, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
from sqlalchemy.sql import text
from testcontainers.postgres import PostgresContainer # type: ignore

from pgai.cli import vectorizer_worker
from pgai.sqlalchemy import EmbeddingModel, VectorizerField


class Base(DeclarativeBase):
pass


class Author(Base):
__tablename__ = "authors"
first_name = Column(Text, primary_key=True)
last_name = Column(Text, primary_key=True)
bio = Column(Text, nullable=False)
bio_embeddings = VectorizerField(
dimensions=768,
add_relationship=True,
)

bio_embeddings_relation: Mapped[list[EmbeddingModel["Author"]]]


def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None:
CliRunner().invoke(
vectorizer_worker,
[
"--db-url",
db_url,
"--once",
"--vectorizer-id",
str(vectorizer_id),
"--concurrency",
"1",
],
catch_exceptions=False,
)


def test_vectorizer_composite_key(
postgres_container: PostgresContainer, initialized_engine: Engine
):
"""Test vectorizer with a composite primary key."""
db_url = postgres_container.get_connection_url()

# Create tables
metadata = Author.metadata
metadata.create_all(initialized_engine, tables=[metadata.sorted_tables[0]])

# Create vectorizer
with initialized_engine.connect() as conn:
conn.execute(
text("""
SELECT ai.create_vectorizer(
'authors'::regclass,
target_table => 'authors_bio_embeddings_store',
embedding => ai.embedding_openai('text-embedding-3-small', 768),
chunking =>
ai.chunking_recursive_character_text_splitter('bio', 50, 10)
);
""")
)
conn.commit()

# Insert test data
with Session(initialized_engine) as session:
author = Author(
first_name="Jane",
last_name="Doe",
bio="Jane is an accomplished researcher in artificial intelligence and machine learning. She has published numerous papers on neural networks.", # noqa
)
session.add(author)
session.commit()

# Run vectorizer worker
run_vectorizer_worker(db_url, 1)

# Verify embeddings were created
with Session(initialized_engine) as session:
# Verify embedding class was created correctly
assert Author.bio_embeddings.__name__ == "BioEmbeddingsEmbedding"

# Check embeddings exist and have correct properties
embedding = session.query(Author.bio_embeddings).first()
assert embedding is not None
assert isinstance(embedding.embedding, np.ndarray)
assert len(embedding.embedding) == 768
assert embedding.chunk is not None
assert isinstance(embedding.chunk, str)

# Check composite key fields were created
assert hasattr(embedding, "first_name")
assert hasattr(embedding, "last_name")
assert embedding.first_name == "Jane" # type: ignore
assert embedding.last_name == "Doe" # type: ignore

# Verify relationship works
author = session.query(Author).first()
assert author is not None
assert hasattr(author, "bio_embeddings_relation")
assert author.bio_embeddings_relation is not None
assert len(author.bio_embeddings_relation) > 0
assert author.bio_embeddings_relation[0].chunk in author.bio

# Test that parent relationship works
embedding_entity = session.query(Author.bio_embeddings).first()
assert embedding_entity is not None
assert embedding_entity.chunk in author.bio
assert embedding_entity.parent is not None
assert embedding_entity.parent.first_name == "Jane"
assert embedding_entity.parent.last_name == "Doe"

# Test semantic search with composite keys
from sqlalchemy import func

# Search for content similar to "machine learning"
similar_embeddings = (
session.query(Author.bio_embeddings)
.order_by(
Author.bio_embeddings.embedding.cosine_distance(
func.ai.openai_embed(
"text-embedding-3-small",
"machine learning",
text("dimensions => 768"),
)
)
)
.all()
)

assert len(similar_embeddings) > 0
# The bio should contain machine learning related content
assert "machine learning" in similar_embeddings[0].parent.bio
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ class BlogPost(Base):
id = Column(Integer, primary_key=True)
title = Column(Text, nullable=False)
content = Column(Text, nullable=False)
content_embeddings = VectorizerField(
dimensions=1536
)
content_embeddings = VectorizerField(dimensions=1536)


def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def test_vectorizer_embedding_creation(
# Verify embeddings were created
with Session(initialized_engine) as session:
# Verify embedding class was created correctly

blog_post = session.query(BlogPost).first()
assert blog_post is not None
assert blog_post.content_embeddings_relation is not None
assert BlogPost.content_embeddings.__name__ == "ContentEmbeddingsEmbedding"

# Check embeddings exist and have correct properties
Expand Down

0 comments on commit 6c4a71e

Please sign in to comment.