From b8ccea9f717a793d726939504e650950642f9533 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 2 Jul 2024 14:31:57 -0500 Subject: [PATCH] feat(nodejs): make tbl.search chainable (#1421) so this was annoying me when writing the docs. for a `search` query, one needed to chain `async` calls. ```ts const res = await (await tbl.search("greetings")).toArray() ``` now the promise will be deferred until the query is collected, leading to a more functional API ```ts const res = await tbl.search("greetings").toArray() ``` --- nodejs/__test__/table.test.ts | 6 +- nodejs/lancedb/arrow.ts | 6 +- .../lancedb/embedding/embedding_function.ts | 2 +- nodejs/lancedb/query.ts | 110 +++++++++++++----- nodejs/lancedb/remote/table.ts | 5 +- nodejs/lancedb/table.ts | 40 +++---- 6 files changed, 112 insertions(+), 57 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 8ed9ec8b5d..94d9d13862 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -706,10 +706,10 @@ describe("table.search", () => { const data = [{ text: "hello world" }, { text: "goodbye world" }]; const table = await db.createTable("test", data, { schema }); - const results = await table.search("greetings").then((r) => r.toArray()); + const results = await table.search("greetings").toArray(); expect(results[0].text).toBe(data[0].text); - const results2 = await table.search("farewell").then((r) => r.toArray()); + const results2 = await table.search("farewell").toArray(); expect(results2[0].text).toBe(data[1].text); }); @@ -721,7 +721,7 @@ describe("table.search", () => { ]; const table = await db.createTable("test", data); - expect(table.search("hello")).rejects.toThrow( + expect(table.search("hello").toArray()).rejects.toThrow( "No embedding functions are defined in the table", ); }); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 3a1a26a54a..7bc7951cab 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -97,7 +97,11 @@ export type TableLike = | ArrowTable | { schema: SchemaLike; batches: RecordBatchLike[] }; -export type IntoVector = Float32Array | Float64Array | number[]; +export type IntoVector = + | Float32Array + | Float64Array + | number[] + | Promise; export function isArrowTable(value: object): value is TableLike { if (value instanceof ArrowTable) return true; diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 1f98b8c9db..ff8d119e4f 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -181,7 +181,7 @@ export abstract class EmbeddingFunction< /** Compute the embeddings for a single query */ - async computeQueryEmbeddings(data: T): Promise { + async computeQueryEmbeddings(data: T): Promise> { return this.computeSourceEmbeddings([data]).then( (embeddings) => embeddings[0], ); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 48cc769379..d63e75187a 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -89,15 +89,26 @@ export interface QueryExecutionOptions { } /** Common methods supported by all query types */ -export class QueryBase< - NativeQueryType extends NativeQuery | NativeVectorQuery, - QueryType, -> implements AsyncIterable +export class QueryBase + implements AsyncIterable { - protected constructor(protected inner: NativeQueryType) { + protected constructor( + protected inner: NativeQueryType | Promise, + ) { // intentionally empty } + // call a function on the inner (either a promise or the actual object) + protected doCall(fn: (inner: NativeQueryType) => void) { + if (this.inner instanceof Promise) { + this.inner = this.inner.then((inner) => { + fn(inner); + return inner; + }); + } else { + fn(this.inner); + } + } /** * A filter statement to be applied to this query. * @@ -110,16 +121,16 @@ export class QueryBase< * Filtering performance can often be improved by creating a scalar index * on the filter column(s). */ - where(predicate: string): QueryType { - this.inner.onlyIf(predicate); - return this as unknown as QueryType; + where(predicate: string): this { + this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate)); + return this; } /** * A filter statement to be applied to this query. * @alias where * @deprecated Use `where` instead */ - filter(predicate: string): QueryType { + filter(predicate: string): this { return this.where(predicate); } @@ -155,7 +166,7 @@ export class QueryBase< */ select( columns: string[] | Map | Record | string, - ): QueryType { + ): this { let columnTuples: [string, string][]; if (typeof columns === "string") { columns = [columns]; @@ -167,8 +178,10 @@ export class QueryBase< } else { columnTuples = Object.entries(columns); } - this.inner.select(columnTuples); - return this as unknown as QueryType; + this.doCall((inner: NativeQueryType) => { + inner.select(columnTuples); + }); + return this; } /** @@ -177,15 +190,19 @@ export class QueryBase< * By default, a plain search has no limit. If this method is not * called then every valid row from the table will be returned. */ - limit(limit: number): QueryType { - this.inner.limit(limit); - return this as unknown as QueryType; + limit(limit: number): this { + this.doCall((inner: NativeQueryType) => inner.limit(limit)); + return this; } protected nativeExecute( options?: Partial, ): Promise { - return this.inner.execute(options?.maxBatchLength); + if (this.inner instanceof Promise) { + return this.inner.then((inner) => inner.execute(options?.maxBatchLength)); + } else { + return this.inner.execute(options?.maxBatchLength); + } } /** @@ -214,7 +231,13 @@ export class QueryBase< /** Collect the results as an Arrow @see {@link ArrowTable}. */ async toArrow(options?: Partial): Promise { const batches = []; - for await (const batch of new RecordBatchIterable(this.inner, options)) { + let inner; + if (this.inner instanceof Promise) { + inner = await this.inner; + } else { + inner = this.inner; + } + for await (const batch of new RecordBatchIterable(inner, options)) { batches.push(batch); } return new ArrowTable(batches); @@ -258,8 +281,8 @@ export interface ExecutableQuery {} * * This builder can be reused to execute the query many times. */ -export class VectorQuery extends QueryBase { - constructor(inner: NativeVectorQuery) { +export class VectorQuery extends QueryBase { + constructor(inner: NativeVectorQuery | Promise) { super(inner); } @@ -286,7 +309,8 @@ export class VectorQuery extends QueryBase { * you the desired recall. */ nprobes(nprobes: number): VectorQuery { - this.inner.nprobes(nprobes); + super.doCall((inner) => inner.nprobes(nprobes)); + return this; } @@ -300,7 +324,7 @@ export class VectorQuery extends QueryBase { * whose data type is a fixed-size-list of floats. */ column(column: string): VectorQuery { - this.inner.column(column); + super.doCall((inner) => inner.column(column)); return this; } @@ -321,7 +345,7 @@ export class VectorQuery extends QueryBase { distanceType( distanceType: Required["distanceType"], ): VectorQuery { - this.inner.distanceType(distanceType); + super.doCall((inner) => inner.distanceType(distanceType)); return this; } @@ -355,7 +379,7 @@ export class VectorQuery extends QueryBase { * distance between the query vector and the actual uncompressed vector. */ refineFactor(refineFactor: number): VectorQuery { - this.inner.refineFactor(refineFactor); + super.doCall((inner) => inner.refineFactor(refineFactor)); return this; } @@ -380,7 +404,7 @@ export class VectorQuery extends QueryBase { * factor can often help restore some of the results lost by post filtering. */ postfilter(): VectorQuery { - this.inner.postfilter(); + super.doCall((inner) => inner.postfilter()); return this; } @@ -394,13 +418,13 @@ export class VectorQuery extends QueryBase { * calculate your recall to select an appropriate value for nprobes. */ bypassVectorIndex(): VectorQuery { - this.inner.bypassVectorIndex(); + super.doCall((inner) => inner.bypassVectorIndex()); return this; } } /** A builder for LanceDB queries. */ -export class Query extends QueryBase { +export class Query extends QueryBase { constructor(tbl: NativeTable) { super(tbl.query()); } @@ -443,7 +467,37 @@ export class Query extends QueryBase { * a default `limit` of 10 will be used. @see {@link Query#limit} */ nearestTo(vector: IntoVector): VectorQuery { - const vectorQuery = this.inner.nearestTo(Float32Array.from(vector)); - return new VectorQuery(vectorQuery); + if (this.inner instanceof Promise) { + const nativeQuery = this.inner.then(async (inner) => { + if (vector instanceof Promise) { + const arr = await vector.then((v) => Float32Array.from(v)); + return inner.nearestTo(arr); + } else { + return inner.nearestTo(Float32Array.from(vector)); + } + }); + return new VectorQuery(nativeQuery); + } + if (vector instanceof Promise) { + const res = (async () => { + try { + const v = await vector; + const arr = Float32Array.from(v); + // + // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping + const value: any = this.nearestTo(arr); + const inner = value.inner as + | NativeVectorQuery + | Promise; + return inner; + } catch (e) { + return Promise.reject(e); + } + })(); + return new VectorQuery(res); + } else { + const vectorQuery = this.inner.nearestTo(Float32Array.from(vector)); + return new VectorQuery(vectorQuery); + } } } diff --git a/nodejs/lancedb/remote/table.ts b/nodejs/lancedb/remote/table.ts index 63def38f2b..f06b0c8435 100644 --- a/nodejs/lancedb/remote/table.ts +++ b/nodejs/lancedb/remote/table.ts @@ -122,9 +122,8 @@ export class RemoteTable extends Table { query(): import("..").Query { throw new Error("query() is not yet supported on the LanceDB cloud"); } - search(query: IntoVector): VectorQuery; - search(query: string): Promise; - search(_query: string | IntoVector): VectorQuery | Promise { + + search(_query: string | IntoVector): VectorQuery { throw new Error("search() is not yet supported on the LanceDB cloud"); } vectorSearch(_vector: unknown): import("..").VectorQuery { diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 588a0d0787..de8f273350 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -244,9 +244,9 @@ export abstract class Table { * Create a search query to find the nearest neighbors * of the given query vector * @param {string} query - the query. This will be converted to a vector using the table's provided embedding function - * @rejects {Error} If no embedding functions are defined in the table + * @note If no embedding functions are defined in the table, this will error when collecting the results. */ - abstract search(query: string): Promise; + abstract search(query: string): VectorQuery; /** * Create a search query to find the nearest neighbors * of the given query vector @@ -502,28 +502,26 @@ export class LocalTable extends Table { query(): Query { return new Query(this.inner); } - - search(query: string): Promise; - - search(query: IntoVector): VectorQuery; - search(query: string | IntoVector): Promise | VectorQuery { + search(query: string | IntoVector): VectorQuery { if (typeof query !== "string") { return this.vectorSearch(query); } else { - return this.getEmbeddingFunctions().then(async (functions) => { - // TODO: Support multiple embedding functions - const embeddingFunc: EmbeddingFunctionConfig | undefined = functions - .values() - .next().value; - if (!embeddingFunc) { - return Promise.reject( - new Error("No embedding functions are defined in the table"), - ); - } - const embeddings = - await embeddingFunc.function.computeQueryEmbeddings(query); - return this.query().nearestTo(embeddings); - }); + const queryPromise = this.getEmbeddingFunctions().then( + async (functions) => { + // TODO: Support multiple embedding functions + const embeddingFunc: EmbeddingFunctionConfig | undefined = functions + .values() + .next().value; + if (!embeddingFunc) { + return Promise.reject( + new Error("No embedding functions are defined in the table"), + ); + } + return await embeddingFunc.function.computeQueryEmbeddings(query); + }, + ); + + return this.query().nearestTo(queryPromise); } }