Skip to content

Commit

Permalink
feat: create batch embedding tables in extension
Browse files Browse the repository at this point in the history
  • Loading branch information
kolaente committed Dec 16, 2024
1 parent f7f6d13 commit d49c352
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 30 deletions.
34 changes: 32 additions & 2 deletions projects/extension/sql/idempotent/013-vectorizer-api.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


-------------------------------------------------------------------------------
-- execute_vectorizer
create or replace function ai.execute_vectorizer(vectorizer_id pg_catalog.int4) returns void
Expand Down Expand Up @@ -31,6 +29,9 @@ create or replace function ai.create_vectorizer
, queue_table pg_catalog.name default null
, grant_to pg_catalog.name[] default ai.grant_to()
, enqueue_existing pg_catalog.bool default true
, embedding_batch_schema pg_catalog.name default null
, embedding_batch_table pg_catalog.name default null
, embedding_batch_chunks_table pg_catalog.name default null
) returns pg_catalog.int4
as $func$
declare
Expand All @@ -44,6 +45,7 @@ declare
_vectorizer_id pg_catalog.int4;
_sql pg_catalog.text;
_job_id pg_catalog.int8;
_implementation pg_catalog.text;
begin
-- make sure all the roles listed in grant_to exist
if grant_to is not null then
Expand Down Expand Up @@ -225,6 +227,31 @@ begin
scheduling = pg_catalog.jsonb_insert(scheduling, array['job_id'], pg_catalog.to_jsonb(_job_id));
end if;

embedding_batch_schema = coalesce(embedding_batch_schema, 'ai');
embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id));
embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id));

-- create batch embedding tables
select (embedding operator (pg_catalog.->> 'implementation'))::text into _implementation;
if _implementation = 'openai' then
-- make sure embedding batch table name is available
if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_table)) is not null then
raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table;
end if;

-- make sure embedding batch chunks table name is available
if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_chunks_table)) is not null then
raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table;
end if;

perform ai._vectorizer_create_embedding_batches_table
(embedding_batch_schema
, embedding_batch_table
, embedding_batch_chunks_table
, grant_to
);
end if;

insert into ai.vectorizer
( id
, source_schema
Expand Down Expand Up @@ -259,6 +286,9 @@ begin
, 'formatting', formatting
, 'scheduling', scheduling
, 'processing', processing
, 'embedding_batch_schema', embedding_batch_schema
, 'embedding_batch_table', embedding_batch_table
, 'embedding_batch_chunks_table', embedding_batch_chunks_table
)
);

Expand Down
96 changes: 96 additions & 0 deletions projects/extension/sql/idempotent/016-openai-batch-api.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
-------------------------------------------------------------------------------
-- _vectorizer_create_queue_table
create or replace function ai._vectorizer_create_embedding_batches_table
( embedding_batch_schema name
, embedding_batch_table name
, embedding_batch_chunks_table name
, grant_to name[]
) returns void as
$func$
declare
_sql text;
begin
-- create the batches table
select pg_catalog.format
( $sql$create table %I.%I(
openai_batch_id VARCHAR(255) PRIMARY KEY,
input_file_id VARCHAR(255) NOT NULL,
output_file_id VARCHAR(255),
status VARCHAR(255) NOT NULL,
errors JSONB,
created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(),
expires_at TIMESTAMP(0),
completed_at TIMESTAMP(0),
failed_at TIMESTAMP(0)
))$sql$
, embedding_batch_schema
, embedding_batch_table
) into strict _sql
;
execute _sql;

-- create the index
select pg_catalog.format
( $sql$create index on %I.%I (status)$sql$
, embedding_batch_schema, embedding_batch_table
) into strict _sql
;
execute _sql;

-- create the batch chunks table
select pg_catalog.format
( $sql$create table %I.%I(
id VARCHAR(255) PRIMARY KEY,
embedding_batch_id VARCHAR(255) REFERENCES %I.%I (openai_batch_id),
text TEXT
))$sql$
, embedding_batch_schema
, embedding_batch_chunks_table
, embedding_batch_schema
, embedding_batch_table
) into strict _sql
;
execute _sql;

if grant_to is not null then
-- grant usage on queue schema to grant_to roles
select pg_catalog.format
( $sql$grant usage on schema %I to %s$sql$
, embedding_batch_schema
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;

-- grant select, update, delete on batches table to grant_to roles
select pg_catalog.format
( $sql$grant select, insert, update, delete on %I.%I to %s$sql$
, embedding_batch_schema
, embedding_batch_table
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;

-- grant select, update, delete on batch chunks table to grant_to roles
select pg_catalog.format
( $sql$grant select, insert, update, delete on %I.%I to %s$sql$
, embedding_batch_schema
, embedding_batch_chunks_table
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;
end if;
end;
$func$
language plpgsql volatile security invoker
set search_path to pg_catalog, pg_temp
;

28 changes: 0 additions & 28 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,34 +478,6 @@ async def run(self) -> int:
res += items_processed
loops += 1

async def _create_batch_table(self, conn: AsyncConnection):
# TODO this does not feel like the way to go, is there a way to do these kind of migrations properly?
await conn.execute("""
CREATE TABLE IF NOT EXISTS ai.embedding_batches
(
openai_batch_id VARCHAR(255) PRIMARY KEY,
input_file_id VARCHAR(255) NOT NULL,
output_file_id VARCHAR(255),
status VARCHAR(255) NOT NULL,
errors JSONB,
created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(),
expires_at TIMESTAMP(0),
completed_at TIMESTAMP(0),
failed_at TIMESTAMP(0)
);
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS embedding_batches_status_index ON ai.embedding_batches (status);
""")
return await conn.execute("""
CREATE TABLE IF NOT EXISTS ai.embedding_batch_chunks
(
id VARCHAR(255) PRIMARY KEY,
embedding_batch_id VARCHAR(255) REFERENCES ai.embedding_batches (openai_batch_id),
text TEXT
);
""")

@tracer.wrap()
async def _do_openai_batch(self, conn: AsyncConnection) -> int:
"""
Expand Down

0 comments on commit d49c352

Please sign in to comment.