Skip to content

Commit

Permalink
feat(nodejs): make tbl.search chainable (lancedb#1421)
Browse files Browse the repository at this point in the history
so this was annoying me when writing the docs. 

for a `search` query, one needed to chain `async` calls.

```ts
const res = await (await tbl.search("greetings")).toArray()
```

now the promise will be deferred until the query is collected, leading
to a more functional API

```ts
const res = await tbl.search("greetings").toArray()
```
  • Loading branch information
universalmind303 authored Jul 2, 2024
1 parent 46c6ff8 commit b8ccea9
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 57 deletions.
6 changes: 3 additions & 3 deletions nodejs/__test__/table.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,10 @@ describe("table.search", () => {
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });

const results = await table.search("greetings").then((r) => r.toArray());
const results = await table.search("greetings").toArray();
expect(results[0].text).toBe(data[0].text);

const results2 = await table.search("farewell").then((r) => r.toArray());
const results2 = await table.search("farewell").toArray();
expect(results2[0].text).toBe(data[1].text);
});

Expand All @@ -721,7 +721,7 @@ describe("table.search", () => {
];
const table = await db.createTable("test", data);

expect(table.search("hello")).rejects.toThrow(
expect(table.search("hello").toArray()).rejects.toThrow(
"No embedding functions are defined in the table",
);
});
Expand Down
6 changes: 5 additions & 1 deletion nodejs/lancedb/arrow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ export type TableLike =
| ArrowTable
| { schema: SchemaLike; batches: RecordBatchLike[] };

export type IntoVector = Float32Array | Float64Array | number[];
export type IntoVector =
| Float32Array
| Float64Array
| number[]
| Promise<Float32Array | Float64Array | number[]>;

export function isArrowTable(value: object): value is TableLike {
if (value instanceof ArrowTable) return true;
Expand Down
2 changes: 1 addition & 1 deletion nodejs/lancedb/embedding/embedding_function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ export abstract class EmbeddingFunction<
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
async computeQueryEmbeddings(data: T): Promise<Awaited<IntoVector>> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);
Expand Down
110 changes: 82 additions & 28 deletions nodejs/lancedb/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,26 @@ export interface QueryExecutionOptions {
}

/** Common methods supported by all query types */
export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery,
QueryType,
> implements AsyncIterable<RecordBatch>
export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
implements AsyncIterable<RecordBatch>
{
protected constructor(protected inner: NativeQueryType) {
protected constructor(
protected inner: NativeQueryType | Promise<NativeQueryType>,
) {
// intentionally empty
}

// call a function on the inner (either a promise or the actual object)
protected doCall(fn: (inner: NativeQueryType) => void) {
if (this.inner instanceof Promise) {
this.inner = this.inner.then((inner) => {
fn(inner);
return inner;
});
} else {
fn(this.inner);
}
}
/**
* A filter statement to be applied to this query.
*
Expand All @@ -110,16 +121,16 @@ export class QueryBase<
* Filtering performance can often be improved by creating a scalar index
* on the filter column(s).
*/
where(predicate: string): QueryType {
this.inner.onlyIf(predicate);
return this as unknown as QueryType;
where(predicate: string): this {
this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate));
return this;
}
/**
* A filter statement to be applied to this query.
* @alias where
* @deprecated Use `where` instead
*/
filter(predicate: string): QueryType {
filter(predicate: string): this {
return this.where(predicate);
}

Expand Down Expand Up @@ -155,7 +166,7 @@ export class QueryBase<
*/
select(
columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType {
): this {
let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
Expand All @@ -167,8 +178,10 @@ export class QueryBase<
} else {
columnTuples = Object.entries(columns);
}
this.inner.select(columnTuples);
return this as unknown as QueryType;
this.doCall((inner: NativeQueryType) => {
inner.select(columnTuples);
});
return this;
}

/**
Expand All @@ -177,15 +190,19 @@ export class QueryBase<
* By default, a plain search has no limit. If this method is not
* called then every valid row from the table will be returned.
*/
limit(limit: number): QueryType {
this.inner.limit(limit);
return this as unknown as QueryType;
limit(limit: number): this {
this.doCall((inner: NativeQueryType) => inner.limit(limit));
return this;
}

protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
return this.inner.execute(options?.maxBatchLength);
if (this.inner instanceof Promise) {
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
} else {
return this.inner.execute(options?.maxBatchLength);
}
}

/**
Expand Down Expand Up @@ -214,7 +231,13 @@ export class QueryBase<
/** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
const batches = [];
for await (const batch of new RecordBatchIterable(this.inner, options)) {
let inner;
if (this.inner instanceof Promise) {
inner = await this.inner;
} else {
inner = this.inner;
}
for await (const batch of new RecordBatchIterable(inner, options)) {
batches.push(batch);
}
return new ArrowTable(batches);
Expand Down Expand Up @@ -258,8 +281,8 @@ export interface ExecutableQuery {}
*
* This builder can be reused to execute the query many times.
*/
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
constructor(inner: NativeVectorQuery) {
export class VectorQuery extends QueryBase<NativeVectorQuery> {
constructor(inner: NativeVectorQuery | Promise<NativeVectorQuery>) {
super(inner);
}

Expand All @@ -286,7 +309,8 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* you the desired recall.
*/
nprobes(nprobes: number): VectorQuery {
this.inner.nprobes(nprobes);
super.doCall((inner) => inner.nprobes(nprobes));

return this;
}

Expand All @@ -300,7 +324,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* whose data type is a fixed-size-list of floats.
*/
column(column: string): VectorQuery {
this.inner.column(column);
super.doCall((inner) => inner.column(column));
return this;
}

Expand All @@ -321,7 +345,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
distanceType(
distanceType: Required<IvfPqOptions>["distanceType"],
): VectorQuery {
this.inner.distanceType(distanceType);
super.doCall((inner) => inner.distanceType(distanceType));
return this;
}

Expand Down Expand Up @@ -355,7 +379,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* distance between the query vector and the actual uncompressed vector.
*/
refineFactor(refineFactor: number): VectorQuery {
this.inner.refineFactor(refineFactor);
super.doCall((inner) => inner.refineFactor(refineFactor));
return this;
}

Expand All @@ -380,7 +404,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* factor can often help restore some of the results lost by post filtering.
*/
postfilter(): VectorQuery {
this.inner.postfilter();
super.doCall((inner) => inner.postfilter());
return this;
}

Expand All @@ -394,13 +418,13 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
* calculate your recall to select an appropriate value for nprobes.
*/
bypassVectorIndex(): VectorQuery {
this.inner.bypassVectorIndex();
super.doCall((inner) => inner.bypassVectorIndex());
return this;
}
}

/** A builder for LanceDB queries. */
export class Query extends QueryBase<NativeQuery, Query> {
export class Query extends QueryBase<NativeQuery> {
constructor(tbl: NativeTable) {
super(tbl.query());
}
Expand Down Expand Up @@ -443,7 +467,37 @@ export class Query extends QueryBase<NativeQuery, Query> {
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: IntoVector): VectorQuery {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
if (this.inner instanceof Promise) {
const nativeQuery = this.inner.then(async (inner) => {
if (vector instanceof Promise) {
const arr = await vector.then((v) => Float32Array.from(v));
return inner.nearestTo(arr);
} else {
return inner.nearestTo(Float32Array.from(vector));
}
});
return new VectorQuery(nativeQuery);
}
if (vector instanceof Promise) {
const res = (async () => {
try {
const v = await vector;
const arr = Float32Array.from(v);
//
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
const value: any = this.nearestTo(arr);
const inner = value.inner as
| NativeVectorQuery
| Promise<NativeVectorQuery>;
return inner;
} catch (e) {
return Promise.reject(e);
}
})();
return new VectorQuery(res);
} else {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
return new VectorQuery(vectorQuery);
}
}
}
5 changes: 2 additions & 3 deletions nodejs/lancedb/remote/table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ export class RemoteTable extends Table {
query(): import("..").Query {
throw new Error("query() is not yet supported on the LanceDB cloud");
}
search(query: IntoVector): VectorQuery;
search(query: string): Promise<VectorQuery>;
search(_query: string | IntoVector): VectorQuery | Promise<VectorQuery> {

search(_query: string | IntoVector): VectorQuery {
throw new Error("search() is not yet supported on the LanceDB cloud");
}
vectorSearch(_vector: unknown): import("..").VectorQuery {
Expand Down
40 changes: 19 additions & 21 deletions nodejs/lancedb/table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ export abstract class Table {
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
* @note If no embedding functions are defined in the table, this will error when collecting the results.
*/
abstract search(query: string): Promise<VectorQuery>;
abstract search(query: string): VectorQuery;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
Expand Down Expand Up @@ -502,28 +502,26 @@ export class LocalTable extends Table {
query(): Query {
return new Query(this.inner);
}

search(query: string): Promise<VectorQuery>;

search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
search(query: string | IntoVector): VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
const queryPromise = this.getEmbeddingFunctions().then(
async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
return await embeddingFunc.function.computeQueryEmbeddings(query);
},
);

return this.query().nearestTo(queryPromise);
}
}

Expand Down

0 comments on commit b8ccea9

Please sign in to comment.