From 9a7835b330bd9444b57be18dc693f8156de56293 Mon Sep 17 00:00:00 2001 From: Iain Date: Mon, 16 Dec 2024 12:50:11 +0100 Subject: [PATCH 01/27] feat: Add a readme for the pgai docs. --- docs/README.md | 65 +++++++++++++++++++ docs/moderate.md | 2 +- docs/privileges.md | 1 + ...vectorizer-add-a-embedding-integration.md} | 2 +- 4 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 docs/README.md rename docs/{adding-embedding-integration.md => vectorizer-add-a-embedding-integration.md} (98%) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..7ad2bfb8 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,65 @@ + +

+ pgai +

+ +
+ +

pgai documentation

+ +[![Discord](https://img.shields.io/badge/Join_us_on_Discord-black?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/KRdHVXAmkp) +[![Try Timescale for free](https://img.shields.io/badge/Try_Timescale_for_free-black?style=for-the-badge&logo=timescale&logoColor=white)](https://tsdb.co/gh-pgai-signup) +
+ +pgai is a PostgreSQL extension that simplifies data storage and retrieval for [Retrieval Augmented Generation](https://en.wikipedia.org/wiki/Prompt_engineering#Retrieval-augmented_generation) (RAG), and other AI applications. +In particular, it automates the creation and sync of embeddings for your data stored in PostgreSQL, simplifies +[semantic search](https://en.wikipedia.org/wiki/Semantic_search), and allows you to call LLM models from SQL. + +The pgai documentation helps you setup, use and develop the projects that make up pgai. + + +## Vectorizer + +Vectorizer automates the embedding process within your database management by treating embeddings as a declarative, +DDL-like feature — like an index. + +- **Get started**: + * [Vectorizer quickstart](/docs/vectorizer-quick-start.md): setup your developer environment, create and run a vectorizer. + * [Vectorizer quickstart for OpenAI](/docs/vectorizer-quick-start-openai.md): setup your developer environment, create and run a vectorizer using OpenAI. + * [Vectorizer quickstart for Voyage](/docs/vectorizer-quick-start-voyage.md): setup your developer environment, create and run a vectorizer using Voyage. +- **Use**: + * [Automate AI embedding with pgai Vectorizer](/docs/vectorizer.md): a comprehensive overview of Vectorizer features, + demonstrating how it streamlines the process of working with vector embeddings in your database. + * [Run vectorizers using pgai vectorizer worker](/docs/vectorizer-worker.md): run vectorizers on a self-hosted TimescaleDB instance. +- **Develop**: + * [Add a Vectorizer embedding integration](/docs/vectorizer-add-a-embedding-integration.md): +- **Reference**: + * [pgai Vectorizer API reference](/docs/vectorizer-api-reference.md): API reference for Vectorizer functions + +## pgai + +Simplifies data storage and retrieval for AI apps. + +- **Get started**: + * [Install pgai with Docker](/docs/install_docker.md): run pgai in a container environment. + * [Setup pgai with Anthropic](/docs/anthropic.md): configure pgai to connect to your Anthropic account. + * [Setup pgai with Cohere](/docs/cohere.md): configure pgai to connect to your Cohere account. + * [Setup pgai with Ollama](/docs/ollama.md): configure pgai to connect to your Ollama account. + * [Setup pgai with OpenAI](/docs/openai.md): configure pgai to connect to your OpenAI account. + * [Setup pgai with Voyage AI](/docs/voyageai.md): configure pgai to connect to your Voyage AI account. +- **Use**: + * [Delayed embed](/docs/delayed_embed.md): run pgai using pgai or TimescaleDB background actions. + * [Load dataset from Hugging Face](/docs/load_dataset_from_huggingface.md): load datasets from Hugging Face's datasets library directly into your PostgreSQL database. + * [Moderate comments using OpenAI](/docs/moderate.md): use triggers or actions to moderate comments using OpenAI. + * [Secure pgai with user privilages](/docs/privileges.md): grant the necessary permissions for a specific user or role to use pgai functionality. +- **Develop**: + * [Install pgai from source](/docs/install_from_source.md): create an environment to develop pgai. + + + + + + + + + diff --git a/docs/moderate.md b/docs/moderate.md index c4882e5f..4623ad3b 100644 --- a/docs/moderate.md +++ b/docs/moderate.md @@ -1,4 +1,4 @@ -# Moderate +# Moderate comments using OpenAI Let's say you want to moderate comments using OpenAI. You can do it in two ways: diff --git a/docs/privileges.md b/docs/privileges.md index 0d6175dd..d9c95268 100644 --- a/docs/privileges.md +++ b/docs/privileges.md @@ -1,3 +1,4 @@ +# Secure pgai with user privilages The ai.grant_ai_usage function is an important security and access control tool in the pgai extension. Its primary purpose is to grant the necessary permissions diff --git a/docs/adding-embedding-integration.md b/docs/vectorizer-add-a-embedding-integration.md similarity index 98% rename from docs/adding-embedding-integration.md rename to docs/vectorizer-add-a-embedding-integration.md index 5beac515..ae0d1d6e 100644 --- a/docs/adding-embedding-integration.md +++ b/docs/vectorizer-add-a-embedding-integration.md @@ -1,4 +1,4 @@ -# Adding a Vectorizer embedding integration +# Add a Vectorizer embedding integration We welcome contributions to add new vectorizer embedding integrations. From 8b033e848e342b753058ae1dfe199fe2e910e44e Mon Sep 17 00:00:00 2001 From: Iain Cox Date: Mon, 16 Dec 2024 16:38:51 +0100 Subject: [PATCH 02/27] Apply suggestions from code review Co-authored-by: Matvey Arye Signed-off-by: Iain Cox --- docs/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/README.md b/docs/README.md index 7ad2bfb8..337fdd9b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -18,13 +18,13 @@ In particular, it automates the creation and sync of embeddings for your data st The pgai documentation helps you setup, use and develop the projects that make up pgai. -## Vectorizer +## pgai Vectorizer Vectorizer automates the embedding process within your database management by treating embeddings as a declarative, DDL-like feature — like an index. - **Get started**: - * [Vectorizer quickstart](/docs/vectorizer-quick-start.md): setup your developer environment, create and run a vectorizer. + * [Vectorizer quickstart for Ollama](/docs/vectorizer-quick-start.md): setup your developer environment, create and run a vectorizer. * [Vectorizer quickstart for OpenAI](/docs/vectorizer-quick-start-openai.md): setup your developer environment, create and run a vectorizer using OpenAI. * [Vectorizer quickstart for Voyage](/docs/vectorizer-quick-start-voyage.md): setup your developer environment, create and run a vectorizer using Voyage. - **Use**: @@ -36,7 +36,7 @@ DDL-like feature — like an index. - **Reference**: * [pgai Vectorizer API reference](/docs/vectorizer-api-reference.md): API reference for Vectorizer functions -## pgai +## pgai model calling Simplifies data storage and retrieval for AI apps. From c779dbcaf8baf210033918a86532df4e4c2db258 Mon Sep 17 00:00:00 2001 From: Iain Date: Mon, 16 Dec 2024 16:57:31 +0100 Subject: [PATCH 03/27] feat: update on review. --- docs/README.md | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/README.md b/docs/README.md index 337fdd9b..73252105 100644 --- a/docs/README.md +++ b/docs/README.md @@ -36,24 +36,32 @@ DDL-like feature — like an index. - **Reference**: * [pgai Vectorizer API reference](/docs/vectorizer-api-reference.md): API reference for Vectorizer functions +## pgai install + +* [Install pgai with Docker](/docs/install_docker.md): run pgai in a container environment. +* [Install pgai from source](/docs/install_from_source.md): create a developer environment for pgai. + ## pgai model calling Simplifies data storage and retrieval for AI apps. -- **Get started**: - * [Install pgai with Docker](/docs/install_docker.md): run pgai in a container environment. - * [Setup pgai with Anthropic](/docs/anthropic.md): configure pgai to connect to your Anthropic account. - * [Setup pgai with Cohere](/docs/cohere.md): configure pgai to connect to your Cohere account. - * [Setup pgai with Ollama](/docs/ollama.md): configure pgai to connect to your Ollama account. - * [Setup pgai with OpenAI](/docs/openai.md): configure pgai to connect to your OpenAI account. - * [Setup pgai with Voyage AI](/docs/voyageai.md): configure pgai to connect to your Voyage AI account. +- **Choose your model**: + + | **Model** | **Tokenize** | **Embed** | **Chat Complete** | **Generate** | **Moderate** | **Classify** | **Rerank** | + |------------------|:------------:|:---------:|:-----------------:|:------------:|:------------:|:------------:|:----------:| + | **[Ollama](/docs/ollama.md)** | | ✔️ | ✔️ | ✔️ | | | | + | **[OpenAI](/docs/openai.md)** | ✔️️ | ✔️ | ✔️ | | ✔️ | | | + | **[Anthropic](/docs/anthropic.md)** | | | | ✔️ | | | | + | **[Cohere](/docs/cohere.md)** | ✔️ | ✔️ | ✔️ | | | ✔️ | ✔️ | + | **[Voyage AI](/docs/voyageai.md)** | | ✔️ | | | | | | + + - **Use**: * [Delayed embed](/docs/delayed_embed.md): run pgai using pgai or TimescaleDB background actions. * [Load dataset from Hugging Face](/docs/load_dataset_from_huggingface.md): load datasets from Hugging Face's datasets library directly into your PostgreSQL database. * [Moderate comments using OpenAI](/docs/moderate.md): use triggers or actions to moderate comments using OpenAI. * [Secure pgai with user privilages](/docs/privileges.md): grant the necessary permissions for a specific user or role to use pgai functionality. -- **Develop**: - * [Install pgai from source](/docs/install_from_source.md): create an environment to develop pgai. + From c374cdc5c81280f03359b4216539d6afcd47c5d1 Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Mon, 16 Dec 2024 14:30:31 +0100 Subject: [PATCH 04/27] chore: fix broken pgai build by pinning hatchling (#308) The build started failing with the following error message: ``` Checking ./dist/pgai-0.3.0-py3-none-any.whl: ERROR InvalidDistribution: Metadata is missing required fields: Name, Version. Make sure the distribution includes the files where those fields are specified, and is using a supported Metadata-Version: 1.0, 1.1, 1.2, 2.0, 2.1, 2.2, 2.3. ``` This error comes from `twine check dist/*`. It is caused by the fact that `hatchling`, which we use to build wheels, released a new version which bumped the `Metadata-Version` field to `2.4`, and twine/pkginfo aren't able to process this version. By pinning `hatchling` to 1.26.3, the previous release, we solve this immediate problem, and avoid new breakage from future hatchling releases. --- projects/pgai/pyproject.toml | 2 +- projects/pgai/uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/pgai/pyproject.toml b/projects/pgai/pyproject.toml index b7c6e8cc..5d6da0d3 100644 --- a/projects/pgai/pyproject.toml +++ b/projects/pgai/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling==1.26.3"] build-backend = "hatchling.build" [project] diff --git a/projects/pgai/uv.lock b/projects/pgai/uv.lock index 398af021..386ef1a5 100644 --- a/projects/pgai/uv.lock +++ b/projects/pgai/uv.lock @@ -1161,7 +1161,7 @@ wheels = [ [[package]] name = "pgai" -version = "0.1.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "click" }, From 3ced1f6b6bfa1f728f7579223a10a4de2db57850 Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Mon, 16 Dec 2024 15:54:14 +0100 Subject: [PATCH 05/27] chore: support uv in extension install, use for dev (#309) We would like to support uv in an opt-in fashion. The `build.py install` now uses uv if it is available, otherwise falling back to pip. To use uv, we now install it in our dev containers. --- projects/extension/Dockerfile | 9 +++++---- projects/extension/build.py | 22 +++++++++++++++++----- projects/extension/requirements-test.txt | 1 + 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/projects/extension/Dockerfile b/projects/extension/Dockerfile index fefbfd95..470b297b 100644 --- a/projects/extension/Dockerfile +++ b/projects/extension/Dockerfile @@ -47,6 +47,7 @@ RUN set -e; \ FROM base AS pgai-test-db ENV PG_MAJOR=${PG_MAJOR} ENV PIP_BREAK_SYSTEM_PACKAGES=1 +RUN pip install uv==0.5.9 WORKDIR /pgai COPY . . RUN just build install @@ -58,17 +59,17 @@ FROM base ENV WHERE_AM_I=docker USER root +RUN pip install --break-system-packages uv==0.5.9 + # install pgspot -ENV PIP_BREAK_SYSTEM_PACKAGES=1 RUN set -eux; \ git clone https://github.com/timescale/pgspot.git /build/pgspot; \ - pip install /build/pgspot; \ + uv pip install --system --break-system-packages /build/pgspot; \ rm -rf /build/pgspot # install our test python dependencies -ENV PIP_BREAK_SYSTEM_PACKAGES=1 COPY requirements-test.txt /build/requirements-test.txt -RUN pip install -r /build/requirements-test.txt +RUN uv pip install --system --break-system-packages -r /build/requirements-test.txt RUN rm -r /build WORKDIR /pgai diff --git a/projects/extension/build.py b/projects/extension/build.py index fc0adeee..0a0dbd10 100755 --- a/projects/extension/build.py +++ b/projects/extension/build.py @@ -451,9 +451,13 @@ def install_old_py_deps() -> None: old_reqs_file = ext_dir().joinpath("old_requirements.txt").resolve() if old_reqs_file.is_file(): env = {k: v for k, v in os.environ.items()} - env["PIP_BREAK_SYSTEM_PACKAGES"] = "1" + cmd = ( + f"pip3 install -v --compile --break-system-packages -r {old_reqs_file}" + if shutil.which("uv") is None + else f"uv pip install -v --compile --system --break-system-packages -r {old_reqs_file}" + ) subprocess.run( - f"pip3 install -v --compile -r {old_reqs_file}", + cmd, shell=True, check=True, env=env, @@ -482,8 +486,10 @@ def install_prior_py() -> None: env=os.environ, ) tmp_src_dir = tmp_dir.joinpath("projects", "extension").resolve() + bin = "pip3" if shutil.which("uv") is None else "uv pip" + cmd = f'{bin} install -v --compile --target "{version_target_dir}" "{tmp_src_dir}"' subprocess.run( - f'pip3 install -v --compile -t "{version_target_dir}" "{tmp_src_dir}"', + cmd, check=True, shell=True, env=os.environ, @@ -524,8 +530,10 @@ def install_py() -> None: "pgai-*.dist-info" ): # delete package info if exists shutil.rmtree(d) + bin = "pip3" if shutil.which("uv") is None else "uv pip" + cmd = f'{bin} install -v --no-deps --compile --target "{version_target_dir}" "{ext_dir()}"' subprocess.run( - f'pip3 install -v --no-deps --compile -t "{version_target_dir}" "{ext_dir()}"', + cmd, check=True, shell=True, env=os.environ, @@ -533,8 +541,12 @@ def install_py() -> None: ) else: version_target_dir.mkdir(exist_ok=True) + bin = "pip3" if shutil.which("uv") is None else "uv pip" + cmd = ( + f'{bin} install -v --compile --target "{version_target_dir}" "{ext_dir()}"' + ) subprocess.run( - f'pip3 install -v --compile -t "{version_target_dir}" "{ext_dir()}"', + cmd, check=True, shell=True, env=os.environ, diff --git a/projects/extension/requirements-test.txt b/projects/extension/requirements-test.txt index de44c0f1..d6ec9c94 100644 --- a/projects/extension/requirements-test.txt +++ b/projects/extension/requirements-test.txt @@ -7,3 +7,4 @@ python-dotenv==1.0.1 fastapi==0.112.0 fastapi-cli==0.0.5 psycopg[binary]==3.2.1 +uv==0.5.9 \ No newline at end of file From f6af7f0a9b306e8efbdd613e7e779999fb030b69 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Mon, 2 Dec 2024 08:32:21 -0600 Subject: [PATCH 06/27] feat: add a semantic catalog for db objs and example sql --- .../idempotent/900-semantic-catalog-init.sql | 61 ++++ .../idempotent/901-semantic-catalog-set.sql | 137 ++++++++ .../902-semantic-catalog-event-triggers.sql | 256 ++++++++++++++ .../sql/idempotent/903-post-restore.sql | 117 +++++++ .../sql/idempotent/999-privileges.sql | 10 +- .../sql/incremental/900-semantic-catalog.sql | 31 ++ .../extension/tests/text_to_sql/.gitignore | 4 + .../extension/tests/text_to_sql/0.expected | 13 + .../extension/tests/text_to_sql/1.expected | 13 + .../extension/tests/text_to_sql/10.expected | 9 + .../extension/tests/text_to_sql/11.expected | 8 + .../extension/tests/text_to_sql/12.expected | 5 + .../extension/tests/text_to_sql/13.expected | 4 + .../extension/tests/text_to_sql/2.expected | 13 + .../extension/tests/text_to_sql/3.expected | 13 + .../extension/tests/text_to_sql/4.expected | 13 + .../extension/tests/text_to_sql/5.expected | 9 + .../extension/tests/text_to_sql/6.expected | 8 + .../extension/tests/text_to_sql/7.expected | 8 + .../extension/tests/text_to_sql/8.expected | 8 + .../extension/tests/text_to_sql/9.expected | 9 + .../extension/tests/text_to_sql/__init__.py | 0 .../extension/tests/text_to_sql/extra.sql | 67 ++++ projects/extension/tests/text_to_sql/init.sql | 77 +++++ .../text_to_sql/snapshot-catalog.expected | 291 ++++++++++++++++ .../tests/text_to_sql/snapshot-catalog.sql | 20 ++ .../extension/tests/text_to_sql/snapshot.sql | 73 ++++ .../tests/text_to_sql/test_dump_restore.py | 182 ++++++++++ .../tests/text_to_sql/test_text_to_sql.py | 326 ++++++++++++++++++ 29 files changed, 1784 insertions(+), 1 deletion(-) create mode 100644 projects/extension/sql/idempotent/900-semantic-catalog-init.sql create mode 100644 projects/extension/sql/idempotent/901-semantic-catalog-set.sql create mode 100644 projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql create mode 100644 projects/extension/sql/idempotent/903-post-restore.sql create mode 100644 projects/extension/sql/incremental/900-semantic-catalog.sql create mode 100644 projects/extension/tests/text_to_sql/.gitignore create mode 100644 projects/extension/tests/text_to_sql/0.expected create mode 100644 projects/extension/tests/text_to_sql/1.expected create mode 100644 projects/extension/tests/text_to_sql/10.expected create mode 100644 projects/extension/tests/text_to_sql/11.expected create mode 100644 projects/extension/tests/text_to_sql/12.expected create mode 100644 projects/extension/tests/text_to_sql/13.expected create mode 100644 projects/extension/tests/text_to_sql/2.expected create mode 100644 projects/extension/tests/text_to_sql/3.expected create mode 100644 projects/extension/tests/text_to_sql/4.expected create mode 100644 projects/extension/tests/text_to_sql/5.expected create mode 100644 projects/extension/tests/text_to_sql/6.expected create mode 100644 projects/extension/tests/text_to_sql/7.expected create mode 100644 projects/extension/tests/text_to_sql/8.expected create mode 100644 projects/extension/tests/text_to_sql/9.expected create mode 100644 projects/extension/tests/text_to_sql/__init__.py create mode 100644 projects/extension/tests/text_to_sql/extra.sql create mode 100644 projects/extension/tests/text_to_sql/init.sql create mode 100644 projects/extension/tests/text_to_sql/snapshot-catalog.expected create mode 100644 projects/extension/tests/text_to_sql/snapshot-catalog.sql create mode 100644 projects/extension/tests/text_to_sql/snapshot.sql create mode 100644 projects/extension/tests/text_to_sql/test_dump_restore.py create mode 100644 projects/extension/tests/text_to_sql/test_text_to_sql.py diff --git a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql new file mode 100644 index 00000000..10a6936d --- /dev/null +++ b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql @@ -0,0 +1,61 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- initialize_semantic_catalog +create or replace function ai.initialize_semantic_catalog +( "name" pg_catalog.name default 'default' +, embedding pg_catalog.jsonb default null +, indexing pg_catalog.jsonb default ai.indexing_default() +, scheduling pg_catalog.jsonb default ai.scheduling_default() +, processing pg_catalog.jsonb default ai.processing_default() +, grant_to pg_catalog.name[] default ai.grant_to() +) returns pg_catalog.int4 +as $func$ +declare + _catalog_id pg_catalog.int4; + _obj_vec_id pg_catalog.int4; + _sql_vec_id pg_catalog.int4; +begin + insert into ai.semantic_catalog("name") + values (initialize_semantic_catalog."name") + returning id + into strict _catalog_id + ; + + select ai.create_vectorizer + ( 'ai.semantic_catalog_obj'::pg_catalog.regclass + , destination=>pg_catalog.format('semantic_catalog_obj_%s', _catalog_id) + , embedding=>embedding + , indexing=>indexing + , scheduling=>scheduling + , processing=>processing + , grant_to=>grant_to + , formatting=>ai.formatting_python_template() -- TODO: this ain't gonna work + , chunking=>ai.chunking_recursive_character_text_splitter('description') -- TODO + ) into strict _obj_vec_id + ; + + select ai.create_vectorizer + ( 'ai.semantic_catalog_sql'::pg_catalog.regclass + , destination=>pg_catalog.format('semantic_catalog_sql_%s', _catalog_id) + , embedding=>embedding + , indexing=>indexing + , scheduling=>scheduling + , processing=>processing + , grant_to=>grant_to + , formatting=>ai.formatting_python_template() -- TODO: this ain't gonna work + , chunking=>ai.chunking_recursive_character_text_splitter('description') -- TODO + ) into strict _sql_vec_id + ; + + update ai.semantic_catalog set + obj_vectorizer_id = _obj_vec_id + , sql_vectorizer_id = _sql_vec_id + where id operator(pg_catalog.=) _catalog_id + ; + + return _catalog_id; +end; +$func$ language plpgsql volatile security definer -- definer on purpose! +set search_path to pg_catalog, pg_temp +; diff --git a/projects/extension/sql/idempotent/901-semantic-catalog-set.sql b/projects/extension/sql/idempotent/901-semantic-catalog-set.sql new file mode 100644 index 00000000..e20cbad6 --- /dev/null +++ b/projects/extension/sql/idempotent/901-semantic-catalog-set.sql @@ -0,0 +1,137 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- add_sql_example +create or replace function ai.add_sql_example +( sql pg_catalog.text +, description pg_catalog.text +) returns int +as $func$ + insert into ai.semantic_catalog_sql (sql, description) + values (trim(sql), trim(description)) + returning id +$func$ language sql volatile security invoker +set search_path to pg_catalog, pg_temp; + +------------------------------------------------------------------------------- +-- _set_description +create or replace function ai._set_description +( classid pg_catalog.oid +, objid pg_catalog.oid +, objsubid pg_catalog.int4 +, description pg_catalog.text +) returns void +as $func$ + insert into ai.semantic_catalog_obj + ( objtype + , objnames + , objargs + , classid + , objid + , objsubid + , description + ) + select + x."type" + , x.object_names + , x.object_args + , classid + , objid + , objsubid + , _set_description.description + from pg_catalog.pg_identify_object_as_address + ( classid + , objid + , objsubid + ) x + on conflict (objtype, objnames, objargs) + do update set description = _set_description.description + ; +$func$ language sql volatile security invoker +set search_path to pg_catalog, pg_temp; + +------------------------------------------------------------------------------- +-- set_description +create or replace function ai.set_description +( relation pg_catalog.regclass +, description pg_catalog.text +) returns void +as $func$ +declare + _classid pg_catalog.oid; + _objid pg_catalog.oid; + _objsubid pg_catalog.int4; + _relkind pg_catalog."char"; +begin + _classid = 'pg_catalog.pg_class'::pg_catalog.regclass::pg_catalog.oid; + _objid = relation::pg_catalog.oid; + _objsubid = 0; + + select k.relkind into strict _relkind + from pg_catalog.pg_class k + where k.oid operator(pg_catalog.=) _objid + ; + if _relkind not in ('r', 'f', 'p', 'v', 'm') then + raise exception 'relkind % not supported', _relkind; + end if; + + perform ai._set_description(_classid, _objid, _objsubid, description); +end +$func$ language plpgsql volatile security invoker +set search_path to pg_catalog, pg_temp; + +------------------------------------------------------------------------------- +-- set_column_description +create or replace function ai.set_column_description +( relation pg_catalog.regclass +, column_name pg_catalog.name +, description pg_catalog.text +) returns void +as $func$ +declare + _classid pg_catalog.oid; + _objid pg_catalog.oid; + _objsubid pg_catalog.int4; +begin + _classid = 'pg_catalog.pg_class'::pg_catalog.regclass::pg_catalog.oid; + _objid = relation::pg_catalog.oid; + _objsubid = 0; + + select a.attnum into _objsubid + from pg_catalog.pg_class k + inner join pg_catalog.pg_attribute a on (k.oid operator(pg_catalog.=) a.attrelid) + where k.oid operator(pg_catalog.=) _objid + and k.relkind in ('r', 'f', 'p', 'v', 'm') + and a.attnum operator(pg_catalog.>) 0 + and a.attname operator(pg_catalog.=) column_name + and not a.attisdropped + ; + if not found then + raise exception '% column not found', column_name; + end if; + + perform ai._set_description(_classid, _objid, _objsubid, description); +end; +$func$ language plpgsql volatile security invoker +set search_path to pg_catalog, pg_temp; + +------------------------------------------------------------------------------- +-- set_function_description +create or replace function ai.set_function_description +( fn regprocedure +, description text +) returns void +as $func$ +declare + _classid pg_catalog.oid; + _objid pg_catalog.oid; + _objsubid pg_catalog.int4; +begin + _classid = 'pg_catalog.pg_proc'::pg_catalog.regclass::pg_catalog.oid; + _objid = fn::pg_catalog.oid; + _objsubid = 0; + + perform ai._set_description(_classid, _objid, _objsubid, description); +end; +$func$ language plpgsql volatile security invoker +set search_path to pg_catalog, pg_temp; diff --git a/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql b/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql new file mode 100644 index 00000000..fcfb883e --- /dev/null +++ b/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql @@ -0,0 +1,256 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- _semantic_catalog_obj_handle_drop +create or replace function ai._semantic_catalog_obj_handle_drop() +returns event_trigger as +$func$ +declare + _rec record; +begin + -- this function is security definer + -- fully-qualify everything and be careful of security holes + for _rec in + ( + select + d.classid + , d.objid + , d.objsubid + --, d.original + --, d.normal + --, d.is_temporary + , d.object_type + --, d.schema_name + --, d.object_name + --, d.object_identity + , d.address_names + , d.address_args + from pg_catalog.pg_event_trigger_dropped_objects() d + ) + loop + delete from ai.semantic_catalog_obj + where objtype operator(pg_catalog.=) _rec.object_type + and objnames operator(pg_catalog.=) _rec.address_names + and objargs operator(pg_catalog.=) _rec.address_args + ; + if _rec.object_type in ('table', 'view') then + -- delete the columns too + delete from ai.semantic_catalog_obj + where classid operator(pg_catalog.=) _rec.classid + and objid operator(pg_catalog.=) _rec.objid + ; + end if; + end loop; +end; +$func$ +language plpgsql volatile security definer -- definer on purpose! +set search_path to pg_catalog, pg_temp +; + +-- install the event trigger if not exists +do language plpgsql $block$ +begin + -- if the event trigger already exists, noop + perform + from pg_catalog.pg_event_trigger g + where g.evtname operator(pg_catalog.=) '_semantic_catalog_obj_handle_drop' + and g.evtfoid operator(pg_catalog.=) pg_catalog.to_regproc('ai._semantic_catalog_obj_handle_drop') + ; + if found then + return; + end if; + + create event trigger _semantic_catalog_obj_handle_drop + on sql_drop + execute function ai._semantic_catalog_obj_handle_drop(); +end +$block$; + +------------------------------------------------------------------------------- +-- _semantic_catalog_obj_handle_ddl +create or replace function ai._semantic_catalog_obj_handle_ddl() +returns event_trigger as +$func$ +declare + _rec record; + _objtype pg_catalog.text; + _objnames pg_catalog.text[]; + _objargs pg_catalog.text[]; +begin + -- this function is security definer + -- fully-qualify everything and be careful of security holes + for _rec in + ( + select + d.classid + , d.objid + , d.objsubid + , d.command_tag + --, d.object_type + --, d.schema_name + --, d.object_identity + --, d.in_extension + --, d.command + from pg_catalog.pg_event_trigger_ddl_commands() d + ) + loop + select + x."type" + , x.object_names + , x.object_args + into strict + _objtype + , _objnames + , _objargs + from pg_catalog.pg_identify_object_as_address + ( _rec.classid + , _rec.objid + , _rec.objsubid + ) x; + + -- alter schema rename to + if _objtype operator(pg_catalog.=) 'schema' then + -- tables/views/columns + with x as + ( + select + d.classid + , d.objid + , d.objsubid + , x."type" as objtype + , x.object_names as objnames + , x.object_args as objargs + from ai.semantic_catalog_obj d + inner join pg_catalog.pg_class k on (d.objid operator(pg_catalog.=) k.oid) + cross join lateral pg_catalog.pg_identify_object_as_address + ( d.classid + , d.objid + , d.objsubid + ) x + where k.relnamespace operator(pg_catalog.=) _rec.objid + ) + update ai.semantic_catalog_obj as d set + objtype = x.objtype + , objnames = x.objnames + , objargs = x.objargs + from x + where d.classid operator(pg_catalog.=) x.classid + and d.objid operator(pg_catalog.=) x.objid + and d.objsubid operator(pg_catalog.=) x.objsubid + and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (x.objtype, x.objnames, x.objargs) -- only if changed + ; + + -- functions + with x as + ( + select + d.classid + , d.objid + , d.objsubid + , x."type" as objtype + , x.object_names as objnames + , x.object_args as objargs + from ai.semantic_catalog_obj d + inner join pg_catalog.pg_proc f on (d.objid operator(pg_catalog.=) f.oid) + cross join lateral pg_catalog.pg_identify_object_as_address + ( d.classid + , d.objid + , d.objsubid + ) x + where f.pronamespace operator(pg_catalog.=) _rec.objid + ) + update ai.semantic_catalog_obj as d set + objtype = x.objtype + , objnames = x.objnames + , objargs = x.objargs + from x + where d.classid operator(pg_catalog.=) x.classid + and d.objid operator(pg_catalog.=) x.objid + and d.objsubid operator(pg_catalog.=) x.objsubid + and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (x.objtype, x.objnames, x.objargs) -- only if changed + ; + + return; -- done + end if; + + -- alter table rename to + -- alter view rename to + -- alter function rename to + -- alter table set schema + -- alter view set schema + -- alter function set schema + update ai.semantic_catalog_obj set + objtype = _objtype + , objnames = _objnames + , objargs = _objargs + where classid operator(pg_catalog.=) _rec.classid + and objid operator(pg_catalog.=) _rec.objid + and objsubid operator(pg_catalog.=) _rec.objsubid + and (objtype, objnames, objargs) operator(pg_catalog.!=) (_objtype, _objnames, _objargs) -- only if changed + ; + if found and _objtype in ('table', 'view') then + -- if table or view renamed or schema changed + -- we need to update the columns too + with attr as + ( + select + _rec.classid + , a.attrelid as objid + , a.attnum as objsubid + from pg_catalog.pg_attribute a + where a.attrelid operator(pg_catalog.=) _rec.objid + and a.attnum operator(pg_catalog.>) 0 + and not a.attisdropped + ) + , xref as + ( + select + attr.classid + , attr.objid + , attr.objsubid + , x."type" as objtype + , x.object_names as objnames + , x.object_args as objargs + from attr + cross join lateral pg_catalog.pg_identify_object_as_address + ( attr.classid + , attr.objid + , attr.objsubid + ) x + ) + update ai.semantic_catalog_obj d set + objtype = xref.objtype + , objnames = xref.objnames + , objargs = xref.objargs + from xref + where d.classid operator(pg_catalog.=) xref.classid + and d.objid operator(pg_catalog.=) xref.objid + and d.objsubid operator(pg_catalog.=) xref.objsubid + and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (xref.objtype, xref.objnames, xref.objargs) -- only if changed + ; + end if; + end loop; +end +$func$ +language plpgsql volatile security definer -- definer on purpose! +set search_path to pg_catalog, pg_temp +; + +-- install the event trigger if not exists +do language plpgsql $block$ +begin + -- if the event trigger already exists, noop + perform + from pg_catalog.pg_event_trigger g + where g.evtname operator(pg_catalog.=) '_semantic_catalog_obj_handle_ddl' + and g.evtfoid operator(pg_catalog.=) pg_catalog.to_regproc('ai._semantic_catalog_obj_handle_ddl') + ; + if found then + return; + end if; + + create event trigger _semantic_catalog_obj_handle_ddl + on ddl_command_end + execute function ai._semantic_catalog_obj_handle_ddl(); +end +$block$; diff --git a/projects/extension/sql/idempotent/903-post-restore.sql b/projects/extension/sql/idempotent/903-post-restore.sql new file mode 100644 index 00000000..0f4dfff1 --- /dev/null +++ b/projects/extension/sql/idempotent/903-post-restore.sql @@ -0,0 +1,117 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- post_restore +create or replace function ai.post_restore() returns void +as $func$ +declare + _sql text; +begin + -- disable vectorizer triggers on the ai.semantic_catalog_obj table + for _sql in + ( + select pg_catalog.format + ( $sql$alter table ai.semantic_catalog_obj disable trigger %I$sql$ + , g.tgname + ) + from pg_catalog.pg_trigger g + where g.tgrelid operator(pg_catalog.=) 'ai.semantic_catalog_obj'::pg_catalog.regclass::pg_catalog.oid + and g.tgname like '_vectorizer_src_trg_%' + ) + loop + execute _sql; + end loop; + + -- oids are likely invalid after a dump/restore + -- look up the new oids and true up + with x as + ( + select + d.objtype + , d.objnames + , d.objargs + , x.classid + , x.objid + , x.objsubid + from + ( + -- despite what the docs say, pg_get_object_address does NOT support everything that pg_identify_object_as_address does + -- view columns and materialized view columns will throw an error + -- https://github.com/postgres/postgres/blob/master/src/backend/catalog/objectaddress.c#L695 + select * + from ai.semantic_catalog_obj d + where d.objtype not in ('view column', 'materialized view column') + ) d + cross join lateral pg_catalog.pg_get_object_address + ( d.objtype + , d.objnames + , d.objargs + ) x + ) + update ai.semantic_catalog_obj as d set + classid = x.classid + , objid = x.objid + , objsubid = x.objsubid + from x + where d.objtype operator(pg_catalog.=) x.objtype + and d.objnames operator(pg_catalog.=) x.objnames + and d.objargs operator(pg_catalog.=) x.objargs + and (d.classid, d.objid, d.objsubid) operator(pg_catalog.!=) (x.classid, x.objid, x.objsubid) -- noop if nothing to change + ; + + -- deal with view columns and materialized view columns + with x as + ( + select + pg_catalog.to_regclass(pg_catalog.array_to_string(pg_catalog.trim_array(d.objnames, 1), '.')) as attrelid + , d.objnames[3] as attname + , d.objtype + , d.objnames + , d.objargs + from ai.semantic_catalog_obj d + where d.objtype in ('view column', 'materialized view column') + and pg_catalog.array_length(d.objnames, 1) operator(pg_catalog.=) 3 + ) + , y as + ( + select + 'pg_catalog.pg_class'::pg_catalog.regclass::pg_catalog.oid as classid + , a.attrelid as objid + , a.attnum as objsubid + , x.objtype + , x.objnames + , x.objargs + from x + inner join pg_catalog.pg_attribute a + on (x.attrelid::pg_catalog.oid operator(pg_catalog.=) a.attrelid and x.attname operator(pg_catalog.=) a.attname) + where x.attrelid is not null + ) + update ai.semantic_catalog_obj as d set + classid = y.classid + , objid = y.objid + , objsubid = y.objsubid + from y + where d.objtype operator(pg_catalog.=) y.objtype + and d.objnames operator(pg_catalog.=) y.objnames + and d.objargs operator(pg_catalog.=) y.objargs + and (d.classid, d.objid, d.objsubid) operator(pg_catalog.!=) (y.classid, y.objid, y.objsubid) -- noop if nothing to change + ; + + -- re-enable vectorizer triggers on the ai.semantic_catalog_obj table + for _sql in + ( + select pg_catalog.format + ( $sql$alter table ai.semantic_catalog_obj enable trigger %I$sql$ + , g.tgname + ) + from pg_catalog.pg_trigger g + where g.tgrelid operator(pg_catalog.=) 'ai.semantic_catalog_obj'::pg_catalog.regclass::pg_catalog.oid + and g.tgname like '_vectorizer_src_trg_%' + ) + loop + execute _sql; + end loop; +end; +$func$ language plpgsql volatile security definer -- definer on purpose +set search_path to pg_catalog, pg_temp +; diff --git a/projects/extension/sql/idempotent/999-privileges.sql b/projects/extension/sql/idempotent/999-privileges.sql index 3f6ea7a1..9ac21fbe 100644 --- a/projects/extension/sql/idempotent/999-privileges.sql +++ b/projects/extension/sql/idempotent/999-privileges.sql @@ -26,6 +26,7 @@ begin when admin then 'all privileges' else case + when k.relname operator(pg_catalog.=) 'semantic_catalog' then 'select' when k.relkind in ('r', 'p') then 'select, insert, update, delete' when k.relkind in ('S') then 'usage, select, update' when k.relkind in ('v') then 'select' @@ -86,7 +87,14 @@ begin and e.extname operator(pg_catalog.=) 'ai' and k.prokind in ('f', 'p') and case - when k.proname in ('grant_ai_usage', 'grant_secret', 'revoke_secret') then admin -- only admins get these function + when k.proname in + ( 'grant_ai_usage' + , 'grant_secret' + , 'revoke_secret' + , 'post_restore' + , 'initialize_semantic_catalog' + ) + then admin -- only admins get these function else true end ) diff --git a/projects/extension/sql/incremental/900-semantic-catalog.sql b/projects/extension/sql/incremental/900-semantic-catalog.sql new file mode 100644 index 00000000..bb6275b5 --- /dev/null +++ b/projects/extension/sql/incremental/900-semantic-catalog.sql @@ -0,0 +1,31 @@ +--FEATURE-FLAG: text_to_sql + +create table ai.semantic_catalog_obj +( objtype pg_catalog.text not null -- required for dump/restore to function +, objnames pg_catalog.text[] not null -- required for dump/restore to function +, objargs pg_catalog.text[] not null -- required for dump/restore to function +, classid pg_catalog.oid not null -- required for event triggers to function +, objid pg_catalog.oid not null -- required for event triggers to function +, objsubid pg_catalog.int4 not null -- required for event triggers to function +, description pg_catalog.text not null -- the description +, primary key (objtype, objnames, objargs) +); +create index on ai.semantic_catalog_obj (classid, objid, objsubid); +perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_obj'::pg_catalog.regclass, ''); + +create table ai.semantic_catalog_sql +( id pg_catalog.int4 not null primary key generated by default as identity +, sql pg_catalog.text not null +, description pg_catalog.text not null +); +perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_sql'::pg_catalog.regclass, ''); +perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_sql_id_seq'::pg_catalog.regclass, ''); + +create table ai.semantic_catalog +( id pg_catalog.int4 not null primary key generated by default as identity +, "name" pg_catalog.text not null unique +, obj_vectorizer_id pg_catalog.int4 -- TODO: foreign key constraint to vectorizer table??? +, sql_vectorizer_id pg_catalog.int4 -- TODO: foreign key constraint to vectorizer table??? +); +perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog'::pg_catalog.regclass, ''); +perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_id_seq'::pg_catalog.regclass, ''); diff --git a/projects/extension/tests/text_to_sql/.gitignore b/projects/extension/tests/text_to_sql/.gitignore new file mode 100644 index 00000000..0859a578 --- /dev/null +++ b/projects/extension/tests/text_to_sql/.gitignore @@ -0,0 +1,4 @@ +describe_objects.sql +describe_schemas.sql +dump.sql +*.snapshot \ No newline at end of file diff --git a/projects/extension/tests/text_to_sql/0.expected b/projects/extension/tests/text_to_sql/0.expected new file mode 100644 index 00000000..2e8b8de5 --- /dev/null +++ b/projects/extension/tests/text_to_sql/0.expected @@ -0,0 +1,13 @@ + objtype | objnames | objargs | description +--------------+--------------------+-----------+------------------------------------------- + function | {public,life} | {integer} | this is a comment about the life function + table | {public,bob} | {} | this is a comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,foo} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,bobby} | {} | this is a comment about the bob table + view column | {public,bobby,bar} | {} | this is a comment about the bar column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + view column | {public,bobby,id} | {} | this is a comment about the id column +(9 rows) + diff --git a/projects/extension/tests/text_to_sql/1.expected b/projects/extension/tests/text_to_sql/1.expected new file mode 100644 index 00000000..375970e4 --- /dev/null +++ b/projects/extension/tests/text_to_sql/1.expected @@ -0,0 +1,13 @@ + objtype | objnames | objargs | description +--------------+--------------------+-----------+-------------------------------------------------- + function | {public,life} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,foo} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,bobby} | {} | this is a comment about the bob table + view column | {public,bobby,bar} | {} | this is a comment about the bar column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + view column | {public,bobby,id} | {} | this is a BETTER comment about the id column +(9 rows) + diff --git a/projects/extension/tests/text_to_sql/10.expected b/projects/extension/tests/text_to_sql/10.expected new file mode 100644 index 00000000..a6210e77 --- /dev/null +++ b/projects/extension/tests/text_to_sql/10.expected @@ -0,0 +1,9 @@ + objtype | objnames | objargs | description +--------------+-------------------+-------------------+-------------------------------------------------- + function | {lucinda,death} | {integer} | this is a BETTER comment about the life function + function | {lucinda,death} | {integer,integer} | overloaded + table | {lucinda,bob} | {} | this is a BETTER comment about the bob table + table column | {lucinda,bob,bar} | {} | this is a comment about the bar column + table column | {lucinda,bob,id} | {} | this is a comment about the id column +(5 rows) + diff --git a/projects/extension/tests/text_to_sql/11.expected b/projects/extension/tests/text_to_sql/11.expected new file mode 100644 index 00000000..beed1df8 --- /dev/null +++ b/projects/extension/tests/text_to_sql/11.expected @@ -0,0 +1,8 @@ + objtype | objnames | objargs | description +--------------+-------------------+-------------------+---------------------------------------------- + function | {lucinda,death} | {integer,integer} | overloaded + table | {lucinda,bob} | {} | this is a BETTER comment about the bob table + table column | {lucinda,bob,bar} | {} | this is a comment about the bar column + table column | {lucinda,bob,id} | {} | this is a comment about the id column +(4 rows) + diff --git a/projects/extension/tests/text_to_sql/12.expected b/projects/extension/tests/text_to_sql/12.expected new file mode 100644 index 00000000..3cc5b686 --- /dev/null +++ b/projects/extension/tests/text_to_sql/12.expected @@ -0,0 +1,5 @@ + objtype | objnames | objargs | description +----------+-----------------+-------------------+------------- + function | {lucinda,death} | {integer,integer} | overloaded +(1 row) + diff --git a/projects/extension/tests/text_to_sql/13.expected b/projects/extension/tests/text_to_sql/13.expected new file mode 100644 index 00000000..5eb50209 --- /dev/null +++ b/projects/extension/tests/text_to_sql/13.expected @@ -0,0 +1,4 @@ + objtype | objnames | objargs | description +---------+----------+---------+------------- +(0 rows) + diff --git a/projects/extension/tests/text_to_sql/2.expected b/projects/extension/tests/text_to_sql/2.expected new file mode 100644 index 00000000..39b3712b --- /dev/null +++ b/projects/extension/tests/text_to_sql/2.expected @@ -0,0 +1,13 @@ + objtype | objnames | objargs | description +--------------+--------------------+-----------+-------------------------------------------------- + function | {public,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,foo} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,bobby} | {} | this is a comment about the bob table + view column | {public,bobby,bar} | {} | this is a comment about the bar column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + view column | {public,bobby,id} | {} | this is a BETTER comment about the id column +(9 rows) + diff --git a/projects/extension/tests/text_to_sql/3.expected b/projects/extension/tests/text_to_sql/3.expected new file mode 100644 index 00000000..103c69c0 --- /dev/null +++ b/projects/extension/tests/text_to_sql/3.expected @@ -0,0 +1,13 @@ + objtype | objnames | objargs | description +--------------+--------------------+-----------+-------------------------------------------------- + function | {public,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,baz} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,bobby} | {} | this is a comment about the bob table + view column | {public,bobby,bar} | {} | this is a comment about the bar column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + view column | {public,bobby,id} | {} | this is a BETTER comment about the id column +(9 rows) + diff --git a/projects/extension/tests/text_to_sql/4.expected b/projects/extension/tests/text_to_sql/4.expected new file mode 100644 index 00000000..c6d76aa8 --- /dev/null +++ b/projects/extension/tests/text_to_sql/4.expected @@ -0,0 +1,13 @@ + objtype | objnames | objargs | description +--------------+------------------------+-----------+-------------------------------------------------- + function | {public,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,baz} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,frederick} | {} | this is a comment about the bob table + view column | {public,frederick,bar} | {} | this is a comment about the bar column + view column | {public,frederick,foo} | {} | this is a comment about the foo column + view column | {public,frederick,id} | {} | this is a BETTER comment about the id column +(9 rows) + diff --git a/projects/extension/tests/text_to_sql/5.expected b/projects/extension/tests/text_to_sql/5.expected new file mode 100644 index 00000000..0d054749 --- /dev/null +++ b/projects/extension/tests/text_to_sql/5.expected @@ -0,0 +1,9 @@ + objtype | objnames | objargs | description +--------------+------------------+-----------+-------------------------------------------------- + function | {public,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,baz} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column +(5 rows) + diff --git a/projects/extension/tests/text_to_sql/6.expected b/projects/extension/tests/text_to_sql/6.expected new file mode 100644 index 00000000..abc20cb9 --- /dev/null +++ b/projects/extension/tests/text_to_sql/6.expected @@ -0,0 +1,8 @@ + objtype | objnames | objargs | description +--------------+------------------+-----------+-------------------------------------------------- + function | {public,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,id} | {} | this is a comment about the id column +(4 rows) + diff --git a/projects/extension/tests/text_to_sql/7.expected b/projects/extension/tests/text_to_sql/7.expected new file mode 100644 index 00000000..e7d2ba6d --- /dev/null +++ b/projects/extension/tests/text_to_sql/7.expected @@ -0,0 +1,8 @@ + objtype | objnames | objargs | description +--------------+------------------+-----------+-------------------------------------------------- + function | {maria,death} | {integer} | this is a BETTER comment about the life function + table | {public,bob} | {} | this is a BETTER comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,id} | {} | this is a comment about the id column +(4 rows) + diff --git a/projects/extension/tests/text_to_sql/8.expected b/projects/extension/tests/text_to_sql/8.expected new file mode 100644 index 00000000..2317f7e2 --- /dev/null +++ b/projects/extension/tests/text_to_sql/8.expected @@ -0,0 +1,8 @@ + objtype | objnames | objargs | description +--------------+-----------------+-----------+-------------------------------------------------- + function | {maria,death} | {integer} | this is a BETTER comment about the life function + table | {maria,bob} | {} | this is a BETTER comment about the bob table + table column | {maria,bob,bar} | {} | this is a comment about the bar column + table column | {maria,bob,id} | {} | this is a comment about the id column +(4 rows) + diff --git a/projects/extension/tests/text_to_sql/9.expected b/projects/extension/tests/text_to_sql/9.expected new file mode 100644 index 00000000..d4c3a72f --- /dev/null +++ b/projects/extension/tests/text_to_sql/9.expected @@ -0,0 +1,9 @@ + objtype | objnames | objargs | description +--------------+-----------------+-------------------+-------------------------------------------------- + function | {maria,death} | {integer} | this is a BETTER comment about the life function + function | {maria,death} | {integer,integer} | overloaded + table | {maria,bob} | {} | this is a BETTER comment about the bob table + table column | {maria,bob,bar} | {} | this is a comment about the bar column + table column | {maria,bob,id} | {} | this is a comment about the id column +(5 rows) + diff --git a/projects/extension/tests/text_to_sql/__init__.py b/projects/extension/tests/text_to_sql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/extension/tests/text_to_sql/extra.sql b/projects/extension/tests/text_to_sql/extra.sql new file mode 100644 index 00000000..cdfd59c7 --- /dev/null +++ b/projects/extension/tests/text_to_sql/extra.sql @@ -0,0 +1,67 @@ + +-- TODO: remove feature flag +select set_config('ai.enable_feature_flag_text_to_sql', 'true', false); +create extension if not exists ai cascade; + +create schema bishop; + +create table public.xenomorph +( id int not null primary key +, foo text not null +, bar timestamptz not null default now() +); +select ai.set_description ('public.xenomorph', 'description for xenomorph'); +select ai.set_column_description('public.xenomorph', 'id', 'description for xenomorph.id'); +select ai.set_column_description('public.xenomorph', 'foo', 'description for xenomorph.foo'); +select ai.set_column_description('public.xenomorph', 'bar', 'description for xenomorph.bar'); + +create table bishop.ripley +( id int not null primary key +, foo text not null +, bar timestamptz not null default now() +); +select ai.set_description('bishop.ripley', 'description for bishop.ripley'); +select ai.set_column_description('bishop.ripley', 'id', 'description for bishop.ripley.id'); +select ai.set_column_description('bishop.ripley', 'foo', 'description for bishop.ripley.foo'); +select ai.set_column_description('bishop.ripley', 'bar', 'description for bishop.ripley.bar'); + +create view public.hicks as +select * from bishop.ripley; +select ai.set_description('public.hicks', 'description for hicks'); +select ai.set_column_description('public.hicks', 'id', 'description for hicks.id'); +select ai.set_column_description('public.hicks', 'foo', 'description for hicks.foo'); +select ai.set_column_description('public.hicks', 'bar', 'description for hicks.bar'); + +create function public.hudson(x int) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('public.hudson'::regproc, 'description for hudson(int)'); + +create function public.hudson(x int, y int) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('public.hudson(int, int)', 'description for hudson(int, int)'); + +create table bishop.burke +( id int not null primary key +, foo text +, bar text +, baz bool +); +select ai.set_description('bishop.burke', 'description for bishop.burke'); +select ai.set_column_description('bishop.burke', 'id', 'description for bishop.burke.id'); +select ai.set_column_description('bishop.burke', 'foo', 'description for bishop.burke.foo'); +select ai.set_column_description('bishop.burke', 'bar', 'description for bishop.burke.bar'); + +create function bishop.gorman(z bool) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('bishop.gorman(bool)', 'description for bishop.gorman(bool)'); + + diff --git a/projects/extension/tests/text_to_sql/init.sql b/projects/extension/tests/text_to_sql/init.sql new file mode 100644 index 00000000..38f300ab --- /dev/null +++ b/projects/extension/tests/text_to_sql/init.sql @@ -0,0 +1,77 @@ + +-- TODO: remove feature flag +select set_config('ai.enable_feature_flag_text_to_sql', 'true', false); +create extension if not exists ai cascade; + +create schema billy; + +create table public.predator +( id int not null primary key +, foo text not null +, bar timestamptz not null default now() +); +select ai.set_description ('public.predator', 'description for predator'); +select ai.set_column_description('public.predator', 'id', 'description for predator.id'); +select ai.set_column_description('public.predator', 'foo', 'description for predator.foo'); +select ai.set_column_description('public.predator', 'bar', 'description for predator.bar'); + +create table billy.dillon +( id int not null primary key +, foo text not null +, bar timestamptz not null default now() +); +select ai.set_description('billy.dillon', 'description for billy.dillon'); +select ai.set_column_description('billy.dillon', 'id', 'description for billy.dillon.id'); +select ai.set_column_description('billy.dillon', 'foo', 'description for billy.dillon.foo'); +select ai.set_column_description('billy.dillon', 'bar', 'description for billy.dillon.bar'); + +create view public.hawkins as +select * from billy.dillon; +select ai.set_description('public.hawkins', 'description for hawkins'); +select ai.set_column_description('public.hawkins', 'id', 'description for hawkins.id'); +select ai.set_column_description('public.hawkins', 'foo', 'description for hawkins.foo'); +select ai.set_column_description('public.hawkins', 'bar', 'description for hawkins.bar'); + +create function public.dutch(x int) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('public.dutch'::regproc, 'description for dutch(int)'); + +create function public.dutch(x int, y int) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('public.dutch(int, int)', 'description for dutch(int, int)'); + +create table billy.poncho +( id int not null primary key +, foo text +, bar text +, baz bool +); +select ai.set_description('billy.poncho', 'description for billy.poncho'); +select ai.set_column_description('billy.poncho', 'id', 'description for billy.poncho.id'); +select ai.set_column_description('billy.poncho', 'foo', 'description for billy.poncho.foo'); +select ai.set_column_description('billy.poncho', 'bar', 'description for billy.poncho.bar'); + +create function billy.mac(z bool) returns int +as $func$ + select 42 +$func$ language sql +; +select ai.set_function_description('billy.mac(bool)', 'description for billy.mac(bool)'); + +select ai.add_sql_example +( $sql$ +select id, concat(foo, bar, baz) +from billy.poncho +where id % 2 = 0 +$sql$ +, $description$ +This query concatenates foo, bar, and baz for even ids of billy.poncho. +$description$ +); + diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected new file mode 100644 index 00000000..df1ad573 --- /dev/null +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -0,0 +1,291 @@ + Table "ai._vectorizer_q_1" + Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description +-----------+--------------------------+-----------+----------+---------+----------+-------------+--------------+------------- + objtype | text | | not null | | extended | | | + objnames | text[] | | not null | | extended | | | + objargs | text[] | | not null | | extended | | | + queued_at | timestamp with time zone | | not null | now() | plain | | | +Indexes: + "_vectorizer_q_1_objtype_objnames_objargs_idx" btree (objtype, objnames, objargs) +Access method: heap + + Table "ai._vectorizer_q_2" + Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description +-----------+--------------------------+-----------+----------+---------+---------+-------------+--------------+------------- + id | integer | | not null | | plain | | | + queued_at | timestamp with time zone | | not null | now() | plain | | | +Indexes: + "_vectorizer_q_2_id_idx" btree (id) +Access method: heap + + Table "ai.semantic_catalog_obj_1_store" + Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description +----------------+-------------+-----------+----------+-------------------+----------+-------------+--------------+------------- + embedding_uuid | uuid | | not null | gen_random_uuid() | plain | | | + objtype | text | | not null | | extended | | | + objnames | text[] | | not null | | extended | | | + objargs | text[] | | not null | | extended | | | + chunk_seq | integer | | not null | | plain | | | + chunk | text | | not null | | extended | | | + embedding | vector(128) | | not null | | main | | | +Indexes: + "semantic_catalog_obj_1_store_pkey" PRIMARY KEY, btree (embedding_uuid) + "semantic_catalog_obj_1_store_objtype_objnames_objargs_chunk_key" UNIQUE CONSTRAINT, btree (objtype, objnames, objargs, chunk_seq) +Foreign-key constraints: + "semantic_catalog_obj_1_store_objtype_objnames_objargs_fkey" FOREIGN KEY (objtype, objnames, objargs) REFERENCES ai.semantic_catalog_obj(objtype, objnames, objargs) ON DELETE CASCADE +Access method: heap + + Table "ai.semantic_catalog_sql_1_store" + Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description +----------------+-------------+-----------+----------+-------------------+----------+-------------+--------------+------------- + embedding_uuid | uuid | | not null | gen_random_uuid() | plain | | | + id | integer | | not null | | plain | | | + chunk_seq | integer | | not null | | plain | | | + chunk | text | | not null | | extended | | | + embedding | vector(128) | | not null | | main | | | +Indexes: + "semantic_catalog_sql_1_store_pkey" PRIMARY KEY, btree (embedding_uuid) + "semantic_catalog_sql_1_store_id_chunk_seq_key" UNIQUE CONSTRAINT, btree (id, chunk_seq) +Foreign-key constraints: + "semantic_catalog_sql_1_store_id_fkey" FOREIGN KEY (id) REFERENCES ai.semantic_catalog_sql(id) ON DELETE CASCADE +Access method: heap + + View "ai.semantic_catalog_obj_1" + Column | Type | Collation | Nullable | Default | Storage | Description +----------------+-------------+-----------+----------+---------+----------+------------- + embedding_uuid | uuid | | | | plain | + chunk_seq | integer | | | | plain | + chunk | text | | | | extended | + embedding | vector(128) | | | | external | + objtype | text | | | | extended | + objnames | text[] | | | | extended | + objargs | text[] | | | | extended | + classid | oid | | | | plain | + objid | oid | | | | plain | + objsubid | integer | | | | plain | + description | text | | | | extended | +View definition: + SELECT t.embedding_uuid, + t.chunk_seq, + t.chunk, + t.embedding, + t.objtype, + t.objnames, + t.objargs, + s.classid, + s.objid, + s.objsubid, + s.description + FROM ai.semantic_catalog_obj_1_store t + LEFT JOIN ai.semantic_catalog_obj s ON t.objtype = s.objtype AND t.objnames = s.objnames AND t.objargs = s.objargs; + + View "ai.semantic_catalog_sql_1" + Column | Type | Collation | Nullable | Default | Storage | Description +----------------+-------------+-----------+----------+---------+----------+------------- + embedding_uuid | uuid | | | | plain | + chunk_seq | integer | | | | plain | + chunk | text | | | | extended | + embedding | vector(128) | | | | external | + id | integer | | | | plain | + sql | text | | | | extended | + description | text | | | | extended | +View definition: + SELECT t.embedding_uuid, + t.chunk_seq, + t.chunk, + t.embedding, + t.id, + s.sql, + s.description + FROM ai.semantic_catalog_sql_1_store t + LEFT JOIN ai.semantic_catalog_sql s ON t.id = s.id; + + id | name | obj_vectorizer_id | sql_vectorizer_id +----+---------+-------------------+------------------- + 1 | default | 1 | 2 +(1 row) + + objtype | objnames | objargs | description +--------------+--------------------+-----------+------------------------------------------- + function | {public,life} | {integer} | this is a comment about the life function + table | {public,bob} | {} | this is a comment about the bob table + table column | {public,bob,bar} | {} | this is a comment about the bar column + table column | {public,bob,foo} | {} | this is a comment about the foo column + table column | {public,bob,id} | {} | this is a comment about the id column + view | {public,bobby} | {} | this is a comment about the bob table + view column | {public,bobby,bar} | {} | this is a comment about the bar column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + view column | {public,bobby,id} | {} | this is a comment about the id column +(9 rows) + + id | sql | description +----+-----------------------------------------+-------------------------------------------------------------- + 1 | select * from bobby where id = life(id) | a bogus query against the bobby view using the life function +(1 row) + + objtype | objnames | objargs +--------------+--------------------+----------- + function | {public,life} | {integer} + table | {public,bob} | {} + table column | {public,bob,bar} | {} + table column | {public,bob,foo} | {} + table column | {public,bob,id} | {} + view | {public,bobby} | {} + view column | {public,bobby,bar} | {} + view column | {public,bobby,foo} | {} + view column | {public,bobby,id} | {} +(9 rows) + + id +---- + 1 +(1 row) + + id | source_table | target_table | view | pending_items +----+-------------------------+---------------------------------+---------------------------+--------------- + 1 | ai.semantic_catalog_obj | ai.semantic_catalog_obj_1_store | ai.semantic_catalog_obj_1 | 9 + 2 | ai.semantic_catalog_sql | ai.semantic_catalog_sql_1_store | ai.semantic_catalog_sql_1 | 1 +(2 rows) + + jsonb_pretty +-------------------------------------------------------------------- + { + + "id": 1, + + "config": { + + "chunking": { + + "chunk_size": 800, + + "separators": [ + + "\n\n", + + "\n", + + ".", + + "?", + + "!", + + " ", + + "" + + ], + + "config_type": "chunking", + + "chunk_column": "description", + + "chunk_overlap": 400, + + "implementation": "recursive_character_text_splitter",+ + "is_separator_regex": false + + }, + + "indexing": { + + "config_type": "indexing", + + "implementation": "none" + + }, + + "embedding": { + + "model": "text-embedding-3-small", + + "dimensions": 128, + + "config_type": "embedding", + + "api_key_name": "OPENAI_API_KEY", + + "implementation": "openai" + + }, + + "formatting": { + + "template": "$chunk", + + "config_type": "formatting", + + "implementation": "python_template" + + }, + + "processing": { + + "config_type": "processing", + + "implementation": "default" + + }, + + "scheduling": { + + "config_type": "scheduling", + + "implementation": "none" + + } + + }, + + "source_pk": [ + + { + + "pknum": 1, + + "attnum": 1, + + "attname": "objtype", + + "typname": "text" + + }, + + { + + "pknum": 2, + + "attnum": 2, + + "attname": "objnames", + + "typname": "_text" + + }, + + { + + "pknum": 3, + + "attnum": 3, + + "attname": "objargs", + + "typname": "_text" + + } + + ], + + "view_name": "semantic_catalog_obj_1", + + "queue_table": "_vectorizer_q_1", + + "view_schema": "ai", + + "queue_schema": "ai", + + "source_table": "semantic_catalog_obj", + + "target_table": "semantic_catalog_obj_1_store", + + "trigger_name": "_vectorizer_src_trg_1", + + "source_schema": "ai", + + "target_schema": "ai" + + } + { + + "id": 2, + + "config": { + + "chunking": { + + "chunk_size": 800, + + "separators": [ + + "\n\n", + + "\n", + + ".", + + "?", + + "!", + + " ", + + "" + + ], + + "config_type": "chunking", + + "chunk_column": "description", + + "chunk_overlap": 400, + + "implementation": "recursive_character_text_splitter",+ + "is_separator_regex": false + + }, + + "indexing": { + + "config_type": "indexing", + + "implementation": "none" + + }, + + "embedding": { + + "model": "text-embedding-3-small", + + "dimensions": 128, + + "config_type": "embedding", + + "api_key_name": "OPENAI_API_KEY", + + "implementation": "openai" + + }, + + "formatting": { + + "template": "$chunk", + + "config_type": "formatting", + + "implementation": "python_template" + + }, + + "processing": { + + "config_type": "processing", + + "implementation": "default" + + }, + + "scheduling": { + + "config_type": "scheduling", + + "implementation": "none" + + } + + }, + + "source_pk": [ + + { + + "pknum": 1, + + "attnum": 1, + + "attname": "id", + + "typname": "int4" + + } + + ], + + "view_name": "semantic_catalog_sql_1", + + "queue_table": "_vectorizer_q_2", + + "view_schema": "ai", + + "queue_schema": "ai", + + "source_table": "semantic_catalog_sql", + + "target_table": "semantic_catalog_sql_1_store", + + "trigger_name": "_vectorizer_src_trg_2", + + "source_schema": "ai", + + "target_schema": "ai" + + } +(2 rows) + diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.sql b/projects/extension/tests/text_to_sql/snapshot-catalog.sql new file mode 100644 index 00000000..b419aa7a --- /dev/null +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.sql @@ -0,0 +1,20 @@ +\pset pager off + +\d+ ai._vectorizer_q_1 +\d+ ai._vectorizer_q_2 +\d+ ai.semantic_catalog_obj_1_store +\d+ ai.semantic_catalog_sql_1_store +\d+ ai.semantic_catalog_obj_1 +\d+ ai.semantic_catalog_sql_1 + +select * from ai.semantic_catalog order by id; +select objtype, objnames, objargs, description from ai.semantic_catalog_obj order by objtype, objnames, objargs; +select * from ai.semantic_catalog_sql order by id; +select objtype, objnames, objargs from ai._vectorizer_q_1 order by objtype, objnames, objargs; +select id from ai._vectorizer_q_2 order by id; +select * from ai.vectorizer_status order by id; + +select jsonb_pretty(to_jsonb(x) #- array['config', 'version']) +from ai.vectorizer x +order by id +; diff --git a/projects/extension/tests/text_to_sql/snapshot.sql b/projects/extension/tests/text_to_sql/snapshot.sql new file mode 100644 index 00000000..983ac471 --- /dev/null +++ b/projects/extension/tests/text_to_sql/snapshot.sql @@ -0,0 +1,73 @@ +\pset pager off + +select version(); + +-- Lists schemas +\dn+ +-- Lists installed extensions. +\dx +-- all the objects belonging to each matching extension are listed. +\dx+ ai +-- Lists default access privilege settings. +\ddp + +-- dynamically generate meta commands to describe schemas +\! rm -f describe_schemas.sql +select format('%s %s', c.c, s.s) +from unnest(array +[ 'public' +, 'ai' +, 'bishop' +, 'billy' +]) s(s) +cross join unnest(array +[ '\dp+' -- Lists tables, views and sequences with their associated access privileges +, '\ddp' -- Lists default access privilege settings. An entry is shown for each role (and schema, if applicable) for which the default privilege settings have been changed from the built-in defaults. +]) c(c) +order by c.c, s.s +\g (tuples_only=on format=csv) describe_schemas.sql +\i describe_schemas.sql + +-- dynamically generate meta commands to describe objects in the schemas +\! rm -f describe_objects.sql +select format('%s %s', c.c, s.s) +from unnest(array +[ 'public.*' +, 'ai.*' +, 'bishop.*' +, 'billy.*' +]) s(s) +cross join unnest(array +[ '\d+' -- Describe each relation +, '\df+' -- Describe functions +, '\dp+' -- Lists tables, views and sequences with their associated access privileges. +, '\di' -- Describe indexes +, '\do' -- Lists operators with their operand and result types +, '\dT' -- Lists data types. +]) c(c) +order by c.c, s.s +\g (tuples_only=on format=csv) describe_objects.sql +\i describe_objects.sql + +-- snapshot the data from all the tables and views +select + format($$select '%I.%I' as table_snapshot;$$, n.nspname, k.relname), + case + -- we don't care about comparing the applied_at_version and applied_at columns of the migration table + when n.nspname = 'ai'::name and k.relname = 'migration'::name + then 'select name, body from ai.migration order by name, body;' + when n.nspname = 'ai'::name and k.relname = 'feature_flag'::name + then 'select name, applied_at_version from ai.migration order by name;' + when n.nspname = 'ai'::name and k.relname = 'semantic_catalog_obj'::name + then 'select objtype, objnames, objargs, description from ai.semantic_catalog_obj order by 1, 2, 3' + else format('select * from %I.%I tbl order by tbl;', n.nspname, k.relname) + end +from pg_namespace n +inner join pg_class k on (n.oid = k.relnamespace) +where k.relkind in ('r', 'p', 'v') +and n.nspname in +( 'public' +, 'ai' +) +order by n.nspname, k.relname +\gexec diff --git a/projects/extension/tests/text_to_sql/test_dump_restore.py b/projects/extension/tests/text_to_sql/test_dump_restore.py new file mode 100644 index 00000000..afb40afb --- /dev/null +++ b/projects/extension/tests/text_to_sql/test_dump_restore.py @@ -0,0 +1,182 @@ +import os +import subprocess +from pathlib import Path + +import psycopg +import pytest + +# skip tests in this module if disabled +enable_text_to_sql_tests = os.getenv("ENABLE_TEXT_TO_SQL_TESTS") +if enable_text_to_sql_tests == "0": + pytest.skip(allow_module_level=True) + + +USER = "billy" # NOT a superuser +SRC_DB = "text_to_sql_src" +DST_DB = "text_to_sql_dst" + + +def db_url(user: str, dbname: str) -> str: + return f"postgres://{user}@127.0.0.1:5432/{dbname}" + + +def where_am_i() -> str: + if "WHERE_AM_I" in os.environ and os.environ["WHERE_AM_I"] == "docker": + return "docker" + return "host" + + +def docker_dir() -> str: + return "/pgai/tests/text_to_sql" + + +def host_dir() -> Path: + return Path(__file__).parent.absolute() + + +def create_user(user: str) -> None: + with psycopg.connect( + db_url(user="postgres", dbname="postgres"), autocommit=True + ) as con: + with con.cursor() as cur: + cur.execute( + """ + select count(*) > 0 + from pg_catalog.pg_roles + where rolname = %s + """, + (user,), + ) + exists: bool = cur.fetchone()[0] + if not exists: + cur.execute(f"create user {user}") # NOT a superuser + + +def create_database(dbname: str) -> None: + with psycopg.connect( + db_url(user="postgres", dbname="postgres"), autocommit=True + ) as con: + with con.cursor() as cur: + cur.execute(f"drop database if exists {dbname} with (force)") + cur.execute(f"create database {dbname} with owner {USER}") + + +def dump_db() -> None: + host_dir().joinpath("dump.sql").unlink(missing_ok=True) + cmd = " ".join( + [ + "pg_dump -Fp --no-comments", + f'''-d "{db_url(USER, SRC_DB)}"''', + f"""-f {docker_dir()}/dump.sql""", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def restore_db() -> None: + cmd = " ".join( + [ + "psql", + f'''-d "{db_url(USER, DST_DB)}"''', + "-v VERBOSITY=verbose", + f"-f {docker_dir()}/dump.sql", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def snapshot_db(dbname: str) -> None: + host_dir().joinpath(f"{dbname}.snapshot").unlink(missing_ok=True) + cmd = " ".join( + [ + "psql", + f'''-d "{db_url("postgres", dbname)}"''', + "-v ON_ERROR_STOP=1", + "-X", + f"-o {docker_dir()}/{dbname}.snapshot", + f"-f {docker_dir()}/snapshot.sql", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def init_src() -> None: + cmd = " ".join( + [ + "psql", + f'''-d "{db_url(USER, SRC_DB)}"''', + "-v ON_ERROR_STOP=1", + f"-f {docker_dir()}/init.sql", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def read_file(filename: str) -> str: + with open(filename, "r") as f: + return f.read() + + +def extra(dbname: str) -> None: + cmd = " ".join( + [ + "psql", + f'''-d "{db_url(USER, dbname)}"''', + "-v ON_ERROR_STOP=1", + f"-f {docker_dir()}/extra.sql", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def post_restore() -> None: + with psycopg.connect(db_url(user=USER, dbname=DST_DB)) as con: + with con.cursor() as cur: + cur.execute("select ai.post_restore()") + + +def check_mapping() -> bool: + with psycopg.connect(db_url(user=USER, dbname=DST_DB)) as con: + with con.cursor() as cur: + cur.execute(""" + select + count(*) = + count(*) filter (where (d.objtype, d.objnames, d.objargs) = (x."type", x.object_names, x.object_args)) + from ai.semantic_catalog_obj d + cross join lateral pg_catalog.pg_identify_object_as_address + ( d.classid + , d.objid + , d.objsubid + ) x + """) + return cur.fetchone()[0] + + +def test_dump_restore(): + create_user(USER) + create_database(SRC_DB) + create_database(DST_DB) + init_src() + dump_db() + extra(SRC_DB) # add extra descriptions AFTER dump so snapshots match + snapshot_db(SRC_DB) + extra(DST_DB) # add extra descriptions BEFORE restore to make sure that works + restore_db() + post_restore() + snapshot_db(DST_DB) + src = read_file(str(host_dir().joinpath(f"{SRC_DB}.snapshot"))) + dst = read_file(str(host_dir().joinpath(f"{DST_DB}.snapshot"))) + assert dst == src + assert ( + check_mapping() is True + ) # ensure that all the oids match the names after the restore diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py new file mode 100644 index 00000000..750b3a67 --- /dev/null +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -0,0 +1,326 @@ +import os +import subprocess +from pathlib import Path + +import pytest +import psycopg + + +# skip tests in this module if disabled +enable_text_to_sql_tests = os.getenv("ENABLE_TEXT_TO_SQL_TESTS") +if enable_text_to_sql_tests == "0": + pytest.skip(allow_module_level=True) + + +def db_url(user: str, dbname: str) -> str: + return f"postgres://{user}@127.0.0.1:5432/{dbname}" + + +def where_am_i() -> str: + if "WHERE_AM_I" in os.environ and os.environ["WHERE_AM_I"] == "docker": + return "docker" + return "host" + + +def docker_dir() -> str: + return "/pgai/tests/text_to_sql" + + +def host_dir() -> Path: + return Path(__file__).parent.absolute() + + +def does_test_user_exist(cur: psycopg.Cursor) -> bool: + cur.execute(""" + select count(*) > 0 + from pg_catalog.pg_roles + where rolname = 'test' + """) + return cur.fetchone()[0] + + +def create_test_user(cur: psycopg.Cursor) -> None: + if not does_test_user_exist(cur): + cur.execute("create user test password 'test'") + + +def set_up_test_db(dbname: str) -> None: + # create a test user and test database owned by the test user + with psycopg.connect( + "postgres://postgres@127.0.0.1:5432/postgres", autocommit=True + ) as con: + with con.cursor() as cur: + create_test_user(cur) + cur.execute(f"drop database if exists {dbname} with (force)") + cur.execute(f"create database {dbname} owner test") + # use the test user to create the extension in the text_to_sql database + with psycopg.connect(f"postgres://test@127.0.0.1:5432/{dbname}") as con: + with con.cursor() as cur: + # turn on the feature flag for text_to_sql + cur.execute( + "select set_config('ai.enable_feature_flag_text_to_sql', 'true', false)" + ) + cur.execute("create extension ai cascade") + + +def snapshot_descriptions(dbname: str, name: str) -> None: + host_dir().joinpath(f"{name}.actual").unlink(missing_ok=True) + cmd = " ".join( + [ + "psql", + f'''-d "{db_url("test", dbname)}"''', + "-v ON_ERROR_STOP=1", + "-X", + f"-o {docker_dir()}/{name}.actual", + '-c "select objtype, objnames, objargs, description from ai.semantic_catalog_obj order by 1,2,3"', + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def file_contents(name: str) -> str: + return host_dir().joinpath(f"{name}").read_text() + + +def test_event_triggers(): + set_up_test_db("text_to_sql_1") + with psycopg.connect(db_url("test", "text_to_sql_1")) as con: + with con.cursor() as cur: + cur.execute(""" + create table bob + ( id int not null primary key + , foo text not null + , bar timestamptz not null default now() + ); + + create view bobby as + select * from bob; + + create function life(x int) returns int + as $func$select 42$func$ language sql; + """) + + # 0 set descriptions + cur.execute(""" + -- table + select ai.set_description('bob', 'this is a comment about the bob table'); + select ai.set_column_description('bob', 'id', 'this is a comment about the id column'); + select ai.set_column_description('bob', 'foo', 'this is a comment about the foo column'); + select ai.set_column_description('bob', 'bar', 'this is a comment about the bar column'); + + -- view + select ai.set_description('bobby', 'this is a comment about the bob table'); + select ai.set_column_description('bobby', 'id', 'this is a comment about the id column'); + select ai.set_column_description('bobby', 'foo', 'this is a comment about the foo column'); + select ai.set_column_description('bobby', 'bar', 'this is a comment about the bar column'); + + -- function + select ai.set_function_description('life'::regproc, 'this is a comment about the life function'); + """) + con.commit() + snapshot_descriptions("text_to_sql_1", "0") + actual = file_contents("0.actual") + expected = file_contents("0.expected") + assert actual == expected + + # 1 change descriptions + cur.execute( + "select ai.set_description('bob', 'this is a BETTER comment about the bob table')" + ) + cur.execute( + "select ai.set_column_description('bobby', 'id', 'this is a BETTER comment about the id column')" + ) + cur.execute( + """ + select ai.set_function_description + ( 'life'::regproc + , 'this is a BETTER comment about the life function' + ) + """ + ) + con.commit() + snapshot_descriptions("text_to_sql_1", "1") + actual = file_contents("1.actual") + expected = file_contents("1.expected") + assert actual == expected + + # 2 rename function + cur.execute("alter function life(int) rename to death") + con.commit() + snapshot_descriptions("text_to_sql_1", "2") + actual = file_contents("2.actual") + expected = file_contents("2.expected") + assert actual == expected + + # 3 rename table column + cur.execute("alter table bob rename column foo to baz") + con.commit() + snapshot_descriptions("text_to_sql_1", "3") + actual = file_contents("3.actual") + expected = file_contents("3.expected") + assert actual == expected + + # 4 rename view + cur.execute("alter view bobby rename to frederick") + con.commit() + snapshot_descriptions("text_to_sql_1", "4") + actual = file_contents("4.actual") + expected = file_contents("4.expected") + assert actual == expected + + # 5 drop view + cur.execute("drop view frederick") + con.commit() + snapshot_descriptions("text_to_sql_1", "5") + actual = file_contents("5.actual") + expected = file_contents("5.expected") + assert actual == expected + + # 6 drop table column + cur.execute("alter table bob drop column baz") + con.commit() + snapshot_descriptions("text_to_sql_1", "6") + actual = file_contents("6.actual") + expected = file_contents("6.expected") + assert actual == expected + + # 7 alter function set schema + cur.execute("create schema maria") + cur.execute("alter function death set schema maria") + con.commit() + snapshot_descriptions("text_to_sql_1", "7") + actual = file_contents("7.actual") + expected = file_contents("7.expected") + assert actual == expected + + # 8 alter table set schema + cur.execute("alter table bob set schema maria") + con.commit() + snapshot_descriptions("text_to_sql_1", "8") + actual = file_contents("8.actual") + expected = file_contents("8.expected") + assert actual == expected + + # 9 test overloaded function names + cur.execute(""" + create function maria.death(x int, y int) returns int + as $func$select 42$func$ language sql; + """) + cur.execute( + "select ai.set_function_description('maria.death(int, int)', 'overloaded')" + ) + con.commit() + snapshot_descriptions("text_to_sql_1", "9") + actual = file_contents("9.actual") + expected = file_contents("9.expected") + assert actual == expected + + # 10 alter schema rename + cur.execute("alter schema maria rename to lucinda") + con.commit() + snapshot_descriptions("text_to_sql_1", "10") + actual = file_contents("10.actual") + expected = file_contents("10.expected") + assert actual == expected + + # 11 drop function + cur.execute("drop function lucinda.death(int)") + con.commit() + snapshot_descriptions("text_to_sql_1", "11") + actual = file_contents("11.actual") + expected = file_contents("11.expected") + assert actual == expected + + # 12 drop table + cur.execute("drop table lucinda.bob") + con.commit() + snapshot_descriptions("text_to_sql_1", "12") + actual = file_contents("12.actual") + expected = file_contents("12.expected") + assert actual == expected + + # 13 drop schema cascade + cur.execute("drop schema lucinda cascade") + con.commit() + snapshot_descriptions("text_to_sql_1", "13") + actual = file_contents("13.actual") + expected = file_contents("13.expected") + assert actual == expected + + cur.execute("delete from ai.semantic_catalog_obj") + + +def snapshot_catalog(dbname: str) -> None: + host_dir().joinpath("snapshot-catalog.actual").unlink(missing_ok=True) + cmd = " ".join( + [ + "psql", + f'''-d "{db_url("postgres", dbname)}"''', + "-v ON_ERROR_STOP=1", + "-X", + "--echo-errors", + f"-o {docker_dir()}/snapshot-catalog.actual", + f"-f {docker_dir()}/snapshot-catalog.sql", + ] + ) + if where_am_i() != "docker": + cmd = f"docker exec -w {docker_dir()} pgai-ext {cmd}" + subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) + + +def test_vectorizer_setup(): + set_up_test_db("text_to_sql_2") + with psycopg.connect(db_url("test", "text_to_sql_2")) as con: + with con.cursor() as cur: + cur.execute("""select ai.grant_ai_usage('test', true)""") + cur.execute(""" + create table bob + ( id int not null primary key + , foo text not null + , bar timestamptz not null default now() + ); + + create view bobby as + select * from bob; + + create function life(x int) returns int + as $func$select 42$func$ language sql; + """) + + cur.execute(""" + select ai.initialize_semantic_catalog + ( embedding=>ai.embedding_openai('text-embedding-3-small', 128) + ) + """) + con.commit() + + cur.execute(""" + -- table + select ai.set_description('bob', 'this is a comment about the bob table'); + select ai.set_column_description('bob', 'id', 'this is a comment about the id column'); + select ai.set_column_description('bob', 'foo', 'this is a comment about the foo column'); + select ai.set_column_description('bob', 'bar', 'this is a comment about the bar column'); + + -- view + select ai.set_description('bobby', 'this is a comment about the bob table'); + select ai.set_column_description('bobby', 'id', 'this is a comment about the id column'); + select ai.set_column_description('bobby', 'foo', 'this is a comment about the foo column'); + select ai.set_column_description('bobby', 'bar', 'this is a comment about the bar column'); + + -- function + select ai.set_function_description('life'::regproc, 'this is a comment about the life function'); + + -- example query + select ai.add_sql_example + ( $sql$select * from bobby where id = life(id)$sql$ + , 'a bogus query against the bobby view using the life function' + ); + """) + con.commit() + + snapshot_catalog("text_to_sql_2") + actual = file_contents("snapshot-catalog.actual") + expected = file_contents("snapshot-catalog.expected") + assert actual == expected From 606f4bda8c70f5984d2c8b2bb1cfa84ff00e7fa4 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Wed, 11 Dec 2024 12:23:34 -0600 Subject: [PATCH 07/27] feat: add ai.find_relevant_sql semantic catalog function --- .../idempotent/900-semantic-catalog-init.sql | 8 ++ .../904-semantic-catalog-search.sql | 111 ++++++++++++++++++ .../text_to_sql/snapshot-catalog.expected | 29 +++-- .../tests/text_to_sql/test_text_to_sql.py | 36 +++++- 4 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 projects/extension/sql/idempotent/904-semantic-catalog-search.sql diff --git a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql index 10a6936d..cbcba1c9 100644 --- a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql +++ b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql @@ -16,6 +16,14 @@ declare _obj_vec_id pg_catalog.int4; _sql_vec_id pg_catalog.int4; begin + grant_to = pg_catalog.array_cat + ( grant_to + , array + [ pg_catalog."session_user"() + , 'pg_database_owner'::name + ] + ); + insert into ai.semantic_catalog("name") values (initialize_semantic_catalog."name") returning id diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql new file mode 100644 index 00000000..8c35e8ec --- /dev/null +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -0,0 +1,111 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- _semantic_catalog_embed +create or replace function ai._semantic_catalog_embed +( catalog_id pg_catalog.int4 +, prompt pg_catalog.text +) returns @extschema:vector@.vector +as $func$ +declare + _vectorizer_id pg_catalog.int4; + _config pg_catalog.jsonb; + _emb @extschema:vector@.vector; +begin + select x.obj_vectorizer_id -- TODO: assumes the embedding settings are the same for obj and sql + into strict _vectorizer_id + from ai.semantic_catalog x + where x.id operator(pg_catalog.=) catalog_id + ; + + select v.config operator(pg_catalog.->) 'embedding' + into strict _config + from ai.vectorizer v + where v.id operator(pg_catalog.=) _vectorizer_id + ; + + case _config operator(pg_catalog.->>) 'implementation' + when 'openai' then + _emb = ai.openai_embed + ( _config operator(pg_catalog.->>) 'model' + , prompt + , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') + , dimensions=>(_config operator(pg_catalog.->>) 'dimensions')::pg_catalog.int4 + , openai_user=>(_config operator(pg_catalog.->>) 'user') + ); + when 'ollama' then + _emb = ai.ollama_embed + ( _config operator(pg_catalog.->>) 'model' + , prompt + , host=>(_config operator(pg_catalog.->>) 'base_url') + , keep_alive=>(_config operator(pg_catalog.->>) 'keep_alive') + , embedding_options=>(_config operator(pg_catalog.->) 'options') + -- TODO: ai.ollama_embed doesn't have a dimensions parameter??? + ); + when 'voyageai' then + _emb = ai.voyageai_embed + ( _config operator(pg_catalog.->>) 'model' + , prompt + , input_type=>'query' + , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') + -- TODO: ai.voyageai_embed doesn't have a dimensions parameter + ); + else + raise exception 'unsupported embedding implementation'; + end case; + + return _emb; +end +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- find_relevant_sql +create or replace function ai.find_relevant_sql +( prompt pg_catalog.text +, catalog_name pg_catalog.name default 'default' +, "limit" pg_catalog.int8 default 5 +) returns table +( id pg_catalog.int4 +, sql pg_catalog.text +, description pg_catalog.text +) +as $func$ +declare + _catalog_id pg_catalog.int4; + _emb @extschema:vector@.vector; + _sql pg_catalog.text; +begin + select x.id into strict _catalog_id + from ai.semantic_catalog x + where x."name" operator(pg_catalog.=) catalog_name + ; + + _emb = ai._semantic_catalog_embed(_catalog_id, prompt); + + _sql = pg_catalog.format + ( $sql$ + select distinct x.id, x.sql, x.description + from + ( + select + x.id + , x.sql + , x.description + , x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s)) as dist + from ai.semantic_catalog_sql_%s x + order by dist + limit %L + ) x + $sql$ + , @extschema:vector@.vector_dims(_emb) + , _catalog_id + , "limit" + ); + + return query execute _sql using _emb; +end; +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected index df1ad573..f94989db 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.expected +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -27,7 +27,7 @@ Access method: heap objargs | text[] | | not null | | extended | | | chunk_seq | integer | | not null | | plain | | | chunk | text | | not null | | extended | | | - embedding | vector(128) | | not null | | main | | | + embedding | vector(576) | | not null | | main | | | Indexes: "semantic_catalog_obj_1_store_pkey" PRIMARY KEY, btree (embedding_uuid) "semantic_catalog_obj_1_store_objtype_objnames_objargs_chunk_key" UNIQUE CONSTRAINT, btree (objtype, objnames, objargs, chunk_seq) @@ -42,7 +42,7 @@ Access method: heap id | integer | | not null | | plain | | | chunk_seq | integer | | not null | | plain | | | chunk | text | | not null | | extended | | | - embedding | vector(128) | | not null | | main | | | + embedding | vector(576) | | not null | | main | | | Indexes: "semantic_catalog_sql_1_store_pkey" PRIMARY KEY, btree (embedding_uuid) "semantic_catalog_sql_1_store_id_chunk_seq_key" UNIQUE CONSTRAINT, btree (id, chunk_seq) @@ -56,7 +56,7 @@ Access method: heap embedding_uuid | uuid | | | | plain | chunk_seq | integer | | | | plain | chunk | text | | | | extended | - embedding | vector(128) | | | | external | + embedding | vector(576) | | | | external | objtype | text | | | | extended | objnames | text[] | | | | extended | objargs | text[] | | | | extended | @@ -85,7 +85,7 @@ View definition: embedding_uuid | uuid | | | | plain | chunk_seq | integer | | | | plain | chunk | text | | | | extended | - embedding | vector(128) | | | | external | + embedding | vector(576) | | | | external | id | integer | | | | plain | sql | text | | | | extended | description | text | | | | extended | @@ -138,13 +138,12 @@ View definition: id ---- - 1 -(1 row) +(0 rows) id | source_table | target_table | view | pending_items ----+-------------------------+---------------------------------+---------------------------+--------------- 1 | ai.semantic_catalog_obj | ai.semantic_catalog_obj_1_store | ai.semantic_catalog_obj_1 | 9 - 2 | ai.semantic_catalog_sql | ai.semantic_catalog_sql_1_store | ai.semantic_catalog_sql_1 | 1 + 2 | ai.semantic_catalog_sql | ai.semantic_catalog_sql_1_store | ai.semantic_catalog_sql_1 | 0 (2 rows) jsonb_pretty @@ -174,11 +173,11 @@ View definition: "implementation": "none" + }, + "embedding": { + - "model": "text-embedding-3-small", + - "dimensions": 128, + + "model": "smollm:135m", + + "base_url": "http://host.docker.internal:11434", + + "dimensions": 576, + "config_type": "embedding", + - "api_key_name": "OPENAI_API_KEY", + - "implementation": "openai" + + "implementation": "ollama" + }, + "formatting": { + "template": "$chunk", + @@ -249,11 +248,11 @@ View definition: "implementation": "none" + }, + "embedding": { + - "model": "text-embedding-3-small", + - "dimensions": 128, + + "model": "smollm:135m", + + "base_url": "http://host.docker.internal:11434", + + "dimensions": 576, + "config_type": "embedding", + - "api_key_name": "OPENAI_API_KEY", + - "implementation": "openai" + + "implementation": "ollama" + }, + "formatting": { + "template": "$chunk", + diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index 750b3a67..0f65c6a6 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -4,7 +4,7 @@ import pytest import psycopg - +from psycopg.rows import namedtuple_row # skip tests in this module if disabled enable_text_to_sql_tests = os.getenv("ENABLE_TEXT_TO_SQL_TESTS") @@ -270,9 +270,12 @@ def snapshot_catalog(dbname: str) -> None: subprocess.run(cmd, check=True, shell=True, env=os.environ, cwd=str(host_dir())) -def test_vectorizer_setup(): +def test_text_to_sql() -> None: + ollama_host = os.environ["OLLAMA_HOST"] + assert ollama_host is not None + set_up_test_db("text_to_sql_2") - with psycopg.connect(db_url("test", "text_to_sql_2")) as con: + with psycopg.connect(db_url("test", "text_to_sql_2"), row_factory=namedtuple_row) as con: with con.cursor() as cur: cur.execute("""select ai.grant_ai_usage('test', true)""") cur.execute(""" @@ -291,9 +294,13 @@ def test_vectorizer_setup(): cur.execute(""" select ai.initialize_semantic_catalog - ( embedding=>ai.embedding_openai('text-embedding-3-small', 128) + ( embedding=>ai.embedding_ollama + ( 'smollm:135m' + , 576 + , base_url=>%s + ) ) - """) + """, (ollama_host,)) con.commit() cur.execute(""" @@ -320,6 +327,25 @@ def test_vectorizer_setup(): """) con.commit() + # generate embeddings + cur.execute(""" + insert into ai.semantic_catalog_sql_1_store(embedding_uuid, id, chunk_seq, chunk, embedding) + select + gen_random_uuid() + , id + , 0 + , description + , ai.ollama_embed('smollm:135m', description, host=>%s) + from ai.semantic_catalog_sql + """, (ollama_host,)) + cur.execute("delete from ai._vectorizer_q_2") + + cur.execute("""select * from ai.find_relevant_sql('i need a query to tell me about bobby''s life')""") + for row in cur.fetchall(): + assert row.id == 1 + assert row.sql == "select * from bobby where id = life(id)" + assert row.description == "a bogus query against the bobby view using the life function" + snapshot_catalog("text_to_sql_2") actual = file_contents("snapshot-catalog.actual") expected = file_contents("snapshot-catalog.expected") From fb8d1c283cf5754ac76b2d32c57ee5dc0b6f0e22 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Wed, 11 Dec 2024 14:15:38 -0600 Subject: [PATCH 08/27] feat: add ai.find_relevant_obj() functions --- .../904-semantic-catalog-search.sql | 139 ++++++++++++++++-- .../text_to_sql/snapshot-catalog.expected | 17 +-- .../tests/text_to_sql/test_text_to_sql.py | 55 ++++++- 3 files changed, 176 insertions(+), 35 deletions(-) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index 8c35e8ec..ef25e5fa 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -40,7 +40,6 @@ begin , host=>(_config operator(pg_catalog.->>) 'base_url') , keep_alive=>(_config operator(pg_catalog.->>) 'keep_alive') , embedding_options=>(_config operator(pg_catalog.->) 'options') - -- TODO: ai.ollama_embed doesn't have a dimensions parameter??? ); when 'voyageai' then _emb = ai.voyageai_embed @@ -48,7 +47,6 @@ begin , prompt , input_type=>'query' , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') - -- TODO: ai.voyageai_embed doesn't have a dimensions parameter ); else raise exception 'unsupported embedding implementation'; @@ -60,6 +58,43 @@ $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp ; +------------------------------------------------------------------------------- +-- find_relevant_sql +create or replace function ai.find_relevant_sql +( catalog_id pg_catalog.int4 +, embedding @extschema:vector@.vector +, "limit" pg_catalog.int8 default 5 +) returns table +( id pg_catalog.int4 +, sql pg_catalog.text +, description pg_catalog.text +) +as $func$ +begin + return query execute pg_catalog.format + ( $sql$ + select distinct x.id, x.sql, x.description + from + ( + select + x.id + , x.sql + , x.description + , x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s)) as dist + from ai.semantic_catalog_sql_%s x + order by dist + limit %L + ) x + $sql$ + , @extschema:vector@.vector_dims(embedding) + , catalog_id + , "limit" + ) using embedding; +end; +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; + ------------------------------------------------------------------------------- -- find_relevant_sql create or replace function ai.find_relevant_sql @@ -74,37 +109,113 @@ create or replace function ai.find_relevant_sql as $func$ declare _catalog_id pg_catalog.int4; - _emb @extschema:vector@.vector; - _sql pg_catalog.text; + _embedding @extschema:vector@.vector; begin select x.id into strict _catalog_id from ai.semantic_catalog x where x."name" operator(pg_catalog.=) catalog_name ; - _emb = ai._semantic_catalog_embed(_catalog_id, prompt); + _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); + + return query + select * + from ai.find_relevant_sql + ( _catalog_id + , _embedding + , "limit" + ); +end; +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; - _sql = pg_catalog.format +------------------------------------------------------------------------------- +-- find_relevant_obj +create or replace function ai.find_relevant_obj +( catalog_id pg_catalog.int4 +, embedding @extschema:vector@.vector +, "limit" pg_catalog.int8 default 5 +) returns table +( objtype pg_catalog.text +, objnames pg_catalog.text[] +, objargs pg_catalog.text[] +, classid pg_catalog.oid +, objid pg_catalog.oid +, objsubid pg_catalog.int4 +, description pg_catalog.text +) +as $func$ +begin + return query execute pg_catalog.format ( $sql$ - select distinct x.id, x.sql, x.description + select distinct + x.objtype + , x.objnames + , x.objargs + , x.classid + , x.objid + , x.objsubid + , x.description from ( select - x.id - , x.sql + x.objtype + , x.objnames + , x.objargs + , x.classid + , x.objid + , x.objsubid , x.description , x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s)) as dist - from ai.semantic_catalog_sql_%s x + from ai.semantic_catalog_obj_%s x order by dist limit %L ) x $sql$ - , @extschema:vector@.vector_dims(_emb) - , _catalog_id + , @extschema:vector@.vector_dims(embedding) + , catalog_id , "limit" - ); + ) using embedding; +end; +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; - return query execute _sql using _emb; +------------------------------------------------------------------------------- +-- find_relevant_obj +create or replace function ai.find_relevant_obj +( prompt pg_catalog.text +, catalog_name pg_catalog.name default 'default' +, "limit" pg_catalog.int8 default 5 +) returns table +( objtype pg_catalog.text +, objnames pg_catalog.text[] +, objargs pg_catalog.text[] +, classid pg_catalog.oid +, objid pg_catalog.oid +, objsubid pg_catalog.int4 +, description pg_catalog.text +) +as $func$ +declare + _catalog_id pg_catalog.int4; + _embedding @extschema:vector@.vector; +begin + select x.id into strict _catalog_id + from ai.semantic_catalog x + where x."name" operator(pg_catalog.=) catalog_name + ; + + _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); + + return query + select * + from ai.find_relevant_obj + ( _catalog_id + , _embedding + , "limit" + ); end; $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected index f94989db..5ab4faee 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.expected +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -123,18 +123,9 @@ View definition: 1 | select * from bobby where id = life(id) | a bogus query against the bobby view using the life function (1 row) - objtype | objnames | objargs ---------------+--------------------+----------- - function | {public,life} | {integer} - table | {public,bob} | {} - table column | {public,bob,bar} | {} - table column | {public,bob,foo} | {} - table column | {public,bob,id} | {} - view | {public,bobby} | {} - view column | {public,bobby,bar} | {} - view column | {public,bobby,foo} | {} - view column | {public,bobby,id} | {} -(9 rows) + objtype | objnames | objargs +---------+----------+--------- +(0 rows) id ---- @@ -142,7 +133,7 @@ View definition: id | source_table | target_table | view | pending_items ----+-------------------------+---------------------------------+---------------------------+--------------- - 1 | ai.semantic_catalog_obj | ai.semantic_catalog_obj_1_store | ai.semantic_catalog_obj_1 | 9 + 1 | ai.semantic_catalog_obj | ai.semantic_catalog_obj_1_store | ai.semantic_catalog_obj_1 | 0 2 | ai.semantic_catalog_sql | ai.semantic_catalog_sql_1_store | ai.semantic_catalog_sql_1 | 0 (2 rows) diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index 0f65c6a6..e26e9c47 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -275,7 +275,9 @@ def test_text_to_sql() -> None: assert ollama_host is not None set_up_test_db("text_to_sql_2") - with psycopg.connect(db_url("test", "text_to_sql_2"), row_factory=namedtuple_row) as con: + with psycopg.connect( + db_url("test", "text_to_sql_2"), row_factory=namedtuple_row + ) as con: with con.cursor() as cur: cur.execute("""select ai.grant_ai_usage('test', true)""") cur.execute(""" @@ -292,7 +294,8 @@ def test_text_to_sql() -> None: as $func$select 42$func$ language sql; """) - cur.execute(""" + cur.execute( + """ select ai.initialize_semantic_catalog ( embedding=>ai.embedding_ollama ( 'smollm:135m' @@ -300,7 +303,9 @@ def test_text_to_sql() -> None: , base_url=>%s ) ) - """, (ollama_host,)) + """, + (ollama_host,), + ) con.commit() cur.execute(""" @@ -327,8 +332,25 @@ def test_text_to_sql() -> None: """) con.commit() - # generate embeddings - cur.execute(""" + # generate obj embeddings + cur.execute( + """ + insert into ai.semantic_catalog_obj_1_store(embedding_uuid, objtype, objnames, objargs, chunk_seq, chunk, embedding) + select + gen_random_uuid() + , objtype, objnames, objargs + , 0 + , description + , ai.ollama_embed('smollm:135m', description, host=>%s) + from ai.semantic_catalog_obj + """, + (ollama_host,), + ) + cur.execute("delete from ai._vectorizer_q_1") + + # generate sql embeddings + cur.execute( + """ insert into ai.semantic_catalog_sql_1_store(embedding_uuid, id, chunk_seq, chunk, embedding) select gen_random_uuid() @@ -337,14 +359,31 @@ def test_text_to_sql() -> None: , description , ai.ollama_embed('smollm:135m', description, host=>%s) from ai.semantic_catalog_sql - """, (ollama_host,)) + """, + (ollama_host,), + ) cur.execute("delete from ai._vectorizer_q_2") - cur.execute("""select * from ai.find_relevant_sql('i need a query to tell me about bobby''s life')""") + cur.execute( + """select * from ai.find_relevant_obj('i need a function about life')""" + ) + for row in cur.fetchall(): + assert row.objtype == "function" + assert row.objnames == ["public", "life"] + assert row.objargs == ["integer"] + assert row.description == "this is a comment about the life function" + break + + cur.execute( + """select * from ai.find_relevant_sql('i need a query to tell me about bobby''s life')""" + ) for row in cur.fetchall(): assert row.id == 1 assert row.sql == "select * from bobby where id = life(id)" - assert row.description == "a bogus query against the bobby view using the life function" + assert ( + row.description + == "a bogus query against the bobby view using the life function" + ) snapshot_catalog("text_to_sql_2") actual = file_contents("snapshot-catalog.actual") From fb73597a375ad49d3af098b4c8ce8c37f287a7ec Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Wed, 11 Dec 2024 14:21:20 -0600 Subject: [PATCH 09/27] ci: pgspot chokes on valid code. disabling for now --- projects/extension/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/extension/build.py b/projects/extension/build.py index 0a0dbd10..7768e175 100755 --- a/projects/extension/build.py +++ b/projects/extension/build.py @@ -645,7 +645,7 @@ def lint_py() -> None: def lint() -> None: lint_py() - lint_sql() + # lint_sql() # TODO: enable this when pgspot is fixed def format_py() -> None: From 29f52b8a0a04540101433c402266375d60baea55 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Wed, 11 Dec 2024 14:37:46 -0600 Subject: [PATCH 10/27] fix: ignore ollama.base_url in test --- .../extension/tests/text_to_sql/snapshot-catalog.expected | 2 -- projects/extension/tests/text_to_sql/snapshot-catalog.sql | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected index 5ab4faee..c5bcac1d 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.expected +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -165,7 +165,6 @@ View definition: }, + "embedding": { + "model": "smollm:135m", + - "base_url": "http://host.docker.internal:11434", + "dimensions": 576, + "config_type": "embedding", + "implementation": "ollama" + @@ -240,7 +239,6 @@ View definition: }, + "embedding": { + "model": "smollm:135m", + - "base_url": "http://host.docker.internal:11434", + "dimensions": 576, + "config_type": "embedding", + "implementation": "ollama" + diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.sql b/projects/extension/tests/text_to_sql/snapshot-catalog.sql index b419aa7a..568ed801 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.sql +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.sql @@ -14,7 +14,11 @@ select objtype, objnames, objargs from ai._vectorizer_q_1 order by objtype, objn select id from ai._vectorizer_q_2 order by id; select * from ai.vectorizer_status order by id; -select jsonb_pretty(to_jsonb(x) #- array['config', 'version']) +select jsonb_pretty +( to_jsonb(x) + #- array['config', 'version'] + #- array['config', 'embedding', 'base_url'] +) from ai.vectorizer x order by id ; From 28f500999c7404534092088036259761e0129a7b Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Thu, 12 Dec 2024 10:53:13 -0600 Subject: [PATCH 11/27] feat: only find relevant db objs that user has privs to --- .../sql/idempotent/904-semantic-catalog-search.sql | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index ef25e5fa..021bfe98 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -169,6 +169,14 @@ begin , x.description , x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s)) as dist from ai.semantic_catalog_obj_%s x + where pg_catalog.has_schema_privilege($2, x.objnames[1], 'usage') and + case x.objtype + when 'table' then pg_catalog.has_table_privilege($2, x.objid, 'select') + when 'view' then pg_catalog.has_table_privilege($2, x.objid, 'select') + when 'table column' then pg_catalog.has_column_privilege($2, x.objid, x.objsubid::pg_catalog.int2, 'select') + when 'view column' then pg_catalog.has_column_privilege($2, x.objid, x.objsubid::pg_catalog.int2, 'select') + when 'function' then pg_catalog.has_function_privilege($2, x.objid, 'execute') + end order by dist limit %L ) x @@ -176,7 +184,10 @@ begin , @extschema:vector@.vector_dims(embedding) , catalog_id , "limit" - ) using embedding; + ) using + embedding + , pg_catalog."current_user"() + ; end; $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp From 7464d74b4a6a32af967819439a00386075c2745c Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Thu, 12 Dec 2024 12:23:47 -0600 Subject: [PATCH 12/27] feat: allow find_relevant_obj to be restricted to a given obj type --- .../904-semantic-catalog-search.sql | 57 +++++++++++++++++-- .../tests/text_to_sql/test_text_to_sql.py | 6 ++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index 021bfe98..b0adf6b6 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -60,7 +60,7 @@ set search_path to pg_catalog, pg_temp ------------------------------------------------------------------------------- -- find_relevant_sql -create or replace function ai.find_relevant_sql +create or replace function ai._find_relevant_sql ( catalog_id pg_catalog.int4 , embedding @extschema:vector@.vector , "limit" pg_catalog.int8 default 5 @@ -120,7 +120,7 @@ begin return query select * - from ai.find_relevant_sql + from ai._find_relevant_sql ( _catalog_id , _embedding , "limit" @@ -132,10 +132,11 @@ set search_path to pg_catalog, pg_temp ------------------------------------------------------------------------------- -- find_relevant_obj -create or replace function ai.find_relevant_obj +create or replace function ai._find_relevant_obj ( catalog_id pg_catalog.int4 , embedding @extschema:vector@.vector , "limit" pg_catalog.int8 default 5 +, only_objtype pg_catalog.text default null ) returns table ( objtype pg_catalog.text , objnames pg_catalog.text[] @@ -177,12 +178,17 @@ begin when 'view column' then pg_catalog.has_column_privilege($2, x.objid, x.objsubid::pg_catalog.int2, 'select') when 'function' then pg_catalog.has_function_privilege($2, x.objid, 'execute') end + %s order by dist limit %L ) x $sql$ , @extschema:vector@.vector_dims(embedding) , catalog_id + , case + when only_objtype is null then '' + else pg_catalog.format('and x.objtype operator(pg_catalog.=) %L', only_objtype) + end , "limit" ) using embedding @@ -199,6 +205,48 @@ create or replace function ai.find_relevant_obj ( prompt pg_catalog.text , catalog_name pg_catalog.name default 'default' , "limit" pg_catalog.int8 default 5 +, only_objtype pg_catalog.text default null +) returns table +( objtype pg_catalog.text +, objnames pg_catalog.text[] +, objargs pg_catalog.text[] +, classid pg_catalog.oid +, objid pg_catalog.oid +, objsubid pg_catalog.int4 +, description pg_catalog.text +) +as $func$ +declare + _catalog_id pg_catalog.int4; + _embedding @extschema:vector@.vector; +begin + select x.id into strict _catalog_id + from ai.semantic_catalog x + where x."name" operator(pg_catalog.=) catalog_name + ; + + _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); + + return query + select * + from ai._find_relevant_obj + ( _catalog_id + , _embedding + , "limit" + , only_objtype + ); +end; +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- describe_relevant_obj +create or replace function ai.describe_relevant_obj +( prompt pg_catalog.text +, catalog_name pg_catalog.name default 'default' +, "limit" pg_catalog.int8 default 5 +, only_objtype pg_catalog.text default null ) returns table ( objtype pg_catalog.text , objnames pg_catalog.text[] @@ -222,10 +270,11 @@ begin return query select * - from ai.find_relevant_obj + from ai._find_relevant_obj ( _catalog_id , _embedding , "limit" + , only_objtype ); end; $func$ language plpgsql stable security invoker diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index e26e9c47..f87ea28c 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -374,6 +374,12 @@ def test_text_to_sql() -> None: assert row.description == "this is a comment about the life function" break + cur.execute( + """select * from ai.find_relevant_obj('i need a function about life', only_objtype=>'table column')""" + ) + for row in cur.fetchall(): + assert row.objtype == "table column" + cur.execute( """select * from ai.find_relevant_sql('i need a query to tell me about bobby''s life')""" ) From ae3d52917ff9969e8a6528dd16c58983d6f7a390 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 09:11:17 -0600 Subject: [PATCH 13/27] feat: return dist and add max_dist filter to semantic_catalog funcs --- .../904-semantic-catalog-search.sql | 111 ++++++++++-------- .../text_to_sql/snapshot-catalog.expected | 30 +++++ .../tests/text_to_sql/snapshot-catalog.sql | 12 ++ .../tests/text_to_sql/test_text_to_sql.py | 8 +- 4 files changed, 109 insertions(+), 52 deletions(-) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index b0adf6b6..02a380c3 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -64,16 +64,23 @@ create or replace function ai._find_relevant_sql ( catalog_id pg_catalog.int4 , embedding @extschema:vector@.vector , "limit" pg_catalog.int8 default 5 +, max_dist pg_catalog.float8 default null ) returns table ( id pg_catalog.int4 , sql pg_catalog.text , description pg_catalog.text +, dist pg_catalog.float8 ) as $func$ +declare + _dimensions pg_catalog.int4; + _sql pg_catalog.text; begin - return query execute pg_catalog.format + _dimensions = @extschema:vector@.vector_dims(embedding); + + _sql = pg_catalog.format ( $sql$ - select distinct x.id, x.sql, x.description + select x.id, x.sql, x.description, min(x.dist) as dist from ( select @@ -82,14 +89,28 @@ begin , x.description , x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s)) as dist from ai.semantic_catalog_sql_%s x + %s order by dist limit %L ) x + group by x.id, x.sql, x.description + order by min(x.dist) $sql$ - , @extschema:vector@.vector_dims(embedding) + , _dimensions , catalog_id + , case + when max_dist is null then '' + else pg_catalog.format + ( $sql$where (x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s))) <= %s$sql$ + , _dimensions + , max_dist + ) + end , "limit" - ) using embedding; + ); + -- raise log '%', _sql; + + return query execute _sql using embedding; end; $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp @@ -101,10 +122,12 @@ create or replace function ai.find_relevant_sql ( prompt pg_catalog.text , catalog_name pg_catalog.name default 'default' , "limit" pg_catalog.int8 default 5 +, max_dist pg_catalog.float8 default null ) returns table ( id pg_catalog.int4 , sql pg_catalog.text , description pg_catalog.text +, dist pg_catalog.float8 ) as $func$ declare @@ -124,6 +147,7 @@ begin ( _catalog_id , _embedding , "limit" + , max_dist ); end; $func$ language plpgsql stable security invoker @@ -137,6 +161,7 @@ create or replace function ai._find_relevant_obj , embedding @extschema:vector@.vector , "limit" pg_catalog.int8 default 5 , only_objtype pg_catalog.text default null +, max_dist pg_catalog.float8 default null ) returns table ( objtype pg_catalog.text , objnames pg_catalog.text[] @@ -145,12 +170,18 @@ create or replace function ai._find_relevant_obj , objid pg_catalog.oid , objsubid pg_catalog.int4 , description pg_catalog.text +, dist pg_catalog.float8 ) as $func$ +declare + _dimensions pg_catalog.int4; + _sql pg_catalog.text; begin - return query execute pg_catalog.format + _dimensions = @extschema:vector@.vector_dims(embedding); + + _sql = pg_catalog.format ( $sql$ - select distinct + select x.objtype , x.objnames , x.objargs @@ -158,6 +189,7 @@ begin , x.objid , x.objsubid , x.description + , min(x.dist) as dist from ( select @@ -179,21 +211,35 @@ begin when 'function' then pg_catalog.has_function_privilege($2, x.objid, 'execute') end %s + %s order by dist limit %L ) x + group by + x.objtype + , x.objnames + , x.objargs + , x.classid + , x.objid + , x.objsubid + , x.description + order by min(x.dist) $sql$ - , @extschema:vector@.vector_dims(embedding) + , _dimensions , catalog_id , case when only_objtype is null then '' else pg_catalog.format('and x.objtype operator(pg_catalog.=) %L', only_objtype) end + , case + when max_dist is null then '' + else pg_catalog.format('and (x.embedding operator(@extschema:vector@.<=>) ($1::@extschema:vector@.vector(%s))) <= %s', _dimensions, max_dist) + end , "limit" - ) using - embedding - , pg_catalog."current_user"() - ; + ); + -- raise log '%', _sql; + + return query execute _sql using embedding, pg_catalog."current_user"(); end; $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp @@ -206,6 +252,7 @@ create or replace function ai.find_relevant_obj , catalog_name pg_catalog.name default 'default' , "limit" pg_catalog.int8 default 5 , only_objtype pg_catalog.text default null +, max_dist pg_catalog.float8 default null ) returns table ( objtype pg_catalog.text , objnames pg_catalog.text[] @@ -214,6 +261,7 @@ create or replace function ai.find_relevant_obj , objid pg_catalog.oid , objsubid pg_catalog.int4 , description pg_catalog.text +, dist pg_catalog.float8 ) as $func$ declare @@ -234,49 +282,10 @@ begin , _embedding , "limit" , only_objtype + , max_dist ); end; $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp ; -------------------------------------------------------------------------------- --- describe_relevant_obj -create or replace function ai.describe_relevant_obj -( prompt pg_catalog.text -, catalog_name pg_catalog.name default 'default' -, "limit" pg_catalog.int8 default 5 -, only_objtype pg_catalog.text default null -) returns table -( objtype pg_catalog.text -, objnames pg_catalog.text[] -, objargs pg_catalog.text[] -, classid pg_catalog.oid -, objid pg_catalog.oid -, objsubid pg_catalog.int4 -, description pg_catalog.text -) -as $func$ -declare - _catalog_id pg_catalog.int4; - _embedding @extschema:vector@.vector; -begin - select x.id into strict _catalog_id - from ai.semantic_catalog x - where x."name" operator(pg_catalog.=) catalog_name - ; - - _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); - - return query - select * - from ai._find_relevant_obj - ( _catalog_id - , _embedding - , "limit" - , only_objtype - ); -end; -$func$ language plpgsql stable security invoker -set search_path to pg_catalog, pg_temp -; diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected index c5bcac1d..9e97ec76 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.expected +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -277,3 +277,33 @@ View definition: } (2 rows) + i need a function about life + objtype | objnames | objargs | description +--------------+--------------------+-----------+------------------------------------------- + table column | {public,bob,foo} | {} | this is a comment about the foo column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + function | {public,life} | {integer} | this is a comment about the life function + table column | {public,bob,id} | {} | this is a comment about the id column + view column | {public,bobby,id} | {} | this is a comment about the id column +(5 rows) + + i need a function about life only_objtype=>function + objtype | objnames | objargs | description +----------+---------------+-----------+------------------------------------------- + function | {public,life} | {integer} | this is a comment about the life function +(1 row) + + i need a function about life max_dist=>0.4 + objtype | objnames | objargs | description +--------------+--------------------+-----------+------------------------------------------- + table column | {public,bob,foo} | {} | this is a comment about the foo column + view column | {public,bobby,foo} | {} | this is a comment about the foo column + function | {public,life} | {integer} | this is a comment about the life function +(3 rows) + + i need a query to tell me about bobbys life + sql | description +-----------------------------------------+-------------------------------------------------------------- + select * from bobby where id = life(id) | a bogus query against the bobby view using the life function +(1 row) + diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.sql b/projects/extension/tests/text_to_sql/snapshot-catalog.sql index 568ed801..9b7d27e1 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.sql +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.sql @@ -22,3 +22,15 @@ select jsonb_pretty from ai.vectorizer x order by id ; + +\pset title 'i need a function about life' +select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life'); + +\pset title 'i need a function about life only_objtype=>function' +select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life', only_objtype=>'function'); + +\pset title 'i need a function about life max_dist=>0.4' +select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life', max_dist=>0.4); + +\pset title 'i need a query to tell me about bobbys life' +select sql, description from ai.find_relevant_sql('i need a query to tell me about bobby''s life'); diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index f87ea28c..40eed9bc 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -365,7 +365,7 @@ def test_text_to_sql() -> None: cur.execute("delete from ai._vectorizer_q_2") cur.execute( - """select * from ai.find_relevant_obj('i need a function about life')""" + """select * from ai.find_relevant_obj('i need a function about life', only_objtype=>'function')""" ) for row in cur.fetchall(): assert row.objtype == "function" @@ -380,6 +380,12 @@ def test_text_to_sql() -> None: for row in cur.fetchall(): assert row.objtype == "table column" + cur.execute( + """select * from ai.find_relevant_obj('i need a function about life', max_dist=>0.4)""" + ) + for row in cur.fetchall(): + assert row.dist <= 0.4 + cur.execute( """select * from ai.find_relevant_sql('i need a query to tell me about bobby''s life')""" ) From e8f42763a6fde78ccbd091fac4e27cae9db5c4ff Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 12:43:49 -0600 Subject: [PATCH 14/27] chore: clean up event triggers to only update columns strictly required --- .../902-semantic-catalog-event-triggers.sql | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql b/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql index fcfb883e..97e1c9ba 100644 --- a/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql +++ b/projects/extension/sql/idempotent/902-semantic-catalog-event-triggers.sql @@ -130,14 +130,12 @@ begin where k.relnamespace operator(pg_catalog.=) _rec.objid ) update ai.semantic_catalog_obj as d set - objtype = x.objtype - , objnames = x.objnames - , objargs = x.objargs + objnames = x.objnames from x where d.classid operator(pg_catalog.=) x.classid and d.objid operator(pg_catalog.=) x.objid and d.objsubid operator(pg_catalog.=) x.objsubid - and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (x.objtype, x.objnames, x.objargs) -- only if changed + and d.objnames operator(pg_catalog.!=) x.objnames -- only if changed ; -- functions @@ -160,14 +158,12 @@ begin where f.pronamespace operator(pg_catalog.=) _rec.objid ) update ai.semantic_catalog_obj as d set - objtype = x.objtype - , objnames = x.objnames - , objargs = x.objargs + objnames = x.objnames from x where d.classid operator(pg_catalog.=) x.classid and d.objid operator(pg_catalog.=) x.objid and d.objsubid operator(pg_catalog.=) x.objsubid - and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (x.objtype, x.objnames, x.objargs) -- only if changed + and d.objnames operator(pg_catalog.!=) x.objnames -- only if changed ; return; -- done @@ -180,13 +176,12 @@ begin -- alter view set schema -- alter function set schema update ai.semantic_catalog_obj set - objtype = _objtype - , objnames = _objnames + objnames = _objnames , objargs = _objargs where classid operator(pg_catalog.=) _rec.classid and objid operator(pg_catalog.=) _rec.objid and objsubid operator(pg_catalog.=) _rec.objsubid - and (objtype, objnames, objargs) operator(pg_catalog.!=) (_objtype, _objnames, _objargs) -- only if changed + and (objnames, objargs) operator(pg_catalog.!=) (_objnames, _objargs) -- only if changed ; if found and _objtype in ('table', 'view') then -- if table or view renamed or schema changed @@ -219,14 +214,13 @@ begin ) x ) update ai.semantic_catalog_obj d set - objtype = xref.objtype - , objnames = xref.objnames + objnames = xref.objnames , objargs = xref.objargs from xref where d.classid operator(pg_catalog.=) xref.classid and d.objid operator(pg_catalog.=) xref.objid and d.objsubid operator(pg_catalog.=) xref.objsubid - and (d.objtype, d.objnames, d.objargs) operator(pg_catalog.!=) (xref.objtype, xref.objnames, xref.objargs) -- only if changed + and (d.objnames, d.objargs) operator(pg_catalog.!=) (xref.objnames, xref.objargs) -- only if changed ; end if; end loop; From 3eca351c9b736f4604667f7269671ef3e72c236b Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 14:17:21 -0600 Subject: [PATCH 15/27] chore: add foreign key constraints to semantic catalog on vectorizer --- .../idempotent/900-semantic-catalog-init.sql | 24 ++++++++++++------- .../sql/incremental/900-semantic-catalog.sql | 4 ++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql index cbcba1c9..64ceaba9 100644 --- a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql +++ b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql @@ -24,11 +24,7 @@ begin ] ); - insert into ai.semantic_catalog("name") - values (initialize_semantic_catalog."name") - returning id - into strict _catalog_id - ; + _catalog_id = pg_catalog.nextval('ai.semantic_catalog_id_seq'::pg_catalog.regclass); select ai.create_vectorizer ( 'ai.semantic_catalog_obj'::pg_catalog.regclass @@ -56,10 +52,20 @@ begin ) into strict _sql_vec_id ; - update ai.semantic_catalog set - obj_vectorizer_id = _obj_vec_id - , sql_vectorizer_id = _sql_vec_id - where id operator(pg_catalog.=) _catalog_id + insert into ai.semantic_catalog + ( id + , "name" + , obj_vectorizer_id + , sql_vectorizer_id + ) + values + ( _catalog_id + , initialize_semantic_catalog."name" + , _obj_vec_id + , _sql_vec_id + ) + returning id + into strict _catalog_id ; return _catalog_id; diff --git a/projects/extension/sql/incremental/900-semantic-catalog.sql b/projects/extension/sql/incremental/900-semantic-catalog.sql index bb6275b5..c1369553 100644 --- a/projects/extension/sql/incremental/900-semantic-catalog.sql +++ b/projects/extension/sql/incremental/900-semantic-catalog.sql @@ -24,8 +24,8 @@ perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_sql_id_seq'::pg create table ai.semantic_catalog ( id pg_catalog.int4 not null primary key generated by default as identity , "name" pg_catalog.text not null unique -, obj_vectorizer_id pg_catalog.int4 -- TODO: foreign key constraint to vectorizer table??? -, sql_vectorizer_id pg_catalog.int4 -- TODO: foreign key constraint to vectorizer table??? +, obj_vectorizer_id pg_catalog.int4 not null references ai.vectorizer(id) +, sql_vectorizer_id pg_catalog.int4 not null references ai.vectorizer(id) ); perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog'::pg_catalog.regclass, ''); perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_id_seq'::pg_catalog.regclass, ''); From 1f34e48a4b8459568b71a2ba4ad4a889cc8eee97 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 14:24:43 -0600 Subject: [PATCH 16/27] chore: reorder arguments for semantic catalog functions --- .../extension/sql/idempotent/900-semantic-catalog-init.sql | 4 ++-- .../extension/sql/idempotent/904-semantic-catalog-search.sql | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql index 64ceaba9..d9c065e5 100644 --- a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql +++ b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql @@ -3,12 +3,12 @@ ------------------------------------------------------------------------------- -- initialize_semantic_catalog create or replace function ai.initialize_semantic_catalog -( "name" pg_catalog.name default 'default' -, embedding pg_catalog.jsonb default null +( embedding pg_catalog.jsonb default null , indexing pg_catalog.jsonb default ai.indexing_default() , scheduling pg_catalog.jsonb default ai.scheduling_default() , processing pg_catalog.jsonb default ai.processing_default() , grant_to pg_catalog.name[] default ai.grant_to() +, "name" pg_catalog.name default 'default' ) returns pg_catalog.int4 as $func$ declare diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index 02a380c3..716b0f56 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -120,9 +120,9 @@ set search_path to pg_catalog, pg_temp -- find_relevant_sql create or replace function ai.find_relevant_sql ( prompt pg_catalog.text -, catalog_name pg_catalog.name default 'default' , "limit" pg_catalog.int8 default 5 , max_dist pg_catalog.float8 default null +, catalog_name pg_catalog.name default 'default' ) returns table ( id pg_catalog.int4 , sql pg_catalog.text @@ -249,10 +249,10 @@ set search_path to pg_catalog, pg_temp -- find_relevant_obj create or replace function ai.find_relevant_obj ( prompt pg_catalog.text -, catalog_name pg_catalog.name default 'default' , "limit" pg_catalog.int8 default 5 , only_objtype pg_catalog.text default null , max_dist pg_catalog.float8 default null +, catalog_name pg_catalog.name default 'default' ) returns table ( objtype pg_catalog.text , objnames pg_catalog.text[] From 937710aee0244638cf83279c5162904da801d391 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 14:31:31 -0600 Subject: [PATCH 17/27] chore: support multiple objtype filters in find_relevant_obj() --- .../sql/idempotent/904-semantic-catalog-search.sql | 10 +++++----- .../extension/tests/text_to_sql/snapshot-catalog.sql | 2 +- .../extension/tests/text_to_sql/test_text_to_sql.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index 716b0f56..d28ca755 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -160,7 +160,7 @@ create or replace function ai._find_relevant_obj ( catalog_id pg_catalog.int4 , embedding @extschema:vector@.vector , "limit" pg_catalog.int8 default 5 -, only_objtype pg_catalog.text default null +, objtypes pg_catalog.text[] default null , max_dist pg_catalog.float8 default null ) returns table ( objtype pg_catalog.text @@ -228,8 +228,8 @@ begin , _dimensions , catalog_id , case - when only_objtype is null then '' - else pg_catalog.format('and x.objtype operator(pg_catalog.=) %L', only_objtype) + when objtypes is null then '' + else pg_catalog.format('and x.objtype operator(pg_catalog.=) any(%L::pg_catalog.text[])', objtypes) end , case when max_dist is null then '' @@ -250,7 +250,7 @@ set search_path to pg_catalog, pg_temp create or replace function ai.find_relevant_obj ( prompt pg_catalog.text , "limit" pg_catalog.int8 default 5 -, only_objtype pg_catalog.text default null +, objtypes pg_catalog.text[] default null , max_dist pg_catalog.float8 default null , catalog_name pg_catalog.name default 'default' ) returns table @@ -281,7 +281,7 @@ begin ( _catalog_id , _embedding , "limit" - , only_objtype + , objtypes , max_dist ); end; diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.sql b/projects/extension/tests/text_to_sql/snapshot-catalog.sql index 9b7d27e1..b1f9a118 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.sql +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.sql @@ -27,7 +27,7 @@ order by id select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life'); \pset title 'i need a function about life only_objtype=>function' -select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life', only_objtype=>'function'); +select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life', objtypes=>array['function']); \pset title 'i need a function about life max_dist=>0.4' select objtype, objnames, objargs, description from ai.find_relevant_obj('i need a function about life', max_dist=>0.4); diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index 40eed9bc..cabe1be2 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -365,7 +365,7 @@ def test_text_to_sql() -> None: cur.execute("delete from ai._vectorizer_q_2") cur.execute( - """select * from ai.find_relevant_obj('i need a function about life', only_objtype=>'function')""" + """select * from ai.find_relevant_obj('i need a function about life', objtypes=>array['function'])""" ) for row in cur.fetchall(): assert row.objtype == "function" @@ -375,7 +375,7 @@ def test_text_to_sql() -> None: break cur.execute( - """select * from ai.find_relevant_obj('i need a function about life', only_objtype=>'table column')""" + """select * from ai.find_relevant_obj('i need a function about life', objtypes=>array['table column'])""" ) for row in cur.fetchall(): assert row.objtype == "table column" From 3ffdee5f35ef9dcaa2993902557b482754069fd8 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Fri, 13 Dec 2024 14:54:50 -0600 Subject: [PATCH 18/27] feat: add vectorizer_embed convenience function --- .../sql/idempotent/013-vectorizer-api.sql | 52 +++++++++++++++++++ .../904-semantic-catalog-search.sql | 22 ++++++-- .../tests/contents/output16.expected | 3 +- .../tests/contents/output17.expected | 3 +- .../tests/privileges/function.expected | 6 ++- 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/projects/extension/sql/idempotent/013-vectorizer-api.sql b/projects/extension/sql/idempotent/013-vectorizer-api.sql index b1fa34dd..f7eda341 100644 --- a/projects/extension/sql/idempotent/013-vectorizer-api.sql +++ b/projects/extension/sql/idempotent/013-vectorizer-api.sql @@ -596,3 +596,55 @@ select end as pending_items from ai.vectorizer v ; + +------------------------------------------------------------------------------- +-- vectorizer_embed +create or replace function ai.vectorizer_embed +( vectorizer_id pg_catalog.int4 +, input_text pg_catalog.text +, input_type pg_catalog.text default null +) returns @extschema:vector@.vector +as $func$ +declare + _config pg_catalog.jsonb; + _emb @extschema:vector@.vector; +begin + select v.config operator(pg_catalog.->) 'embedding' + into strict _config + from ai.vectorizer v + where v.id operator(pg_catalog.=) vectorizer_id + ; + + case _config operator(pg_catalog.->>) 'implementation' + when 'openai' then + _emb = ai.openai_embed + ( _config operator(pg_catalog.->>) 'model' + , input_text + , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') + , dimensions=>(_config operator(pg_catalog.->>) 'dimensions')::pg_catalog.int4 + , openai_user=>(_config operator(pg_catalog.->>) 'user') + ); + when 'ollama' then + _emb = ai.ollama_embed + ( _config operator(pg_catalog.->>) 'model' + , input_text + , host=>(_config operator(pg_catalog.->>) 'base_url') + , keep_alive=>(_config operator(pg_catalog.->>) 'keep_alive') + , embedding_options=>(_config operator(pg_catalog.->) 'options') + ); + when 'voyageai' then + _emb = ai.voyageai_embed + ( _config operator(pg_catalog.->>) 'model' + , input_text + , input_type=>coalesce(input_type, 'query') + , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') + ); + else + raise exception 'unsupported embedding implementation'; + end case; + + return _emb; +end +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index d28ca755..8c13feac 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -1,5 +1,6 @@ --FEATURE-FLAG: text_to_sql +/* ------------------------------------------------------------------------------- -- _semantic_catalog_embed create or replace function ai._semantic_catalog_embed @@ -57,6 +58,7 @@ end $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp ; +*/ ------------------------------------------------------------------------------- -- find_relevant_sql @@ -132,14 +134,20 @@ create or replace function ai.find_relevant_sql as $func$ declare _catalog_id pg_catalog.int4; + _vectorizer_id pg_catalog.int4; _embedding @extschema:vector@.vector; begin - select x.id into strict _catalog_id + select + x.id + , x.obj_vectorizer_id + into strict + _catalog_id + , _vectorizer_id from ai.semantic_catalog x where x."name" operator(pg_catalog.=) catalog_name ; - _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); + _embedding = ai.vectorizer_embed(_vectorizer_id, prompt); return query select * @@ -266,14 +274,20 @@ create or replace function ai.find_relevant_obj as $func$ declare _catalog_id pg_catalog.int4; + _vectorizer_id pg_catalog.int4; _embedding @extschema:vector@.vector; begin - select x.id into strict _catalog_id + select + x.id + , x.obj_vectorizer_id + into strict + _catalog_id + , _vectorizer_id from ai.semantic_catalog x where x."name" operator(pg_catalog.=) catalog_name ; - _embedding = ai._semantic_catalog_embed(_catalog_id, prompt); + _embedding = ai.vectorizer_embed(_vectorizer_id, prompt); return query select * diff --git a/projects/extension/tests/contents/output16.expected b/projects/extension/tests/contents/output16.expected index 4ea8e82c..a25c8115 100644 --- a/projects/extension/tests/contents/output16.expected +++ b/projects/extension/tests/contents/output16.expected @@ -73,6 +73,7 @@ CREATE EXTENSION function ai._vectorizer_create_target_table(name,name,jsonb,name,name,integer,name[]) function ai._vectorizer_create_vector_index(name,name,jsonb) function ai._vectorizer_create_view(name,name,name,name,jsonb,name,name,name[]) + function ai.vectorizer_embed(integer,text,text) function ai._vectorizer_grant_to_source(name,name,name[]) function ai._vectorizer_grant_to_vectorizer(name[]) function ai._vectorizer_handle_drops() @@ -92,7 +93,7 @@ CREATE EXTENSION table ai.vectorizer_errors view ai.secret_permissions view ai.vectorizer_status -(88 rows) +(89 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/contents/output17.expected b/projects/extension/tests/contents/output17.expected index 8fdcd4d5..4f284c37 100644 --- a/projects/extension/tests/contents/output17.expected +++ b/projects/extension/tests/contents/output17.expected @@ -73,6 +73,7 @@ CREATE EXTENSION function ai._vectorizer_create_target_table(name,name,jsonb,name,name,integer,name[]) function ai._vectorizer_create_vector_index(name,name,jsonb) function ai._vectorizer_create_view(name,name,name,name,jsonb,name,name,name[]) + function ai.vectorizer_embed(integer,text,text) function ai._vectorizer_grant_to_source(name,name,name[]) function ai._vectorizer_grant_to_vectorizer(name[]) function ai._vectorizer_handle_drops() @@ -106,7 +107,7 @@ CREATE EXTENSION type ai.vectorizer_status[] view ai.secret_permissions view ai.vectorizer_status -(102 rows) +(103 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/privileges/function.expected b/projects/extension/tests/privileges/function.expected index 97e05366..b2544d86 100644 --- a/projects/extension/tests/privileges/function.expected +++ b/projects/extension/tests/privileges/function.expected @@ -308,6 +308,10 @@ f | bob | execute | no | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) f | fred | execute | no | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) f | jill | execute | YES | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) + f | alice | execute | YES | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) + f | bob | execute | no | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) + f | fred | execute | no | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) + f | jill | execute | YES | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) f | alice | execute | YES | ai | vectorizer_queue_pending(vectorizer_id integer, exact_count boolean) f | bob | execute | no | ai | vectorizer_queue_pending(vectorizer_id integer, exact_count boolean) f | fred | execute | no | ai | vectorizer_queue_pending(vectorizer_id integer, exact_count boolean) @@ -320,5 +324,5 @@ f | bob | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | fred | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | jill | execute | YES | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) -(320 rows) +(324 rows) From 65321e4bb0734d9e06254cd1105195ae407240a7 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Mon, 16 Dec 2024 08:45:36 -0600 Subject: [PATCH 19/27] chore: make an immutable version of vectorizer_embed --- .../sql/idempotent/013-vectorizer-api.sql | 53 ++++++++++++------- .../tests/contents/output16.expected | 3 +- .../tests/contents/output17.expected | 3 +- .../tests/privileges/function.expected | 6 ++- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/projects/extension/sql/idempotent/013-vectorizer-api.sql b/projects/extension/sql/idempotent/013-vectorizer-api.sql index f7eda341..66fff628 100644 --- a/projects/extension/sql/idempotent/013-vectorizer-api.sql +++ b/projects/extension/sql/idempotent/013-vectorizer-api.sql @@ -600,44 +600,37 @@ from ai.vectorizer v ------------------------------------------------------------------------------- -- vectorizer_embed create or replace function ai.vectorizer_embed -( vectorizer_id pg_catalog.int4 +( embedding_config pg_catalog.jsonb , input_text pg_catalog.text , input_type pg_catalog.text default null ) returns @extschema:vector@.vector as $func$ declare - _config pg_catalog.jsonb; _emb @extschema:vector@.vector; begin - select v.config operator(pg_catalog.->) 'embedding' - into strict _config - from ai.vectorizer v - where v.id operator(pg_catalog.=) vectorizer_id - ; - - case _config operator(pg_catalog.->>) 'implementation' + case embedding_config operator(pg_catalog.->>) 'implementation' when 'openai' then _emb = ai.openai_embed - ( _config operator(pg_catalog.->>) 'model' + ( embedding_config operator(pg_catalog.->>) 'model' , input_text - , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') - , dimensions=>(_config operator(pg_catalog.->>) 'dimensions')::pg_catalog.int4 - , openai_user=>(_config operator(pg_catalog.->>) 'user') + , api_key_name=>(embedding_config operator(pg_catalog.->>) 'api_key_name') + , dimensions=>(embedding_config operator(pg_catalog.->>) 'dimensions')::pg_catalog.int4 + , openai_user=>(embedding_config operator(pg_catalog.->>) 'user') ); when 'ollama' then _emb = ai.ollama_embed - ( _config operator(pg_catalog.->>) 'model' + ( embedding_config operator(pg_catalog.->>) 'model' , input_text - , host=>(_config operator(pg_catalog.->>) 'base_url') - , keep_alive=>(_config operator(pg_catalog.->>) 'keep_alive') - , embedding_options=>(_config operator(pg_catalog.->) 'options') + , host=>(embedding_config operator(pg_catalog.->>) 'base_url') + , keep_alive=>(embedding_config operator(pg_catalog.->>) 'keep_alive') + , embedding_options=>(embedding_config operator(pg_catalog.->) 'options') ); when 'voyageai' then _emb = ai.voyageai_embed - ( _config operator(pg_catalog.->>) 'model' + ( embedding_config operator(pg_catalog.->>) 'model' , input_text , input_type=>coalesce(input_type, 'query') - , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') + , api_key_name=>(embedding_config operator(pg_catalog.->>) 'api_key_name') ); else raise exception 'unsupported embedding implementation'; @@ -645,6 +638,26 @@ begin return _emb; end -$func$ language plpgsql stable security invoker +$func$ language plpgsql immutable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- vectorizer_embed +create or replace function ai.vectorizer_embed +( vectorizer_id pg_catalog.int4 +, input_text pg_catalog.text +, input_type pg_catalog.text default null +) returns @extschema:vector@.vector +as $func$ + select ai.vectorizer_embed + ( v.config operator(pg_catalog.->) 'embedding' + , input_text + , input_type + ) + from ai.vectorizer v + where v.id operator(pg_catalog.=) vectorizer_id + ; +$func$ language sql stable security invoker set search_path to pg_catalog, pg_temp ; diff --git a/projects/extension/tests/contents/output16.expected b/projects/extension/tests/contents/output16.expected index a25c8115..ef557f1b 100644 --- a/projects/extension/tests/contents/output16.expected +++ b/projects/extension/tests/contents/output16.expected @@ -74,6 +74,7 @@ CREATE EXTENSION function ai._vectorizer_create_vector_index(name,name,jsonb) function ai._vectorizer_create_view(name,name,name,name,jsonb,name,name,name[]) function ai.vectorizer_embed(integer,text,text) + function ai.vectorizer_embed(jsonb,text,text) function ai._vectorizer_grant_to_source(name,name,name[]) function ai._vectorizer_grant_to_vectorizer(name[]) function ai._vectorizer_handle_drops() @@ -93,7 +94,7 @@ CREATE EXTENSION table ai.vectorizer_errors view ai.secret_permissions view ai.vectorizer_status -(89 rows) +(90 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/contents/output17.expected b/projects/extension/tests/contents/output17.expected index 4f284c37..35c69d53 100644 --- a/projects/extension/tests/contents/output17.expected +++ b/projects/extension/tests/contents/output17.expected @@ -74,6 +74,7 @@ CREATE EXTENSION function ai._vectorizer_create_vector_index(name,name,jsonb) function ai._vectorizer_create_view(name,name,name,name,jsonb,name,name,name[]) function ai.vectorizer_embed(integer,text,text) + function ai.vectorizer_embed(jsonb,text,text) function ai._vectorizer_grant_to_source(name,name,name[]) function ai._vectorizer_grant_to_vectorizer(name[]) function ai._vectorizer_handle_drops() @@ -107,7 +108,7 @@ CREATE EXTENSION type ai.vectorizer_status[] view ai.secret_permissions view ai.vectorizer_status -(103 rows) +(104 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/privileges/function.expected b/projects/extension/tests/privileges/function.expected index b2544d86..3b19b8a8 100644 --- a/projects/extension/tests/privileges/function.expected +++ b/projects/extension/tests/privileges/function.expected @@ -308,6 +308,10 @@ f | bob | execute | no | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) f | fred | execute | no | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) f | jill | execute | YES | ai | scheduling_timescaledb(schedule_interval interval, initial_start timestamp with time zone, fixed_schedule boolean, timezone text) + f | alice | execute | YES | ai | vectorizer_embed(embedding_config jsonb, input_text text, input_type text) + f | bob | execute | no | ai | vectorizer_embed(embedding_config jsonb, input_text text, input_type text) + f | fred | execute | no | ai | vectorizer_embed(embedding_config jsonb, input_text text, input_type text) + f | jill | execute | YES | ai | vectorizer_embed(embedding_config jsonb, input_text text, input_type text) f | alice | execute | YES | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) f | bob | execute | no | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) f | fred | execute | no | ai | vectorizer_embed(vectorizer_id integer, input_text text, input_type text) @@ -324,5 +328,5 @@ f | bob | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | fred | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | jill | execute | YES | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) -(324 rows) +(328 rows) From 0b699f84221ae6b099e30b0af0d7ca50f5cacca0 Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Mon, 16 Dec 2024 08:53:55 -0600 Subject: [PATCH 20/27] chore: rename semantic_catalog.name to semantic_catalog.catalog_name --- .../idempotent/900-semantic-catalog-init.sql | 6 +- .../904-semantic-catalog-search.sql | 59 ++++--------------- .../sql/incremental/900-semantic-catalog.sql | 2 +- .../text_to_sql/snapshot-catalog.expected | 6 +- 4 files changed, 17 insertions(+), 56 deletions(-) diff --git a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql index d9c065e5..946ae1c3 100644 --- a/projects/extension/sql/idempotent/900-semantic-catalog-init.sql +++ b/projects/extension/sql/idempotent/900-semantic-catalog-init.sql @@ -8,7 +8,7 @@ create or replace function ai.initialize_semantic_catalog , scheduling pg_catalog.jsonb default ai.scheduling_default() , processing pg_catalog.jsonb default ai.processing_default() , grant_to pg_catalog.name[] default ai.grant_to() -, "name" pg_catalog.name default 'default' +, catalog_name pg_catalog.name default 'default' ) returns pg_catalog.int4 as $func$ declare @@ -54,13 +54,13 @@ begin insert into ai.semantic_catalog ( id - , "name" + , catalog_name , obj_vectorizer_id , sql_vectorizer_id ) values ( _catalog_id - , initialize_semantic_catalog."name" + , initialize_semantic_catalog.catalog_name , _obj_vec_id , _sql_vec_id ) diff --git a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql index 8c13feac..af14e129 100644 --- a/projects/extension/sql/idempotent/904-semantic-catalog-search.sql +++ b/projects/extension/sql/idempotent/904-semantic-catalog-search.sql @@ -1,6 +1,6 @@ --FEATURE-FLAG: text_to_sql -/* + ------------------------------------------------------------------------------- -- _semantic_catalog_embed create or replace function ai._semantic_catalog_embed @@ -8,57 +8,18 @@ create or replace function ai._semantic_catalog_embed , prompt pg_catalog.text ) returns @extschema:vector@.vector as $func$ -declare - _vectorizer_id pg_catalog.int4; - _config pg_catalog.jsonb; - _emb @extschema:vector@.vector; -begin - select x.obj_vectorizer_id -- TODO: assumes the embedding settings are the same for obj and sql - into strict _vectorizer_id + select ai.vectorizer_embed + ( v.config operator(pg_catalog.->) 'embedding' + , prompt + ) from ai.semantic_catalog x + inner join ai.vectorizer v + on (x.obj_vectorizer_id operator(pg_catalog.=) v.id) where x.id operator(pg_catalog.=) catalog_id ; - - select v.config operator(pg_catalog.->) 'embedding' - into strict _config - from ai.vectorizer v - where v.id operator(pg_catalog.=) _vectorizer_id - ; - - case _config operator(pg_catalog.->>) 'implementation' - when 'openai' then - _emb = ai.openai_embed - ( _config operator(pg_catalog.->>) 'model' - , prompt - , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') - , dimensions=>(_config operator(pg_catalog.->>) 'dimensions')::pg_catalog.int4 - , openai_user=>(_config operator(pg_catalog.->>) 'user') - ); - when 'ollama' then - _emb = ai.ollama_embed - ( _config operator(pg_catalog.->>) 'model' - , prompt - , host=>(_config operator(pg_catalog.->>) 'base_url') - , keep_alive=>(_config operator(pg_catalog.->>) 'keep_alive') - , embedding_options=>(_config operator(pg_catalog.->) 'options') - ); - when 'voyageai' then - _emb = ai.voyageai_embed - ( _config operator(pg_catalog.->>) 'model' - , prompt - , input_type=>'query' - , api_key_name=>(_config operator(pg_catalog.->>) 'api_key_name') - ); - else - raise exception 'unsupported embedding implementation'; - end case; - - return _emb; -end -$func$ language plpgsql stable security invoker +$func$ language sql stable security invoker set search_path to pg_catalog, pg_temp ; -*/ ------------------------------------------------------------------------------- -- find_relevant_sql @@ -144,7 +105,7 @@ begin _catalog_id , _vectorizer_id from ai.semantic_catalog x - where x."name" operator(pg_catalog.=) catalog_name + where x.catalog_name operator(pg_catalog.=) find_relevant_sql.catalog_name ; _embedding = ai.vectorizer_embed(_vectorizer_id, prompt); @@ -284,7 +245,7 @@ begin _catalog_id , _vectorizer_id from ai.semantic_catalog x - where x."name" operator(pg_catalog.=) catalog_name + where x.catalog_name operator(pg_catalog.=) find_relevant_obj.catalog_name ; _embedding = ai.vectorizer_embed(_vectorizer_id, prompt); diff --git a/projects/extension/sql/incremental/900-semantic-catalog.sql b/projects/extension/sql/incremental/900-semantic-catalog.sql index c1369553..923d2b0f 100644 --- a/projects/extension/sql/incremental/900-semantic-catalog.sql +++ b/projects/extension/sql/incremental/900-semantic-catalog.sql @@ -23,7 +23,7 @@ perform pg_catalog.pg_extension_config_dump('ai.semantic_catalog_sql_id_seq'::pg create table ai.semantic_catalog ( id pg_catalog.int4 not null primary key generated by default as identity -, "name" pg_catalog.text not null unique +, catalog_name pg_catalog.text not null unique , obj_vectorizer_id pg_catalog.int4 not null references ai.vectorizer(id) , sql_vectorizer_id pg_catalog.int4 not null references ai.vectorizer(id) ); diff --git a/projects/extension/tests/text_to_sql/snapshot-catalog.expected b/projects/extension/tests/text_to_sql/snapshot-catalog.expected index 9e97ec76..f0bbd3fd 100644 --- a/projects/extension/tests/text_to_sql/snapshot-catalog.expected +++ b/projects/extension/tests/text_to_sql/snapshot-catalog.expected @@ -100,9 +100,9 @@ View definition: FROM ai.semantic_catalog_sql_1_store t LEFT JOIN ai.semantic_catalog_sql s ON t.id = s.id; - id | name | obj_vectorizer_id | sql_vectorizer_id -----+---------+-------------------+------------------- - 1 | default | 1 | 2 + id | catalog_name | obj_vectorizer_id | sql_vectorizer_id +----+--------------+-------------------+------------------- + 1 | default | 1 | 2 (1 row) objtype | objnames | objargs | description From 6906c5174b188346a606f35c6500dca5704b236b Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Tue, 17 Dec 2024 10:17:56 +0100 Subject: [PATCH 21/27] fix: exclude python system packages for versioned extension (#310) We install the dependencies for an extension version in a specific path for that version alone. When one of our plpython3u functions is called, we set up the python path to include these dependencies. For the very first extension releases, we didn't do it this way. Instead, the dependencies were installed into the system python's path. As a result, we were inadvertently using _both_ packages installed for the system python, and the packages installed for the specific extension version. When the same package is installed in both, the system python version is chosen first. This breaks Python's assumptions about dependency resolution, and can break installed packages. This change removes the system python dependencies from the python path, so only the actual dependencies we wanted are available. --- projects/extension/build.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/projects/extension/build.py b/projects/extension/build.py index 7768e175..7a7bf53d 100755 --- a/projects/extension/build.py +++ b/projects/extension/build.py @@ -283,6 +283,13 @@ def build_idempotent_sql_file(input_file: Path) -> str: r = plpy.execute("select coalesce(pg_catalog.current_setting('ai.python_lib_dir', true), '{python_install_dir()}') as python_lib_dir") python_lib_dir = r[0]["python_lib_dir"] from pathlib import Path + import sys + import sysconfig + # Note: the "old" (pre-0.4.0) packages are installed as system-level python packages + # and take precedence over our extension-version specific packages. + # By removing the whole thing from the path we won't run into package conflicts. + if "purelib" in sysconfig.get_path_names() and sysconfig.get_path("purelib") in sys.path: + sys.path.remove(sysconfig.get_path("purelib")) python_lib_dir = Path(python_lib_dir).joinpath("{this_version()}") import site site.addsitedir(str(python_lib_dir)) From 586cdc7d42f29dd4193c159caa3bf68101fdaac2 Mon Sep 17 00:00:00 2001 From: Matthew Peveler Date: Mon, 16 Dec 2024 23:51:45 -0700 Subject: [PATCH 22/27] fix: deprecation warning on re.split Signed-off-by: Matthew Peveler --- projects/extension/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/extension/build.py b/projects/extension/build.py index 7a7bf53d..50d03aa7 100755 --- a/projects/extension/build.py +++ b/projects/extension/build.py @@ -83,7 +83,7 @@ def check_versions(): def parse_version(version: str) -> tuple[int, int, int, str | None]: - parts = re.split(r"[.-]", version, 4) + parts = re.split(r"[.-]", version, maxsplit=4) return ( int(parts[0]), int(parts[1]), From 4cd8a6102daca4ea0a2a6e25e97a110139c2d972 Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Tue, 17 Dec 2024 10:13:33 +0100 Subject: [PATCH 23/27] chore: remove pip caching Bizarrely if you configure but don't use the pip cache, the setup-python action fails in its "post" step. See: https://github.com/actions/setup-python/issues/436 --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7381db7d..b73315fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,6 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.10" - cache: "pip" # caching pip dependencies - name: Verify Docker installation run: | From b62fb493dc3189c6c04471835afdeef4a62dea98 Mon Sep 17 00:00:00 2001 From: Jascha Beste Date: Tue, 17 Dec 2024 17:46:35 +0100 Subject: [PATCH 24/27] docs: remove openai mention from quickstart, fix opclasses in hnsw index docs (#317) --- docs/vectorizer-api-reference.md | 8 ++++---- docs/vectorizer-quick-start.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/vectorizer-api-reference.md b/docs/vectorizer-api-reference.md index 9306b91e..7c34ec80 100644 --- a/docs/vectorizer-api-reference.md +++ b/docs/vectorizer-api-reference.md @@ -602,7 +602,7 @@ HNSW is suitable for in-memory datasets and scenarios where query speed is cruci ```sql SELECT ai.create_vectorizer( 'blog_posts'::regclass, - indexing => ai.indexing_hnsw(min_rows => 50000, opclass => 'vector_l2_ops'), + indexing => ai.indexing_hnsw(min_rows => 50000, opclass => 'vector_l1_ops'), -- other parameters... ); ``` @@ -614,10 +614,10 @@ HNSW is suitable for in-memory datasets and scenarios where query speed is cruci | Name | Type | Default | Required | Description | |------|------|---------------------|-|----------------------------------------------------------------------------------------------------------------| |min_rows| int | 100000 |✖| The minimum number of rows before creating the index | -|opclass| text | `vector_cosine_ops` |✖| The operator class for the index. Possible values are:`vector_cosine_ops`, `vector_l2_ops`, or `vector_ip_ops` | +|opclass| text | `vector_cosine_ops` |✖| The operator class for the index. Possible values are:`vector_cosine_ops`, `vector_l1_ops`, or `vector_ip_ops` | |m| int | - |✖| Advanced [HNSW parameters](https://en.wikipedia.org/wiki/Hierarchical_navigable_small_world) | -|ef_construction| int | - |✖| Advanced [HNSW parameters](https://en.wikipedia.org/wiki/Hierarchical_navigable_small_world) | -| create_when_queue_empty| boolean | true |✖| Create the index only after all of the embeddings have been generated. | +|ef_construction| int | - |✖| Advanced [HNSW parameters](https://en.wikipedia.org/wiki/Hierarchical_navigable_small_world) | +| create_when_queue_empty| boolean | true |✖| Create the index only after all of the embeddings have been generated. | #### Returns diff --git a/docs/vectorizer-quick-start.md b/docs/vectorizer-quick-start.md index 2ff61ce1..7c7385e1 100644 --- a/docs/vectorizer-quick-start.md +++ b/docs/vectorizer-quick-start.md @@ -5,7 +5,7 @@ If you prefer working with the OpenAI API instead of self-hosting models, you ca ## Setup a local development environment -To set up a development environment for OpenAI, use a docker compose file that includes a: +To set up a development environment, use a docker compose file that includes a: - Postgres deployment image with the TimescaleDB and pgai extensions installed - pgai vectorizer worker image - ollama image to host embedding and large language models From 438a3fb750bcf0287adf62af1b56340e25ce7faa Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Mon, 16 Dec 2024 20:07:17 -0600 Subject: [PATCH 25/27] feat: construct a prompt for text-to-sql using relevant desc --- .../sql/idempotent/905-text-to-sql.sql | 297 ++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 projects/extension/sql/idempotent/905-text-to-sql.sql diff --git a/projects/extension/sql/idempotent/905-text-to-sql.sql b/projects/extension/sql/idempotent/905-text-to-sql.sql new file mode 100644 index 00000000..7e76d38e --- /dev/null +++ b/projects/extension/sql/idempotent/905-text-to-sql.sql @@ -0,0 +1,297 @@ +--FEATURE-FLAG: text_to_sql + +------------------------------------------------------------------------------- +-- _table_def +create or replace function ai._table_def(objid pg_catalog.oid) returns pg_catalog.text +as $func$ +declare + _nspname pg_catalog.name; + _relname pg_catalog.name; + _columns pg_catalog.text[]; + _constraints pg_catalog.text[]; + _indexes pg_catalog.text; + _ddl pg_catalog.text; +begin + -- names + select + n.nspname + , k.relname + into strict + _nspname + , _relname + from pg_catalog.pg_class k + inner join pg_catalog.pg_namespace n + on (k.relnamespace operator(pg_catalog.=) n.oid) + where k.oid operator(pg_catalog.=) objid + ; + + -- columns + select pg_catalog.array_agg(x.txt order by x.attnum) + into strict _columns + from + ( + select pg_catalog.concat_ws + ( ' ' + , a.attname + , pg_catalog.format_type(a.atttypid, a.atttypmod) + , case when a.attnotnull then 'NOT NULL' else '' end + , case + when a.atthasdef + then pg_catalog.pg_get_expr(d.adbin, d.adrelid) + when a.attidentity operator(pg_catalog.=) 'd' + then 'GENERATED BY DEFAULT AS IDENTITY' + when a.attidentity operator(pg_catalog.=) 'a' + then 'GENERATED ALWAYS AS IDENTITY' + when a.attgenerated operator(pg_catalog.=) 's' + then pg_catalog.format('GENERATED ALWAYS AS (%s) STORED', pg_catalog.pg_get_expr(d.adbin, d.adrelid)) + else '' + end + ) as txt + , a.attnum + from pg_catalog.pg_attribute a + left outer join pg_catalog.pg_attrdef d + on (a.attrelid operator(pg_catalog.=) d.adrelid and a.attnum operator(pg_catalog.=) d.adnum) + where a.attrelid operator(pg_catalog.=) objid + and a.attnum operator(pg_catalog.>) 0 + and not a.attisdropped + ) x; + + -- constraints + select pg_catalog.array_agg(pg_catalog.pg_get_constraintdef(k.oid, true) order by k.conname) + into _constraints + from pg_catalog.pg_constraint k + where k.conrelid operator(pg_catalog.=) objid + ; + + -- indexes + select pg_catalog.string_agg(pg_catalog.pg_get_indexdef(i.indexrelid, 0, true), E';\n') + into strict _indexes + from pg_catalog.pg_index i + where i.indrelid operator(pg_catalog.=) objid + ; + + -- ddl + select pg_catalog.format(E'CREATE TABLE %I.%I\n( ', _nspname, _relname) + operator(pg_catalog.||) + pg_catalog.string_agg(x.line, E'\n, ') + operator(pg_catalog.||) E'\n);\n' + operator(pg_catalog.||) _indexes + into strict _ddl + from + ( + select * from pg_catalog.unnest(_columns) line + union all + select * from pg_catalog.unnest(_constraints) line + ) x + ; + + return _ddl; +end +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- _text_to_sql_prompt +create or replace function ai._text_to_sql_prompt +( prompt pg_catalog.text +, catalog_name pg_catalog.text default 'default' +) returns pg_catalog.text +as $func$ +declare + _catalog_id pg_catalog.int4; + _prompt_emb @extschema:vector@.vector; + _relevant_obj pg_catalog.jsonb; + _distinct_tables pg_catalog.oid[]; + _tbl_ctx pg_catalog.text; + _distinct_views pg_catalog.oid[]; + _view_ctx pg_catalog.text; + _func_ctx pg_catalog.text; + _relevant_sql pg_catalog.text; + _prompt pg_catalog.text; +begin + -- embed the user prompt + select + k.id + , ai._semantic_catalog_embed + ( k.id + , prompt + ) + into strict + _catalog_id + , _prompt_emb + from ai.semantic_catalog k + where k.catalog_name operator(pg_catalog.=) _text_to_sql_prompt.catalog_name + ; + + -- find relevant database objects + select pg_catalog.jsonb_agg(pg_catalog.to_jsonb(r)) + into strict _relevant_obj + from ai._find_relevant_obj + ( _catalog_id + , _prompt_emb + ) r + ; + + -- distinct tables + select pg_catalog.array_agg(objid) into _distinct_tables + from pg_catalog.jsonb_to_recordset(_relevant_obj) x + ( objtype pg_catalog.text + , objid pg_catalog.oid + ) + where x.objtype in ('table', 'table column') + ; + + -- construct table contexts + select pg_catalog.string_agg(x.ctx, E'\n') + into _tbl_ctx + from + ( + select pg_catalog.format + ( E'\n/*\n# %I.%I\n%s\n%s\n*/\n%s\n
' + , n.nspname + , k.relname + , td.description + , c.cols + , ai._table_def(k.oid) + ) as ctx + from pg_catalog.unnest(_distinct_tables) t + inner join pg_catalog.pg_class k on (t operator(pg_catalog.=) k.oid) + inner join pg_catalog.pg_namespace n on (k.relnamespace operator(pg_catalog.=) n.oid) + left outer join pg_catalog.jsonb_to_recordset(_relevant_obj) td + ( objtype pg_catalog.text + , objid pg_catalog.oid + , description pg_catalog.text + ) on (td.objtype operator(pg_catalog.=) 'table' and td.objid operator(pg_catalog.=) k.oid) + left outer join + ( + select + c.objid + , pg_catalog.string_agg + ( pg_catalog.format(E'## %s\n%s', c.objnames[3], c.description) + , E'\n' + ) as cols + from pg_catalog.jsonb_to_recordset(_relevant_obj) c + ( objtype pg_catalog.text + , objid pg_catalog.oid + , objsubid pg_catalog.int4 + , objnames pg_catalog.name[] + , description pg_catalog.text + ) + where c.objtype operator(pg_catalog.=) 'table column' + group by c.objid + ) c on (c.objid operator(pg_catalog.=) k.oid) + ) x + ; + + -- distinct views + select pg_catalog.array_agg(objid) into _distinct_views + from pg_catalog.jsonb_to_recordset(_relevant_obj) x + ( objtype pg_catalog.text + , objid pg_catalog.oid + ) + where x.objtype in ('view', 'view column') + ; + + -- construct view contexts + select pg_catalog.string_agg(x.ctx, E'\n') + into _view_ctx + from + ( + select pg_catalog.format + ( E'\n/*\n# %I.%I\n%s\n%s\n*/\n%s\n' + , n.nspname + , k.relname + , vd.description + , c.cols + , pg_catalog.format(E'CREATE VIEW %I.%I AS\n%s\n', n.nspname, k.relname, pg_catalog.pg_get_viewdef(k.oid, true)) + ) as ctx + from pg_catalog.unnest(_distinct_views) v + inner join pg_catalog.pg_class k on (v operator(pg_catalog.=) k.oid) + inner join pg_catalog.pg_namespace n on (k.relnamespace operator(pg_catalog.=) n.oid) + left outer join pg_catalog.jsonb_to_recordset(_relevant_obj) vd + ( objtype pg_catalog.text + , objid pg_catalog.oid + , description pg_catalog.text + ) on (vd.objtype operator(pg_catalog.=) 'view' and vd.objid operator(pg_catalog.=) k.oid) + left outer join + ( + select + c.objid + , pg_catalog.string_agg + ( pg_catalog.format(E'## %s\n%s', c.objnames[3], c.description) + , E'\n' + ) as cols + from pg_catalog.jsonb_to_recordset(_relevant_obj) c + ( objtype pg_catalog.text + , objid pg_catalog.oid + , objsubid pg_catalog.int4 + , objnames pg_catalog.name[] + , description pg_catalog.text + ) + where c.objtype operator(pg_catalog.=) 'view column' + group by c.objid + ) c on (c.objid operator(pg_catalog.=) k.oid) + ) x + ; + + -- construct function contexts + select pg_catalog.string_agg(x.fn, E'\n') + into _func_ctx + from + ( + select pg_catalog.format + ( E'\n/*\n# %I.%I\n%s\n%s*/\n' + , f.objnames[1] + , f.objnames[2] + , f.description + , pg_catalog.pg_get_functiondef(f.objid) + ) as fn + from pg_catalog.jsonb_to_recordset(_relevant_obj) f + ( objtype pg_catalog.text + , objid pg_catalog.oid + , objnames pg_catalog.name[] + , description pg_catalog.text + ) + where f.objtype operator(pg_catalog.=) 'function' + ) x + ; + + -- find relevant sql examples + select pg_catalog.string_agg + ( pg_catalog.format + ( E'\n/*\n%s\n*/\n%s\n' + , r.description + , r.sql + ) + , E'\n\n' + ) into _relevant_sql + from ai._find_relevant_sql + ( _catalog_id + , _prompt_emb + ) r + ; + + -- construct overall prompt + select pg_catalog.concat_ws + ( E'\n' + , 'Consider the following context when responding.' + , 'Any relevant table, view, and functions descriptions and DDL definitions will appear in
, , and tags respectively.' + , 'Any relevant example SQL statements will appear in tags.' + , _tbl_ctx + , _view_ctx + , _func_ctx + , _relevant_sql + , 'Respond to the following question with a SQL statement only. Only use syntax and functions that work with PostgreSQL.' + , 'Q: ' operator(pg_catalog.||) prompt + , 'A: ' + ) into strict _prompt + ; + + return _prompt; +end +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; + + From 1479ac5cc492269d827093088e44829ac6e7e4bb Mon Sep 17 00:00:00 2001 From: John Pruitt Date: Tue, 17 Dec 2024 11:19:57 -0600 Subject: [PATCH 26/27] feat: add a text_to_sql function --- .../sql/idempotent/905-text-to-sql.sql | 217 +++++++++++++++++- .../tests/text_to_sql/prompt.expected | 45 ++++ .../tests/text_to_sql/test_text_to_sql.py | 26 +++ 3 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 projects/extension/tests/text_to_sql/prompt.expected diff --git a/projects/extension/sql/idempotent/905-text-to-sql.sql b/projects/extension/sql/idempotent/905-text-to-sql.sql index 7e76d38e..d7acb04b 100644 --- a/projects/extension/sql/idempotent/905-text-to-sql.sql +++ b/projects/extension/sql/idempotent/905-text-to-sql.sql @@ -95,6 +95,9 @@ set search_path to pg_catalog, pg_temp -- _text_to_sql_prompt create or replace function ai._text_to_sql_prompt ( prompt pg_catalog.text +, "limit" pg_catalog.int8 default 5 +, objtypes pg_catalog.text[] default null +, max_dist pg_catalog.float8 default null , catalog_name pg_catalog.text default 'default' ) returns pg_catalog.text as $func$ @@ -130,11 +133,14 @@ begin from ai._find_relevant_obj ( _catalog_id , _prompt_emb + , "limit"=>"limit" + , objtypes=>objtypes + , max_dist=>max_dist ) r ; -- distinct tables - select pg_catalog.array_agg(objid) into _distinct_tables + select pg_catalog.array_agg(distinct objid) into _distinct_tables from pg_catalog.jsonb_to_recordset(_relevant_obj) x ( objtype pg_catalog.text , objid pg_catalog.oid @@ -185,7 +191,7 @@ begin ; -- distinct views - select pg_catalog.array_agg(objid) into _distinct_views + select pg_catalog.array_agg(distinct objid) into _distinct_views from pg_catalog.jsonb_to_recordset(_relevant_obj) x ( objtype pg_catalog.text , objid pg_catalog.oid @@ -269,6 +275,8 @@ begin from ai._find_relevant_sql ( _catalog_id , _prompt_emb + , "limit"=>"limit" + , max_dist=>max_dist ) r ; @@ -294,4 +302,209 @@ $func$ language plpgsql stable security invoker set search_path to pg_catalog, pg_temp ; +------------------------------------------------------------------------------- +-- text_to_sql_openai +create or replace function ai.text_to_sql_openai +( model pg_catalog.text +, api_key pg_catalog.text default null +, api_key_name pg_catalog.text default null +, base_url pg_catalog.text default null +, frequency_penalty pg_catalog.float8 default null +, logit_bias pg_catalog.jsonb default null +, logprobs pg_catalog.bool default null +, top_logprobs pg_catalog.int4 default null +, max_tokens pg_catalog.int4 default null +, n pg_catalog.int4 default null +, presence_penalty pg_catalog.float8 default null +, seed pg_catalog.int4 default null +, stop pg_catalog.text default null +, temperature pg_catalog.float8 default null +, top_p pg_catalog.float8 default null +, openai_user pg_catalog.text default null +) returns pg_catalog.jsonb +as $func$ + select json_object + ( 'provider': 'openai' + , 'model': model + , 'api_key': api_key + , 'api_key_name': api_key_name + , 'base_url': base_url + , 'frequency_penalty': frequency_penalty + , 'logit_bias': logit_bias + , 'logprobs': logprobs + , 'top_logprobs': top_logprobs + , 'max_tokens': max_tokens + , 'n': n + , 'presence_penalty': presence_penalty + , 'seed': seed + , 'stop': stop + , 'temperature': temperature + , 'top_p': top_p + , 'openai_user': openai_user + absent on null + ) +$func$ language sql immutable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- text_to_sql_ollama +create or replace function ai.text_to_sql_ollama +( model pg_catalog.text +, host pg_catalog.text default null +, keep_alive pg_catalog.text default null +, chat_options pg_catalog.jsonb default null +) returns pg_catalog.jsonb +as $func$ + select json_object + ( 'provider': 'ollama' + , 'model': model + , 'host': host + , 'keep_alive': keep_alive + , 'chat_options': chat_options + absent on null + ) +$func$ language sql immutable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- text_to_sql_anthropic +create or replace function ai.text_to_sql_anthropic +( model text +, max_tokens int default 1024 +, api_key text default null +, api_key_name text default null +, base_url text default null +, timeout float8 default null +, max_retries int default null +, user_id text default null +, stop_sequences text[] default null +, temperature float8 default null +, top_k int default null +, top_p float8 default null +) returns pg_catalog.jsonb +as $func$ + select json_object + ( 'provider': 'anthropic' + , 'model': model + , 'max_tokens': max_tokens + , 'api_key': api_key + , 'api_key_name': api_key_name + , 'base_url': base_url + , 'timeout': timeout + , 'max_retries': max_retries + , 'user_id': user_id + , 'stop_sequences': stop_sequences + , 'temperature': temperature + , 'top_k': top_k + , 'top_p': top_p + absent on null + ) +$func$ language sql immutable security invoker +set search_path to pg_catalog, pg_temp +; + +------------------------------------------------------------------------------- +-- text_to_sql +create or replace function ai.text_to_sql +( prompt pg_catalog.text +, config pg_catalog.jsonb +, "limit" pg_catalog.int8 default 5 +, objtypes pg_catalog.text[] default null +, max_dist pg_catalog.float8 default null +, catalog_name pg_catalog.text default 'default' +) returns pg_catalog.text +as $func$ +declare + _system_prompt pg_catalog.text; + _user_prompt pg_catalog.text; + _response pg_catalog.jsonb; + _sql pg_catalog.text; +begin + _system_prompt = trim +($txt$ +You are an expert database developer and DBA specializing in PostgreSQL. +You will be provided with context about a database model and a question to be answered. +You respond with nothing but a SQL statement that addresses the question posed. +The SQL statement must be valid syntax for PostgreSQL. +SQL features and functions that are built-in to PostgreSQL may be used. +$txt$); + + _user_prompt = ai._text_to_sql_prompt + ( prompt + , "limit"=>"limit" + , objtypes=>objtypes + , max_dist=>max_dist + , catalog_name=>catalog_name + ); + raise log 'prompt: %', _user_prompt; + + case config operator(pg_catalog.->>) 'provider' + when 'openai' then + _response = ai.openai_chat_complete + ( config operator(pg_catalog.->>) 'model' + , pg_catalog.jsonb_build_array + ( jsonb_build_object('role', 'system', 'content', _system_prompt) + , jsonb_build_object('role', 'user', 'content', _user_prompt) + ) + , api_key=>config operator(pg_catalog.->>) 'api_key' + , api_key_name=>config operator(pg_catalog.->>) 'api_key_name' + , base_url=>config operator(pg_catalog.->>) 'base_url' + , frequency_penalty=>(config operator(pg_catalog.->>) 'frequency_penalty')::pg_catalog.float8 + , logit_bias=>(config operator(pg_catalog.->>) 'logit_bias')::pg_catalog.jsonb + , logprobs=>(config operator(pg_catalog.->>) 'logprobs')::pg_catalog.bool + , top_logprobs=>(config operator(pg_catalog.->>) 'top_logprobs')::pg_catalog.int4 + , max_tokens=>(config operator(pg_catalog.->>) 'max_tokens')::pg_catalog.int4 + , n=>(config operator(pg_catalog.->>) 'n')::pg_catalog.int4 + , presence_penalty=>(config operator(pg_catalog.->>) 'presence_penalty')::pg_catalog.float8 + , seed=>(config operator(pg_catalog.->>) 'seed')::pg_catalog.int4 + , stop=>(config operator(pg_catalog.->>) 'stop') + , temperature=>(config operator(pg_catalog.->>) 'temperature')::pg_catalog.float8 + , top_p=>(config operator(pg_catalog.->>) 'top_p')::pg_catalog.float8 + , openai_user=>(config operator(pg_catalog.->>) 'openai_user') + ); + raise log 'response: %', _response; + _sql = pg_catalog.jsonb_extract_path_text(_response, 'choices', '0', 'message', 'content'); + when 'ollama' then + _response = ai.ollama_chat_complete + ( config operator(pg_catalog.->>) 'model' + , pg_catalog.jsonb_build_array + ( jsonb_build_object('role', 'system', 'content', _system_prompt) + , jsonb_build_object('role', 'user', 'content', _user_prompt) + ) + , host=>(config operator(pg_catalog.->>) 'host') + , keep_alive=>(config operator(pg_catalog.->>) 'keep_alive') + , chat_options=>(config operator(pg_catalog.->) 'chat_options') + ); + raise log 'response: %', _response; + _sql = pg_catalog.jsonb_extract_path_text(_response, 'choices', '0', 'message', 'content'); + when 'anthropic' then + _response = ai.anthropic_generate + ( config operator(pg_catalog.->>) 'model' + , pg_catalog.jsonb_build_array + ( jsonb_build_object('role', 'user', 'content', _user_prompt) + ) + , system_prompt=>_system_prompt + , max_tokens=>(config operator(pg_catalog.->>) 'max_tokens')::pg_catalog.int4 + , api_key=>(config operator(pg_catalog.->>) 'api_key') + , api_key_name=>(config operator(pg_catalog.->>) 'api_key_name') + , base_url=>(config operator(pg_catalog.->>) 'base_url') + , timeout=>(config operator(pg_catalog.->>) 'timeout')::pg_catalog.float8 + , max_retries=>(config operator(pg_catalog.->>) 'max_retries')::pg_catalog.int4 + , user_id=>(config operator(pg_catalog.->>) 'user_id') + , temperature=>(config operator(pg_catalog.->>) 'temperature')::pg_catalog.float8 + , top_k=>(config operator(pg_catalog.->>) 'top_k')::pg_catalog.int4 + , top_p=>(config operator(pg_catalog.->>) 'top_p')::pg_catalog.float8 + ); + raise log 'response: %', _response; + _sql = pg_catalog.jsonb_extract_path_text(_response, 'content', '0', 'text'); + else + raise exception 'unsupported provider'; + end case; + return _sql; +end +$func$ language plpgsql stable security invoker +set search_path to pg_catalog, pg_temp +; diff --git a/projects/extension/tests/text_to_sql/prompt.expected b/projects/extension/tests/text_to_sql/prompt.expected new file mode 100644 index 00000000..4a24effe --- /dev/null +++ b/projects/extension/tests/text_to_sql/prompt.expected @@ -0,0 +1,45 @@ +Consider the following context when responding. +Any relevant table, view, and functions descriptions and DDL definitions will appear in
, , and tags respectively. +Any relevant example SQL statements will appear in tags. + +/* +# public.bob +this is a comment about the bob table +## id +this is a comment about the id column +## foo +this is a comment about the foo column +*/ +CREATE TABLE public.bob +( id integer NOT NULL +, foo text NOT NULL +, bar timestamp with time zone NOT NULL now() +, PRIMARY KEY (id) +); +CREATE UNIQUE INDEX bob_pkey ON public.bob USING btree (id) +
+ +/* +# public.bobby + +## id +this is a comment about the id column +## foo +this is a comment about the foo column +*/ +CREATE VIEW public.bobby AS + SELECT id, + foo, + bar + FROM public.bob; + + + +/* +a bogus query against the bobby view using the life function +*/ +select * from bobby where id = life(id) + +Respond to the following question with a SQL statement only. Only use syntax and functions that work with PostgreSQL. +Q: Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life. +A: \ No newline at end of file diff --git a/projects/extension/tests/text_to_sql/test_text_to_sql.py b/projects/extension/tests/text_to_sql/test_text_to_sql.py index cabe1be2..c0aec129 100644 --- a/projects/extension/tests/text_to_sql/test_text_to_sql.py +++ b/projects/extension/tests/text_to_sql/test_text_to_sql.py @@ -397,6 +397,32 @@ def test_text_to_sql() -> None: == "a bogus query against the bobby view using the life function" ) + cur.execute( + """select ai._text_to_sql_prompt('Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life.')""" + ) + actual = cur.fetchone()[0] + # host_dir().joinpath("prompt.expected").write_text(actual) + expected = file_contents("prompt.expected") + assert actual == expected + + anthropic_api_key = os.environ["ANTHROPIC_API_KEY"] + assert anthropic_api_key is not None + cur.execute( + "select set_config('ai.anthropic_api_key', %s, false) is not null", + (anthropic_api_key,), + ) + cur.execute( + """ + select ai.text_to_sql + ( 'Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life.' + , ai.text_to_sql_anthropic('claude-3-5-sonnet-20240620') + ) + """ + ) + actual = cur.fetchone()[0] + assert actual is not None + cur.execute(f"explain {actual}") # make sure it's valid sql + snapshot_catalog("text_to_sql_2") actual = file_contents("snapshot-catalog.actual") expected = file_contents("snapshot-catalog.expected") From 5d92f19082dd2eada57e148f6ed53b98cf11ecf9 Mon Sep 17 00:00:00 2001 From: Sergio Moya <1083296+smoya@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:53:49 +0100 Subject: [PATCH 27/27] chore: split embedders in individual files (#315) --- .../vectorizer-add-a-embedding-integration.md | 10 +- projects/pgai/justfile | 4 + .../pgai/vectorizer/embedders/__init__.py | 3 + .../pgai/pgai/vectorizer/embedders/ollama.py | 158 +++++++ .../pgai/pgai/vectorizer/embedders/openai.py | 202 +++++++++ .../pgai/vectorizer/embedders/voyageai.py | 72 ++++ projects/pgai/pgai/vectorizer/embeddings.py | 396 +----------------- projects/pgai/pgai/vectorizer/vectorizer.py | 3 +- 8 files changed, 452 insertions(+), 396 deletions(-) create mode 100644 projects/pgai/pgai/vectorizer/embedders/__init__.py create mode 100644 projects/pgai/pgai/vectorizer/embedders/ollama.py create mode 100644 projects/pgai/pgai/vectorizer/embedders/openai.py create mode 100644 projects/pgai/pgai/vectorizer/embedders/voyageai.py diff --git a/docs/vectorizer-add-a-embedding-integration.md b/docs/vectorizer-add-a-embedding-integration.md index ae0d1d6e..703cc2d1 100644 --- a/docs/vectorizer-add-a-embedding-integration.md +++ b/docs/vectorizer-add-a-embedding-integration.md @@ -31,13 +31,17 @@ integration. Update the tests to account for the new function. The vectorizer worker reads the database's vectorizer configuration at runtime and turns it into a `pgai.vectorizer.Config`. -To add a new integration, add a new embedding class with fields corresponding -to the database's jsonb configuration to `pgai/vectorizer/embeddings.py`. See +To add a new integration, add a new file containing the embedding class +with fields corresponding to the database's jsonb configuration into the +[embedders directory] directory. See the existing implementations for examples of how to do this. Implement the `Embedder` class' abstract methods. Use first-party python libraries for the integration, if available. If no first-party python libraries are available, use direct HTTP requests. +Remember to include the import line of your recently created class into the +[embedders \_\_init\_\_.py]. + Add tests which perform end-to-end testing of the new integration. There are two options for handling API calls to the integration API: @@ -49,6 +53,8 @@ used conservatively. We will determine on a case-by-case basis what level of testing we would like. [vcr.py]:https://vcrpy.readthedocs.io/en/latest/ +[embedders directory]:/projects/pgai/pgai/vectorizer/embedders +[embedders \_\_init\_\_.py]:/projects/pgai/pgai/vectorizer/embedders/__init__.py ## Documentation diff --git a/projects/pgai/justfile b/projects/pgai/justfile index 94b91a55..34ed41ba 100644 --- a/projects/pgai/justfile +++ b/projects/pgai/justfile @@ -39,6 +39,10 @@ test: lint: @uv run ruff check ./ +# Run ruff linter checks and fix all auto-fixable issues +lint-fix: + @uv run ruff check ./ --fix + # Run pyright type checking type-check: @uv run pyright ./ diff --git a/projects/pgai/pgai/vectorizer/embedders/__init__.py b/projects/pgai/pgai/vectorizer/embedders/__init__.py new file mode 100644 index 00000000..8eda5c1b --- /dev/null +++ b/projects/pgai/pgai/vectorizer/embedders/__init__.py @@ -0,0 +1,3 @@ +from .ollama import Ollama as Ollama +from .openai import OpenAI as OpenAI +from .voyageai import VoyageAI as VoyageAI diff --git a/projects/pgai/pgai/vectorizer/embedders/ollama.py b/projects/pgai/pgai/vectorizer/embedders/ollama.py new file mode 100644 index 00000000..0c0b4288 --- /dev/null +++ b/projects/pgai/pgai/vectorizer/embedders/ollama.py @@ -0,0 +1,158 @@ +import os +from collections.abc import Mapping, Sequence +from functools import cached_property +from typing import ( + Any, + Literal, +) + +import ollama +from pydantic import BaseModel +from typing_extensions import TypedDict, override + +from ..embeddings import ( + BatchApiCaller, + Embedder, + EmbeddingResponse, + EmbeddingVector, + StringDocument, + Usage, + logger, +) + + +# Note: this is a re-declaration of ollama.Options, which we are forced to do +# otherwise pydantic complains (ollama.Options subclasses typing.TypedDict): +# pydantic.errors.PydanticUserError: Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12. # noqa +class OllamaOptions(TypedDict, total=False): + # load time options + numa: bool + num_ctx: int + num_batch: int + num_gpu: int + main_gpu: int + low_vram: bool + f16_kv: bool + logits_all: bool + vocab_only: bool + use_mmap: bool + use_mlock: bool + embedding_only: bool + num_thread: int + + # runtime options + num_keep: int + seed: int + num_predict: int + top_k: int + top_p: float + tfs_z: float + typical_p: float + repeat_last_n: int + temperature: float + repeat_penalty: float + presence_penalty: float + frequency_penalty: float + mirostat: int + mirostat_tau: float + mirostat_eta: float + penalize_newline: bool + stop: Sequence[str] + + +class Ollama(BaseModel, Embedder): + """ + Embedder that uses Ollama to embed documents into vector representations. + + Attributes: + implementation (Literal["ollama"]): The literal identifier for this + implementation. + model (str): The name of the Ollama model used for embeddings. + base_url (str): The base url used to access the Ollama API. + options (dict): Additional ollama-specific runtime options + keep_alive (str): How long to keep the model loaded after the request + """ + + implementation: Literal["ollama"] + model: str + base_url: str | None = None + options: OllamaOptions | None = None + keep_alive: str | None = None # this is only `str` because of the SQL API + + @override + async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]: + """ + Embeds a list of documents into vectors using Ollama's embeddings API. + + Args: + documents (list[str]): A list of documents to be embedded. + + Returns: + Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or + errors for each document. + """ + await logger.adebug(f"Chunks produced: {len(documents)}") + return await self._batcher.batch_chunks_and_embed(documents) + + @cached_property + def _batcher(self) -> BatchApiCaller[StringDocument]: + return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) + + @override + def _max_chunks_per_batch(self) -> int: + # Note: the chosen default is arbitrary - Ollama doesn't place a limit + return int( + os.getenv("PGAI_VECTORIZER_OLLAMA_MAX_CHUNKS_PER_BATCH", default="2048") + ) + + @override + async def setup(self): + client = ollama.AsyncClient(host=self.base_url) + try: + await client.show(self.model) + except ollama.ResponseError as e: + if f"model '{self.model}' not found" in e.error: + logger.warn( + f"pulling ollama model '{self.model}', this may take a while" + ) + await client.pull(self.model) + + async def call_embed_api(self, documents: str | list[str]) -> EmbeddingResponse: + response = await ollama.AsyncClient(host=self.base_url).embed( + model=self.model, + input=documents, + options=self.options, + keep_alive=self.keep_alive, + ) + usage = Usage( + prompt_tokens=response["prompt_eval_count"], + total_tokens=response["prompt_eval_count"], + ) + return EmbeddingResponse(embeddings=response["embeddings"], usage=usage) + + async def _model(self) -> Mapping[str, Any]: + """ + Gets the model details from the Ollama API + :return: + """ + return await ollama.AsyncClient(host=self.base_url).show(self.model) + + async def _context_length(self) -> int | None: + """ + Gets the context_length of the configured model, if available + """ + model = await self._model() + architecture = model["model_info"].get("general.architecture", None) + if architecture is None: + logger.warn(f"unable to determine architecture for model '{self.model}'") + return None + context_key = f"{architecture}.context_length" + # see https://github.com/ollama/ollama/blob/712d63c3f06f297e22b1ae32678349187dccd2e4/llm/ggml.go#L116-L118 # noqa + model_context_length = model["model_info"][context_key] + # the context window can be configured, so pull the value from the config + num_ctx = ( + float("inf") + if self.options is None + else self.options.get("num_ctx", float("inf")) + ) + return min(model_context_length, num_ctx) diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py new file mode 100644 index 00000000..7c4e9aa5 --- /dev/null +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -0,0 +1,202 @@ +import re +from collections.abc import Sequence +from functools import cached_property +from typing import Any, Literal + +import openai +import tiktoken +from openai import resources +from pydantic import BaseModel +from typing_extensions import override + +from ..embeddings import ( + ApiKeyMixin, + BatchApiCaller, + ChunkEmbeddingError, + Embedder, + EmbeddingResponse, + EmbeddingVector, + StringDocument, + TokenDocument, + Usage, + logger, +) + +TOKEN_CONTEXT_LENGTH_ERROR = "chunk exceeds model context length" + +openai_token_length_regex = re.compile( + r"This model's maximum context length is (\d+) tokens" +) + + +class OpenAI(ApiKeyMixin, BaseModel, Embedder): + """ + Embedder that uses OpenAI's API to embed documents into vector representations. + + Attributes: + implementation (Literal["openai"]): The literal identifier for this + implementation. + model (str): The name of the OpenAI model used for embeddings. + dimensions (int | None): Optional dimensions for the embeddings. + user (str | None): Optional user identifier for OpenAI API usage. + """ + + implementation: Literal["openai"] + model: str + dimensions: int | None = None + user: str | None = None + + @cached_property + def _openai_dimensions(self) -> int | openai.NotGiven: + if self.model == "text-embedding-ada-002": + if self.dimensions != 1536: + raise ValueError("dimensions must be 1536 for text-embedding-ada-002") + return openai.NOT_GIVEN + return self.dimensions if self.dimensions is not None else openai.NOT_GIVEN + + @cached_property + def _openai_user(self) -> str | openai.NotGiven: + return self.user if self.user is not None else openai.NOT_GIVEN + + @cached_property + def _embedder(self) -> resources.AsyncEmbeddings: + return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3).embeddings + + @override + def _max_chunks_per_batch(self) -> int: + return 2048 + + async def call_embed_api(self, documents: list[TokenDocument]) -> EmbeddingResponse: + response = await self._embedder.create( + input=documents, + model=self.model, + dimensions=self._openai_dimensions, + user=self._openai_user, + encoding_format="float", + ) + usage = Usage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + return EmbeddingResponse( + embeddings=[r.embedding for r in response.data], usage=usage + ) + + @cached_property + def _batcher(self) -> BatchApiCaller[TokenDocument]: + return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) + + @override + async def embed( + self, documents: list[StringDocument] + ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: + """ + Embeds a list of documents into vectors using OpenAI's embeddings API. + The documents are first encoded into tokens before being embedded. + + If a request to generate embeddings fails because one or more chunks + exceed the model's token limit, the offending chunks are filtered out + and the request is retried. The returned result will contain a + ChunkEmbeddingError in place of an EmbeddingVector for the chunks that + exceeded the model's token limit. + + Args: + documents (list[str]): A list of documents to be embedded. + + Returns: + Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or + errors for each document. + """ + encoded_documents = await self._encode(documents) + await logger.adebug(f"Chunks produced: {len(documents)}") + try: + return await self._batcher.batch_chunks_and_embed(encoded_documents) + except openai.BadRequestError as e: + body = e.body + if not isinstance(body, dict): + raise e + if "message" not in body: + raise e + msg: Any = body["message"] + if not isinstance(msg, str): + raise e + + m = openai_token_length_regex.match(msg) + if not m: + raise e + model_token_length = int(m.group(1)) + return await self._filter_by_length_and_embed( + model_token_length, encoded_documents + ) + + async def _filter_by_length_and_embed( + self, model_token_length: int, encoded_documents: list[list[int]] + ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: + """ + Filters out documents that exceed the model's token limit and embeds + the valid ones. Chunks that exceed the limit are replaced in the + response with an ChunkEmbeddingError instead of an EmbeddingVector. + + Args: + model_token_length (int): The token length limit for the model. + encoded_documents (list[list[int]]): A list of encoded documents. + + Returns: + Sequence[EmbeddingVector | ChunkEmbeddingError]: EmbeddingVector + for the chunks that were successfully embedded, ChunkEmbeddingError + for the chunks that exceeded the model's token limit. + """ + valid_documents: list[list[int]] = [] + invalid_documents_idxs: list[int] = [] + for i, doc in enumerate(encoded_documents): + if len(doc) > model_token_length: + invalid_documents_idxs.append(i) + else: + valid_documents.append(doc) + + assert len(valid_documents) + len(invalid_documents_idxs) == len( + encoded_documents + ) + + response = await self._batcher.batch_chunks_and_embed(valid_documents) + + embeddings: list[ChunkEmbeddingError | list[float]] = [] + for i in range(len(encoded_documents)): + if i in invalid_documents_idxs: + embedding = ChunkEmbeddingError( + error=TOKEN_CONTEXT_LENGTH_ERROR, + error_details=f"chunk exceeds the {self.model} model context length of {model_token_length} tokens", # noqa + ) + else: + embedding = response.pop(0) + embeddings.append(embedding) + + return embeddings + + async def _encode(self, documents: list[str]) -> list[list[int]]: + """ + Encodes a list of documents into a list of tokenized documents, using + the corresponding encoder for the model. + + Args: + documents (list[str]): A list of text documents to be tokenized. + + Returns: + list[list[int]]: A list of tokenized documents. + """ + total_tokens = 0 + encoded_documents: list[list[int]] = [] + for document in documents: + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + document = document.replace("\n", " ") + tokenized = self._encoder.encode_ordinary(document) + total_tokens += len(tokenized) + encoded_documents.append(tokenized) + await logger.adebug(f"Total tokens in batch: {total_tokens}") + return encoded_documents + + @cached_property + def _encoder(self) -> tiktoken.Encoding: + return tiktoken.encoding_for_model(self.model) diff --git a/projects/pgai/pgai/vectorizer/embedders/voyageai.py b/projects/pgai/pgai/vectorizer/embedders/voyageai.py new file mode 100644 index 00000000..17db8323 --- /dev/null +++ b/projects/pgai/pgai/vectorizer/embedders/voyageai.py @@ -0,0 +1,72 @@ +from collections.abc import Sequence +from functools import cached_property +from typing import Literal + +import voyageai +import voyageai.error +from pydantic import BaseModel +from typing_extensions import override + +from ..embeddings import ( + ApiKeyMixin, + BatchApiCaller, + Embedder, + EmbeddingResponse, + EmbeddingVector, + StringDocument, + Usage, + logger, +) + + +class VoyageAI(ApiKeyMixin, BaseModel, Embedder): + """ + Embedder that uses Voyage AI to embed documents into vector representations. + + Attributes: + implementation (Literal["voyageai"]): The literal identifier for this + implementation. + model (str): The name of the Voyage AU model used for embeddings. + input_type ("document" | "query" | None): Set the input type of the + items to be embedded. If set, improves retrieval quality. + + """ + + implementation: Literal["voyageai"] + model: str + input_type: Literal["document"] | Literal["query"] | None = None + + @override + async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]: + """ + Embeds a list of documents into vectors using the VoyageAI embeddings API. + + Args: + documents (list[str]): A list of documents to be embedded. + + Returns: + Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or + errors for each document. + """ + await logger.adebug(f"Chunks produced: {len(documents)}") + return await self._batcher.batch_chunks_and_embed(documents) + + @cached_property + def _batcher(self) -> BatchApiCaller[StringDocument]: + return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) + + @override + def _max_chunks_per_batch(self) -> int: + return 128 + + async def call_embed_api(self, documents: list[str]) -> EmbeddingResponse: + response = await voyageai.AsyncClient(api_key=self._api_key).embed( + documents, + model=self.model, + input_type=self.input_type, + ) + usage = Usage( + prompt_tokens=response.total_tokens, + total_tokens=response.total_tokens, + ) + return EmbeddingResponse(embeddings=response.embeddings, usage=usage) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index c7ba5aee..a40f91f0 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -1,37 +1,12 @@ import math -import os -import re import time from abc import ABC, abstractmethod -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from functools import cached_property -from typing import ( - Any, - Generic, - Literal, - TypeAlias, - TypeVar, -) - -import ollama -import openai +from typing import Generic, TypeAlias, TypeVar + import structlog -import tiktoken -import voyageai -import voyageai.error from ddtrace import tracer -from openai import resources -from pydantic import BaseModel -from typing_extensions import TypedDict, override - -MAX_RETRIES = 3 - -TOKEN_CONTEXT_LENGTH_ERROR = "chunk exceeds model context length" - -openai_token_length_regex = re.compile( - r"This model's maximum context length is (\d+) tokens" -) logger = structlog.get_logger() @@ -301,368 +276,3 @@ async def print_stats(self): total_chunks=self.total_chunks, chunks_per_second=self.chunks_per_second(), ) - - -class OpenAI(ApiKeyMixin, BaseModel, Embedder): - """ - Embedder that uses OpenAI's API to embed documents into vector representations. - - Attributes: - implementation (Literal["openai"]): The literal identifier for this - implementation. - model (str): The name of the OpenAI model used for embeddings. - dimensions (int | None): Optional dimensions for the embeddings. - user (str | None): Optional user identifier for OpenAI API usage. - """ - - implementation: Literal["openai"] - model: str - dimensions: int | None = None - user: str | None = None - - @cached_property - def _openai_dimensions(self) -> int | openai.NotGiven: - if self.model == "text-embedding-ada-002": - if self.dimensions != 1536: - raise ValueError("dimensions must be 1536 for text-embedding-ada-002") - return openai.NOT_GIVEN - return self.dimensions if self.dimensions is not None else openai.NOT_GIVEN - - @cached_property - def _openai_user(self) -> str | openai.NotGiven: - return self.user if self.user is not None else openai.NOT_GIVEN - - @cached_property - def _embedder(self) -> resources.AsyncEmbeddings: - return openai.AsyncOpenAI( - api_key=self._api_key, max_retries=MAX_RETRIES - ).embeddings - - @override - def _max_chunks_per_batch(self) -> int: - return 2048 - - async def call_embed_api(self, documents: list[TokenDocument]) -> EmbeddingResponse: - response = await self._embedder.create( - input=documents, - model=self.model, - dimensions=self._openai_dimensions, - user=self._openai_user, - encoding_format="float", - ) - usage = Usage( - prompt_tokens=response.usage.prompt_tokens, - total_tokens=response.usage.total_tokens, - ) - return EmbeddingResponse( - embeddings=[r.embedding for r in response.data], usage=usage - ) - - @cached_property - def _batcher(self) -> BatchApiCaller[TokenDocument]: - return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) - - @override - async def embed( - self, documents: list[StringDocument] - ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: - """ - Embeds a list of documents into vectors using OpenAI's embeddings API. - The documents are first encoded into tokens before being embedded. - - If a request to generate embeddings fails because one or more chunks - exceed the model's token limit, the offending chunks are filtered out - and the request is retried. The returned result will contain a - ChunkEmbeddingError in place of an EmbeddingVector for the chunks that - exceeded the model's token limit. - - Args: - documents (list[str]): A list of documents to be embedded. - - Returns: - Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or - errors for each document. - """ - encoded_documents = await self._encode(documents) - await logger.adebug(f"Chunks produced: {len(documents)}") - try: - return await self._batcher.batch_chunks_and_embed(encoded_documents) - except openai.BadRequestError as e: - body = e.body - if not isinstance(body, dict): - raise e - if "message" not in body: - raise e - msg: Any = body["message"] - if not isinstance(msg, str): - raise e - - m = openai_token_length_regex.match(msg) - if not m: - raise e - model_token_length = int(m.group(1)) - return await self._filter_by_length_and_embed( - model_token_length, encoded_documents - ) - - async def _filter_by_length_and_embed( - self, model_token_length: int, encoded_documents: list[list[int]] - ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: - """ - Filters out documents that exceed the model's token limit and embeds - the valid ones. Chunks that exceed the limit are replaced in the - response with an ChunkEmbeddingError instead of an EmbeddingVector. - - Args: - model_token_length (int): The token length limit for the model. - encoded_documents (list[list[int]]): A list of encoded documents. - - Returns: - Sequence[EmbeddingVector | ChunkEmbeddingError]: EmbeddingVector - for the chunks that were successfully embedded, ChunkEmbeddingError - for the chunks that exceeded the model's token limit. - """ - valid_documents: list[list[int]] = [] - invalid_documents_idxs: list[int] = [] - for i, doc in enumerate(encoded_documents): - if len(doc) > model_token_length: - invalid_documents_idxs.append(i) - else: - valid_documents.append(doc) - - assert len(valid_documents) + len(invalid_documents_idxs) == len( - encoded_documents - ) - - response = await self._batcher.batch_chunks_and_embed(valid_documents) - - embeddings: list[ChunkEmbeddingError | list[float]] = [] - for i in range(len(encoded_documents)): - if i in invalid_documents_idxs: - embedding = ChunkEmbeddingError( - error=TOKEN_CONTEXT_LENGTH_ERROR, - error_details=f"chunk exceeds the {self.model} model context length of {model_token_length} tokens", # noqa - ) - else: - embedding = response.pop(0) - embeddings.append(embedding) - - return embeddings - - async def _encode(self, documents: list[str]) -> list[list[int]]: - """ - Encodes a list of documents into a list of tokenized documents, using - the corresponding encoder for the model. - - Args: - documents (list[str]): A list of text documents to be tokenized. - - Returns: - list[list[int]]: A list of tokenized documents. - """ - total_tokens = 0 - encoded_documents: list[list[int]] = [] - for document in documents: - if self.model.endswith("001"): - # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 - # replace newlines, which can negatively affect performance. - document = document.replace("\n", " ") - tokenized = self._encoder.encode_ordinary(document) - total_tokens += len(tokenized) - encoded_documents.append(tokenized) - await logger.adebug(f"Total tokens in batch: {total_tokens}") - return encoded_documents - - @cached_property - def _encoder(self) -> tiktoken.Encoding: - return tiktoken.encoding_for_model(self.model) - - -# Note: this is a re-declaration of ollama.Options, which we are forced to do -# otherwise pydantic complains (ollama.Options subclasses typing.TypedDict): -# pydantic.errors.PydanticUserError: Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12. # noqa -class OllamaOptions(TypedDict, total=False): - # load time options - numa: bool - num_ctx: int - num_batch: int - num_gpu: int - main_gpu: int - low_vram: bool - f16_kv: bool - logits_all: bool - vocab_only: bool - use_mmap: bool - use_mlock: bool - embedding_only: bool - num_thread: int - - # runtime options - num_keep: int - seed: int - num_predict: int - top_k: int - top_p: float - tfs_z: float - typical_p: float - repeat_last_n: int - temperature: float - repeat_penalty: float - presence_penalty: float - frequency_penalty: float - mirostat: int - mirostat_tau: float - mirostat_eta: float - penalize_newline: bool - stop: Sequence[str] - - -class Ollama(BaseModel, Embedder): - """ - Embedder that uses Ollama to embed documents into vector representations. - - Attributes: - implementation (Literal["ollama"]): The literal identifier for this - implementation. - model (str): The name of the Ollama model used for embeddings. - base_url (str): The base url used to access the Ollama API. - options (dict): Additional ollama-specific runtime options - keep_alive (str): How long to keep the model loaded after the request - """ - - implementation: Literal["ollama"] - model: str - base_url: str | None = None - options: OllamaOptions | None = None - keep_alive: str | None = None # this is only `str` because of the SQL API - - @override - async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]: - """ - Embeds a list of documents into vectors using Ollama's embeddings API. - - Args: - documents (list[str]): A list of documents to be embedded. - - Returns: - Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or - errors for each document. - """ - await logger.adebug(f"Chunks produced: {len(documents)}") - return await self._batcher.batch_chunks_and_embed(documents) - - @cached_property - def _batcher(self) -> BatchApiCaller[StringDocument]: - return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) - - @override - def _max_chunks_per_batch(self) -> int: - # Note: the chosen default is arbitrary - Ollama doesn't place a limit - return int( - os.getenv("PGAI_VECTORIZER_OLLAMA_MAX_CHUNKS_PER_BATCH", default="2048") - ) - - @override - async def setup(self): - client = ollama.AsyncClient(host=self.base_url) - try: - await client.show(self.model) - except ollama.ResponseError as e: - if f"model '{self.model}' not found" in e.error: - logger.warn( - f"pulling ollama model '{self.model}', this may take a while" - ) - await client.pull(self.model) - - async def call_embed_api(self, documents: str | list[str]) -> EmbeddingResponse: - response = await ollama.AsyncClient(host=self.base_url).embed( - model=self.model, - input=documents, - options=self.options, - keep_alive=self.keep_alive, - ) - usage = Usage( - prompt_tokens=response["prompt_eval_count"], - total_tokens=response["prompt_eval_count"], - ) - return EmbeddingResponse(embeddings=response["embeddings"], usage=usage) - - async def _model(self) -> Mapping[str, Any]: - """ - Gets the model details from the Ollama API - :return: - """ - return await ollama.AsyncClient(host=self.base_url).show(self.model) - - async def _context_length(self) -> int | None: - """ - Gets the context_length of the configured model, if available - """ - model = await self._model() - architecture = model["model_info"].get("general.architecture", None) - if architecture is None: - logger.warn(f"unable to determine architecture for model '{self.model}'") - return None - context_key = f"{architecture}.context_length" - # see https://github.com/ollama/ollama/blob/712d63c3f06f297e22b1ae32678349187dccd2e4/llm/ggml.go#L116-L118 # noqa - model_context_length = model["model_info"][context_key] - # the context window can be configured, so pull the value from the config - num_ctx = ( - float("inf") - if self.options is None - else self.options.get("num_ctx", float("inf")) - ) - return min(model_context_length, num_ctx) - - -class VoyageAI(ApiKeyMixin, BaseModel, Embedder): - """ - Embedder that uses Voyage AI to embed documents into vector representations. - - Attributes: - implementation (Literal["voyageai"]): The literal identifier for this - implementation. - model (str): The name of the Voyage AU model used for embeddings. - input_type ("document" | "query" | None): Set the input type of the - items to be embedded. If set, improves retrieval quality. - - """ - - implementation: Literal["voyageai"] - model: str - input_type: Literal["document"] | Literal["query"] | None = None - - @override - async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]: - """ - Embeds a list of documents into vectors using the VoyageAI embeddings API. - - Args: - documents (list[str]): A list of documents to be embedded. - - Returns: - Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or - errors for each document. - """ - await logger.adebug(f"Chunks produced: {len(documents)}") - return await self._batcher.batch_chunks_and_embed(documents) - - @cached_property - def _batcher(self) -> BatchApiCaller[StringDocument]: - return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api) - - @override - def _max_chunks_per_batch(self) -> int: - return 128 - - async def call_embed_api(self, documents: list[str]) -> EmbeddingResponse: - response = await voyageai.AsyncClient(api_key=self._api_key).embed( - documents, - model=self.model, - input_type=self.input_type, - ) - usage = Usage( - prompt_tokens=response.total_tokens, - total_tokens=response.total_tokens, - ) - return EmbeddingResponse(embeddings=response.embeddings, usage=usage) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index d7ea0e5e..d57c249f 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -22,7 +22,8 @@ LangChainCharacterTextSplitter, LangChainRecursiveCharacterTextSplitter, ) -from .embeddings import ChunkEmbeddingError, Ollama, OpenAI, VoyageAI +from .embedders import Ollama, OpenAI, VoyageAI +from .embeddings import ChunkEmbeddingError from .formatting import ChunkValue, PythonTemplate from .processing import ProcessingDefault