From 79a1667753b3d30fe96bd949a4650cbcdfd736c4 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 25 Jun 2024 11:10:08 -0500 Subject: [PATCH] feat(nodejs): feature parity [6/N] - make public interface work with multiple arrow versions (#1392) previously we didnt have great compatibility with other versions of apache arrow. This should bridge that gap a bit. depends on https://github.com/lancedb/lancedb/pull/1391 see actual diff here https://github.com/universalmind303/lancedb/compare/query-filter...universalmind303:arrow-compatibility --- nodejs/__test__/table.test.ts | 6 +- nodejs/lancedb/arrow.ts | 93 ++++++++++++++++++++++------- nodejs/lancedb/connection.ts | 14 ++--- nodejs/lancedb/query.ts | 4 +- nodejs/lancedb/remote/connection.ts | 9 ++- nodejs/lancedb/sanitize.ts | 74 ++++++++++++++++++++++- nodejs/lancedb/table.ts | 10 ++-- rust/lancedb/src/table.rs | 1 + 8 files changed, 174 insertions(+), 37 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 7ca86de08..1c3fb4ab4 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -39,7 +39,9 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { let tmpDir: tmp.DirResult; let table: Table; - const schema = new arrow.Schema([ + const schema: + | import("apache-arrow").Schema + | import("apache-arrow-old").Schema = new arrow.Schema([ new arrow.Field("id", new arrow.Float64(), true), ]); @@ -315,7 +317,7 @@ describe("When creating an index", () => { .query() .limit(2) .nearestTo(queryVec) - .distanceType("DoT") + .distanceType("dot") .toArrow(); expect(rst.numRows).toBe(2); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 8309c1617..3a1a26a54 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -15,6 +15,7 @@ import { Table as ArrowTable, Binary, + BufferType, DataType, Field, FixedSizeBinary, @@ -37,14 +38,68 @@ import { type makeTable, vectorFromArray, } from "apache-arrow"; +import { Buffers } from "apache-arrow/data"; import { type EmbeddingFunction } from "./embedding/embedding_function"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; -import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize"; +import { + sanitizeField, + sanitizeSchema, + sanitizeTable, + sanitizeType, +} from "./sanitize"; export * from "apache-arrow"; +export type SchemaLike = + | Schema + | { + fields: FieldLike[]; + metadata: Map; + get names(): unknown[]; + }; +export type FieldLike = + | Field + | { + type: string; + name: string; + nullable?: boolean; + metadata?: Map; + }; + +export type DataLike = + // biome-ignore lint/suspicious/noExplicitAny: + | import("apache-arrow").Data> + | { + // biome-ignore lint/suspicious/noExplicitAny: + type: any; + length: number; + offset: number; + stride: number; + nullable: boolean; + children: DataLike[]; + get nullCount(): number; + // biome-ignore lint/suspicious/noExplicitAny: + values: Buffers[BufferType.DATA]; + // biome-ignore lint/suspicious/noExplicitAny: + typeIds: Buffers[BufferType.TYPE]; + // biome-ignore lint/suspicious/noExplicitAny: + nullBitmap: Buffers[BufferType.VALIDITY]; + // biome-ignore lint/suspicious/noExplicitAny: + valueOffsets: Buffers[BufferType.OFFSET]; + }; + +export type RecordBatchLike = + | RecordBatch + | { + schema: SchemaLike; + data: DataLike; + }; + +export type TableLike = + | ArrowTable + | { schema: SchemaLike; batches: RecordBatchLike[] }; export type IntoVector = Float32Array | Float64Array | number[]; -export function isArrowTable(value: object): value is ArrowTable { +export function isArrowTable(value: object): value is TableLike { if (value instanceof ArrowTable) return true; return "schema" in value && "batches" in value; } @@ -135,7 +190,7 @@ export function isFixedSizeList(value: unknown): value is FixedSizeList { } /** Data type accepted by NodeJS SDK */ -export type Data = Record[] | ArrowTable; +export type Data = Record[] | TableLike; /* * Options to control how a column should be converted to a vector array @@ -162,7 +217,7 @@ export class MakeArrowTableOptions { * The schema must be specified if there are no records (e.g. to make * an empty table) */ - schema?: Schema; + schema?: SchemaLike; /* * Mapping from vector column name to expected type @@ -310,7 +365,7 @@ export function makeArrowTable( if (opt.schema !== undefined && opt.schema !== null) { opt.schema = sanitizeSchema(opt.schema); opt.schema = validateSchemaEmbeddings( - opt.schema, + opt.schema as Schema, data, options?.embeddingFunction, ); @@ -394,7 +449,7 @@ export function makeArrowTable( // `new ArrowTable(schema, batches)` which does not do any schema inference const firstTable = new ArrowTable(columns); const batchesFixed = firstTable.batches.map( - (batch) => new RecordBatch(opt.schema!, batch.data), + (batch) => new RecordBatch(opt.schema as Schema, batch.data), ); let schema: Schema; if (metadata !== undefined) { @@ -407,9 +462,9 @@ export function makeArrowTable( } } - schema = new Schema(opt.schema.fields, schemaMetadata); + schema = new Schema(opt.schema.fields as Field[], schemaMetadata); } else { - schema = opt.schema; + schema = opt.schema as Schema; } return new ArrowTable(schema, batchesFixed); } @@ -425,7 +480,7 @@ export function makeArrowTable( * Create an empty Arrow table with the provided schema */ export function makeEmptyTable( - schema: Schema, + schema: SchemaLike, metadata?: Map, ): ArrowTable { return makeArrowTable([], { schema }, metadata); @@ -563,18 +618,17 @@ async function applyEmbeddingsFromMetadata( async function applyEmbeddings( table: ArrowTable, embeddings?: EmbeddingFunctionConfig, - schema?: Schema, + schema?: SchemaLike, ): Promise { + if (schema !== undefined && schema !== null) { + schema = sanitizeSchema(schema); + } if (schema?.metadata.has("embedding_functions")) { - return applyEmbeddingsFromMetadata(table, schema!); + return applyEmbeddingsFromMetadata(table, schema! as Schema); } else if (embeddings == null || embeddings === undefined) { return table; } - if (schema !== undefined && schema !== null) { - schema = sanitizeSchema(schema); - } - // Convert from ArrowTable to Record const colEntries = [...Array(table.numCols).keys()].map((_, idx) => { const name = table.schema.fields[idx].name; @@ -650,7 +704,7 @@ async function applyEmbeddings( `When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`, ); } - return alignTable(newTable, schema); + return alignTable(newTable, schema as Schema); } return newTable; } @@ -744,7 +798,7 @@ export async function fromRecordsToStreamBuffer( export async function fromTableToBuffer( table: ArrowTable, embeddings?: EmbeddingFunctionConfig, - schema?: Schema, + schema?: SchemaLike, ): Promise { if (schema !== undefined && schema !== null) { schema = sanitizeSchema(schema); @@ -771,7 +825,7 @@ export async function fromDataToBuffer( schema = sanitizeSchema(schema); } if (isArrowTable(data)) { - return fromTableToBuffer(data, embeddings, schema); + return fromTableToBuffer(sanitizeTable(data), embeddings, schema); } else { const table = await convertToTable(data, embeddings, { schema }); return fromTableToBuffer(table); @@ -789,7 +843,7 @@ export async function fromDataToBuffer( export async function fromTableToStreamBuffer( table: ArrowTable, embeddings?: EmbeddingFunctionConfig, - schema?: Schema, + schema?: SchemaLike, ): Promise { const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema); const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings); @@ -854,7 +908,6 @@ function validateSchemaEmbeddings( for (let field of schema.fields) { if (isFixedSizeList(field.type)) { field = sanitizeField(field); - if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { if (schema.metadata.has("embedding_functions")) { const embeddings = JSON.parse( diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index 4e3aa6991..431015396 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Table as ArrowTable, Data, Schema } from "./arrow"; +import { Data, Schema, SchemaLike, TableLike } from "./arrow"; import { fromTableToBuffer, makeEmptyTable } from "./arrow"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { Connection as LanceDbConnection } from "./native"; @@ -50,7 +50,7 @@ export interface CreateTableOptions { * The default is true while the new format is in beta */ useLegacyFormat?: boolean; - schema?: Schema; + schema?: SchemaLike; embeddingFunction?: EmbeddingFunctionConfig; } @@ -167,12 +167,12 @@ export abstract class Connection { /** * Creates a new Table and initialize it with new data. * @param {string} name - The name of the table. - * @param {Record[] | ArrowTable} data - Non-empty Array of Records + * @param {Record[] | TableLike} data - Non-empty Array of Records * to be inserted into the table */ abstract createTable( name: string, - data: Record[] | ArrowTable, + data: Record[] | TableLike, options?: Partial, ): Promise; @@ -183,7 +183,7 @@ export abstract class Connection { */ abstract createEmptyTable( name: string, - schema: Schema, + schema: import("./arrow").SchemaLike, options?: Partial, ): Promise
; @@ -235,7 +235,7 @@ export class LocalConnection extends Connection { nameOrOptions: | string | ({ name: string; data: Data } & Partial), - data?: Record[] | ArrowTable, + data?: Record[] | TableLike, options?: Partial, ): Promise
{ if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) { @@ -259,7 +259,7 @@ export class LocalConnection extends Connection { async createEmptyTable( name: string, - schema: Schema, + schema: import("./arrow").SchemaLike, options?: Partial, ): Promise
{ let mode: string = options?.mode ?? "create"; diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index ec00d6e40..97829cda9 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -300,7 +300,9 @@ export class VectorQuery extends QueryBase { * * By default "l2" is used. */ - distanceType(distanceType: string): VectorQuery { + distanceType( + distanceType: Required["distanceType"], + ): VectorQuery { this.inner.distanceType(distanceType); return this; } diff --git a/nodejs/lancedb/remote/connection.ts b/nodejs/lancedb/remote/connection.ts index 59d8d61ad..20bc04584 100644 --- a/nodejs/lancedb/remote/connection.ts +++ b/nodejs/lancedb/remote/connection.ts @@ -1,5 +1,10 @@ import { Schema } from "apache-arrow"; -import { Data, fromTableToStreamBuffer, makeEmptyTable } from "../arrow"; +import { + Data, + SchemaLike, + fromTableToStreamBuffer, + makeEmptyTable, +} from "../arrow"; import { Connection, CreateTableOptions, @@ -156,7 +161,7 @@ export class RemoteConnection extends Connection { async createEmptyTable( name: string, - schema: Schema, + schema: SchemaLike, options?: Partial | undefined, ): Promise
{ if (options?.mode) { diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index cebbc9e3e..35c08c8da 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -20,10 +20,12 @@ // comes from the exact same library instance. This is not always the case // and so we must sanitize the input to ensure that it is compatible. +import { BufferType, Data } from "apache-arrow"; import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type"; import { Binary, Bool, + DataLike, DataType, DateDay, DateMillisecond, @@ -56,9 +58,14 @@ import { Map_, Null, type Precision, + RecordBatch, + RecordBatchLike, Schema, + SchemaLike, SparseUnion, Struct, + Table, + TableLike, Time, TimeMicrosecond, TimeMillisecond, @@ -488,7 +495,7 @@ export function sanitizeField(fieldLike: unknown): Field { * instance because they might be using a different instance of apache-arrow * than lancedb is using. */ -export function sanitizeSchema(schemaLike: unknown): Schema { +export function sanitizeSchema(schemaLike: SchemaLike): Schema { if (schemaLike instanceof Schema) { return schemaLike; } @@ -514,3 +521,68 @@ export function sanitizeSchema(schemaLike: unknown): Schema { ); return new Schema(sanitizedFields, metadata); } + +export function sanitizeTable(tableLike: TableLike): Table { + if (tableLike instanceof Table) { + return tableLike; + } + if (typeof tableLike !== "object" || tableLike === null) { + throw Error("Expected a Table but object was null/undefined"); + } + if (!("schema" in tableLike)) { + throw Error( + "The table passed in does not appear to be a table (no 'schema' property)", + ); + } + if (!("batches" in tableLike)) { + throw Error( + "The table passed in does not appear to be a table (no 'columns' property)", + ); + } + const schema = sanitizeSchema(tableLike.schema); + + const batches = tableLike.batches.map(sanitizeRecordBatch); + return new Table(schema, batches); +} + +function sanitizeRecordBatch(batchLike: RecordBatchLike): RecordBatch { + if (batchLike instanceof RecordBatch) { + return batchLike; + } + if (typeof batchLike !== "object" || batchLike === null) { + throw Error("Expected a RecordBatch but object was null/undefined"); + } + if (!("schema" in batchLike)) { + throw Error( + "The record batch passed in does not appear to be a record batch (no 'schema' property)", + ); + } + if (!("data" in batchLike)) { + throw Error( + "The record batch passed in does not appear to be a record batch (no 'data' property)", + ); + } + const schema = sanitizeSchema(batchLike.schema); + const data = sanitizeData(batchLike.data); + return new RecordBatch(schema, data); +} +function sanitizeData( + dataLike: DataLike, + // biome-ignore lint/suspicious/noExplicitAny: +): import("apache-arrow").Data> { + if (dataLike instanceof Data) { + return dataLike; + } + return new Data( + dataLike.type, + dataLike.offset, + dataLike.length, + dataLike.nullCount, + { + [BufferType.OFFSET]: dataLike.valueOffsets, + [BufferType.DATA]: dataLike.values, + [BufferType.VALIDITY]: dataLike.nullBitmap, + [BufferType.TYPE]: dataLike.typeIds, + }, + ); +} diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 1ad5249af..588a0d078 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -17,6 +17,7 @@ import { Data, IntoVector, Schema, + TableLike, fromDataToBuffer, fromTableToBuffer, fromTableToStreamBuffer, @@ -38,6 +39,8 @@ import { Table as _NativeTable, } from "./native"; import { Query, VectorQuery } from "./query"; +import { sanitizeTable } from "./sanitize"; +export { IndexConfig } from "./native"; /** * Options for adding data to a table. @@ -381,8 +384,7 @@ export abstract class Table { abstract indexStats(name: string): Promise; static async parseTableData( - // biome-ignore lint/suspicious/noExplicitAny: - data: Record[] | ArrowTable, + data: Record[] | TableLike, options?: Partial, streaming = false, ) { @@ -395,9 +397,9 @@ export abstract class Table { let table: ArrowTable; if (isArrowTable(data)) { - table = data; + table = sanitizeTable(data); } else { - table = makeArrowTable(data, options); + table = makeArrowTable(data as Record[], options); } if (streaming) { const buf = await fromTableToStreamBuffer( diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index c2a89bbbe..18c4e5922 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1889,6 +1889,7 @@ impl TableInternal for NativeTable { } columns.push(field.name.clone()); } + let index_type = if is_vector { crate::index::IndexType::IvfPq } else {