From 6c4a71e959bb26c37bb6ba01ab15cbb2befae721 Mon Sep 17 00:00:00 2001 From: Jascha Date: Tue, 3 Dec 2024 10:35:35 +0100 Subject: [PATCH] feat: allow arbitrary primary keys on parent model --- projects/pgai/pgai/sqlalchemy/__init__.py | 124 ++++++++++------ .../tests/vectorizer/extensions/conftest.py | 1 - .../vectorizer/extensions/test_sqlalchemy.py | 2 +- .../test_sqlalchemy_composite_primary.py | 138 ++++++++++++++++++ .../test_sqlalchemy_large_embeddings.py | 4 +- .../test_sqlalchemy_relationship.py | 4 + 6 files changed, 227 insertions(+), 46 deletions(-) diff --git a/projects/pgai/pgai/sqlalchemy/__init__.py b/projects/pgai/pgai/sqlalchemy/__init__.py index bc7e4949..d8df7884 100644 --- a/projects/pgai/pgai/sqlalchemy/__init__.py +++ b/projects/pgai/pgai/sqlalchemy/__init__.py @@ -1,7 +1,7 @@ from typing import Any, Generic, TypeVar from pgvector.sqlalchemy import Vector # type: ignore -from sqlalchemy import ForeignKey, Integer, Text +from sqlalchemy import ForeignKeyConstraint, Integer, Text, inspect from sqlalchemy.orm import DeclarativeBase, Mapped, backref, mapped_column, relationship # Type variable for the parent model @@ -11,7 +11,6 @@ def to_pascal_case(text: str): # Split on any non-alphanumeric character words = "".join(char if char.isalnum() else " " for char in text).split() - # Capitalize first letter of all words return "".join(word.capitalize() for word in words) @@ -20,7 +19,6 @@ class EmbeddingModel(DeclarativeBase, Generic[T]): """Base type for embedding models with required attributes""" embedding_uuid: Mapped[str] - id: Mapped[int] chunk: Mapped[str] embedding: Mapped[Vector] chunk_seq: Mapped[int] @@ -36,15 +34,24 @@ def __init__( add_relationship: bool = False, ): self.add_relationship = add_relationship - - # Store table/view configuration self.dimensions = dimensions self.target_schema = target_schema self.target_table = target_table self.owner: type[DeclarativeBase] | None = None self.name: str | None = None - - def set_schemas_correctly(self, owner: type[T]) -> None: + self._embedding_class: type[EmbeddingModel[Any]] | None = None + self._initialized = False + + def _relationship_property( + self, obj: Any = None + ) -> Mapped[list[EmbeddingModel[Any]]]: + # Force initialization if not done yet + if not self._initialized: + _ = self.__get__(obj, self.owner) + # Return the actual relationship + return getattr(obj, f"_{self.name}_relation") + + def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None: table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema") self.target_schema = ( self.target_schema @@ -54,52 +61,87 @@ def set_schemas_correctly(self, owner: type[T]) -> None: ) def create_embedding_class( - self, owner: type[T], name: str - ) -> type[EmbeddingModel[T]]: - table_name = self.target_table or f"{owner.__tablename__}_{name}_store" + self, owner: type[DeclarativeBase] + ) -> type[EmbeddingModel[Any]]: + assert self.name is not None + table_name = self.target_table or f"{owner.__tablename__}_{self.name}_store" self.set_schemas_correctly(owner) - class_name = f"{to_pascal_case(name)}Embedding" + class_name = f"{to_pascal_case(self.name)}Embedding" registry_instance = owner.registry base: type[DeclarativeBase] = owner.__base__ # type: ignore - class Embedding(base): - __tablename__ = table_name - __table_args__ = ( - {"info": {"pgai_managed": True}, "schema": self.target_schema} - if self.target_schema - and self.target_schema != owner.registry.metadata.schema - else {"info": {"pgai_managed": True}} - ) - registry = registry_instance - - embedding_uuid = mapped_column(Text, primary_key=True) - id = mapped_column( - Integer, ForeignKey(f"{owner.__tablename__}.id", ondelete="CASCADE") - ) - chunk = mapped_column(Text, nullable=False) - embedding = mapped_column( - Vector(self.dimensions), nullable=False - ) - chunk_seq = mapped_column(Integer, nullable=False) - - Embedding.__name__ = class_name + # Get primary key information from the fully initialized model + mapper = inspect(owner) + pk_cols = mapper.primary_key + + # Create the complete class dictionary + class_dict: dict[str, Any] = { + "__tablename__": table_name, + "registry": registry_instance, + # Add all standard columns + "embedding_uuid": mapped_column(Text, primary_key=True), + "chunk": mapped_column(Text, nullable=False), + "embedding": mapped_column(Vector(self.dimensions), nullable=False), + "chunk_seq": mapped_column(Integer, nullable=False), + } + + # Add primary key columns to the dictionary + for col in pk_cols: + class_dict[col.name] = mapped_column(col.type, nullable=False) + + # Create the table args with foreign key constraint + table_args_dict: dict[str, Any] = {"info": {"pgai_managed": True}} + if self.target_schema and self.target_schema != owner.registry.metadata.schema: + table_args_dict["schema"] = self.target_schema + + # Create the composite foreign key constraint + fk_constraint = ForeignKeyConstraint( + [col.name for col in pk_cols], # Local columns + [ + f"{owner.__tablename__}.{col.name}" for col in pk_cols + ], # Referenced columns + ondelete="CASCADE", + ) + + # Add table args to class dictionary + class_dict["__table_args__"] = (fk_constraint, table_args_dict) + + # Create the class using type() + Embedding = type(class_name, (base,), class_dict) + return Embedding # type: ignore def __get__( self, obj: DeclarativeBase | None, objtype: type[DeclarativeBase] | None = None ) -> type[EmbeddingModel[Any]]: + if not self._initialized and objtype is not None: + self._embedding_class = self.create_embedding_class(objtype) + + # Set up relationship if requested + if self.add_relationship: + mapper = inspect(objtype) + pk_cols = mapper.primary_key + + relationship_instance = relationship( + self._embedding_class, + foreign_keys=[ + getattr(self._embedding_class, col.name) for col in pk_cols + ], + backref=backref("parent", lazy="select"), + ) + # Store actual relationship under a private name + setattr(objtype, f"_{self.name}_relation", relationship_instance) + + self._initialized = True + + if self._embedding_class is None: + raise RuntimeError("Embedding class not properly initialized") + return self._embedding_class def __set_name__(self, owner: type[DeclarativeBase], name: str): self.owner = owner self.name = name - self._embedding_class = self.create_embedding_class(owner, name) - - # Set up relationship if self.add_relationship: - relationship_instance = relationship( - self._embedding_class, - foreign_keys=[self._embedding_class.id], - backref=backref("parent", lazy="select"), - ) - setattr(owner, f"{name}_relation", relationship_instance) + # Add the property that ensures initialization + setattr(owner, f"{name}_relation", property(self._relationship_property)) diff --git a/projects/pgai/tests/vectorizer/extensions/conftest.py b/projects/pgai/tests/vectorizer/extensions/conftest.py index 9893da9e..3733575f 100644 --- a/projects/pgai/tests/vectorizer/extensions/conftest.py +++ b/projects/pgai/tests/vectorizer/extensions/conftest.py @@ -44,4 +44,3 @@ def initialized_engine( with engine.connect() as conn: conn.execute(text("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")) conn.commit() - diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py index 92827da8..abb64747 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py @@ -127,7 +127,7 @@ class BlogPost(Base): session.query(BlogPost, BlogPost.content_embeddings) .join( BlogPost.content_embeddings, - BlogPost.id == BlogPost.content_embeddings.id, + BlogPost.id == BlogPost.content_embeddings.id, # type: ignore ) .filter(BlogPost.title.ilike("%Python%")) .all() diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py index e69de29b..4841b2e5 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py @@ -0,0 +1,138 @@ +import numpy as np +from click.testing import CliRunner +from sqlalchemy import Column, Engine, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, Session +from sqlalchemy.sql import text +from testcontainers.postgres import PostgresContainer # type: ignore + +from pgai.cli import vectorizer_worker +from pgai.sqlalchemy import EmbeddingModel, VectorizerField + + +class Base(DeclarativeBase): + pass + + +class Author(Base): + __tablename__ = "authors" + first_name = Column(Text, primary_key=True) + last_name = Column(Text, primary_key=True) + bio = Column(Text, nullable=False) + bio_embeddings = VectorizerField( + dimensions=768, + add_relationship=True, + ) + + bio_embeddings_relation: Mapped[list[EmbeddingModel["Author"]]] + + +def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None: + CliRunner().invoke( + vectorizer_worker, + [ + "--db-url", + db_url, + "--once", + "--vectorizer-id", + str(vectorizer_id), + "--concurrency", + "1", + ], + catch_exceptions=False, + ) + + +def test_vectorizer_composite_key( + postgres_container: PostgresContainer, initialized_engine: Engine +): + """Test vectorizer with a composite primary key.""" + db_url = postgres_container.get_connection_url() + + # Create tables + metadata = Author.metadata + metadata.create_all(initialized_engine, tables=[metadata.sorted_tables[0]]) + + # Create vectorizer + with initialized_engine.connect() as conn: + conn.execute( + text(""" + SELECT ai.create_vectorizer( + 'authors'::regclass, + target_table => 'authors_bio_embeddings_store', + embedding => ai.embedding_openai('text-embedding-3-small', 768), + chunking => + ai.chunking_recursive_character_text_splitter('bio', 50, 10) + ); + """) + ) + conn.commit() + + # Insert test data + with Session(initialized_engine) as session: + author = Author( + first_name="Jane", + last_name="Doe", + bio="Jane is an accomplished researcher in artificial intelligence and machine learning. She has published numerous papers on neural networks.", # noqa + ) + session.add(author) + session.commit() + + # Run vectorizer worker + run_vectorizer_worker(db_url, 1) + + # Verify embeddings were created + with Session(initialized_engine) as session: + # Verify embedding class was created correctly + assert Author.bio_embeddings.__name__ == "BioEmbeddingsEmbedding" + + # Check embeddings exist and have correct properties + embedding = session.query(Author.bio_embeddings).first() + assert embedding is not None + assert isinstance(embedding.embedding, np.ndarray) + assert len(embedding.embedding) == 768 + assert embedding.chunk is not None + assert isinstance(embedding.chunk, str) + + # Check composite key fields were created + assert hasattr(embedding, "first_name") + assert hasattr(embedding, "last_name") + assert embedding.first_name == "Jane" # type: ignore + assert embedding.last_name == "Doe" # type: ignore + + # Verify relationship works + author = session.query(Author).first() + assert author is not None + assert hasattr(author, "bio_embeddings_relation") + assert author.bio_embeddings_relation is not None + assert len(author.bio_embeddings_relation) > 0 + assert author.bio_embeddings_relation[0].chunk in author.bio + + # Test that parent relationship works + embedding_entity = session.query(Author.bio_embeddings).first() + assert embedding_entity is not None + assert embedding_entity.chunk in author.bio + assert embedding_entity.parent is not None + assert embedding_entity.parent.first_name == "Jane" + assert embedding_entity.parent.last_name == "Doe" + + # Test semantic search with composite keys + from sqlalchemy import func + + # Search for content similar to "machine learning" + similar_embeddings = ( + session.query(Author.bio_embeddings) + .order_by( + Author.bio_embeddings.embedding.cosine_distance( + func.ai.openai_embed( + "text-embedding-3-small", + "machine learning", + text("dimensions => 768"), + ) + ) + ) + .all() + ) + + assert len(similar_embeddings) > 0 + # The bio should contain machine learning related content + assert "machine learning" in similar_embeddings[0].parent.bio diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py index 9bf9eacd..c4787d31 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py @@ -18,9 +18,7 @@ class BlogPost(Base): id = Column(Integer, primary_key=True) title = Column(Text, nullable=False) content = Column(Text, nullable=False) - content_embeddings = VectorizerField( - dimensions=1536 - ) + content_embeddings = VectorizerField(dimensions=1536) def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None: diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py index 62fe0a42..b9328d41 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py @@ -80,6 +80,10 @@ def test_vectorizer_embedding_creation( # Verify embeddings were created with Session(initialized_engine) as session: # Verify embedding class was created correctly + + blog_post = session.query(BlogPost).first() + assert blog_post is not None + assert blog_post.content_embeddings_relation is not None assert BlogPost.content_embeddings.__name__ == "ContentEmbeddingsEmbedding" # Check embeddings exist and have correct properties