Skip to content

Commit

Permalink
feat(nodejs): feature parity [6/N] - make public interface work with …
Browse files Browse the repository at this point in the history
…multiple arrow versions (lancedb#1392)

previously we didnt have great compatibility with other versions of
apache arrow. This should bridge that gap a bit.


depends on lancedb#1391
see actual diff here
universalmind303/lancedb@query-filter...universalmind303:arrow-compatibility
  • Loading branch information
universalmind303 authored Jun 25, 2024
1 parent a866b78 commit 79a1667
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 37 deletions.
6 changes: 4 additions & 2 deletions nodejs/__test__/table.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
let tmpDir: tmp.DirResult;
let table: Table;

const schema = new arrow.Schema([
const schema:
| import("apache-arrow").Schema
| import("apache-arrow-old").Schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);

Expand Down Expand Up @@ -315,7 +317,7 @@ describe("When creating an index", () => {
.query()
.limit(2)
.nearestTo(queryVec)
.distanceType("DoT")
.distanceType("dot")
.toArrow();
expect(rst.numRows).toBe(2);

Expand Down
93 changes: 73 additions & 20 deletions nodejs/lancedb/arrow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import {
Table as ArrowTable,
Binary,
BufferType,
DataType,
Field,
FixedSizeBinary,
Expand All @@ -37,14 +38,68 @@ import {
type makeTable,
vectorFromArray,
} from "apache-arrow";
import { Buffers } from "apache-arrow/data";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
import {
sanitizeField,
sanitizeSchema,
sanitizeTable,
sanitizeType,
} from "./sanitize";
export * from "apache-arrow";
export type SchemaLike =
| Schema
| {
fields: FieldLike[];
metadata: Map<string, string>;
get names(): unknown[];
};
export type FieldLike =
| Field
| {
type: string;
name: string;
nullable?: boolean;
metadata?: Map<string, string>;
};

export type DataLike =
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
| import("apache-arrow").Data<Struct<any>>
| {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
type: any;
length: number;
offset: number;
stride: number;
nullable: boolean;
children: DataLike[];
get nullCount(): number;
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
values: Buffers<any>[BufferType.DATA];
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
typeIds: Buffers<any>[BufferType.TYPE];
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
nullBitmap: Buffers<any>[BufferType.VALIDITY];
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
valueOffsets: Buffers<any>[BufferType.OFFSET];
};

export type RecordBatchLike =
| RecordBatch
| {
schema: SchemaLike;
data: DataLike;
};

export type TableLike =
| ArrowTable
| { schema: SchemaLike; batches: RecordBatchLike[] };

export type IntoVector = Float32Array | Float64Array | number[];

export function isArrowTable(value: object): value is ArrowTable {
export function isArrowTable(value: object): value is TableLike {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;
}
Expand Down Expand Up @@ -135,7 +190,7 @@ export function isFixedSizeList(value: unknown): value is FixedSizeList {
}

/** Data type accepted by NodeJS SDK */
export type Data = Record<string, unknown>[] | ArrowTable;
export type Data = Record<string, unknown>[] | TableLike;

/*
* Options to control how a column should be converted to a vector array
Expand All @@ -162,7 +217,7 @@ export class MakeArrowTableOptions {
* The schema must be specified if there are no records (e.g. to make
* an empty table)
*/
schema?: Schema;
schema?: SchemaLike;

/*
* Mapping from vector column name to expected type
Expand Down Expand Up @@ -310,7 +365,7 @@ export function makeArrowTable(
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(
opt.schema,
opt.schema as Schema,
data,
options?.embeddingFunction,
);
Expand Down Expand Up @@ -394,7 +449,7 @@ export function makeArrowTable(
// `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns);
const batchesFixed = firstTable.batches.map(
(batch) => new RecordBatch(opt.schema!, batch.data),
(batch) => new RecordBatch(opt.schema as Schema, batch.data),
);
let schema: Schema;
if (metadata !== undefined) {
Expand All @@ -407,9 +462,9 @@ export function makeArrowTable(
}
}

schema = new Schema(opt.schema.fields, schemaMetadata);
schema = new Schema(opt.schema.fields as Field[], schemaMetadata);
} else {
schema = opt.schema;
schema = opt.schema as Schema;
}
return new ArrowTable(schema, batchesFixed);
}
Expand All @@ -425,7 +480,7 @@ export function makeArrowTable(
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable(
schema: Schema,
schema: SchemaLike,
metadata?: Map<string, string>,
): ArrowTable {
return makeArrowTable([], { schema }, metadata);
Expand Down Expand Up @@ -563,18 +618,17 @@ async function applyEmbeddingsFromMetadata(
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
schema?: SchemaLike,
): Promise<ArrowTable> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
if (schema?.metadata.has("embedding_functions")) {
return applyEmbeddingsFromMetadata(table, schema!);
return applyEmbeddingsFromMetadata(table, schema! as Schema);
} else if (embeddings == null || embeddings === undefined) {
return table;
}

if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}

// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
const name = table.schema.fields[idx].name;
Expand Down Expand Up @@ -650,7 +704,7 @@ async function applyEmbeddings<T>(
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
);
}
return alignTable(newTable, schema);
return alignTable(newTable, schema as Schema);
}
return newTable;
}
Expand Down Expand Up @@ -744,7 +798,7 @@ export async function fromRecordsToStreamBuffer(
export async function fromTableToBuffer(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
schema?: SchemaLike,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
Expand All @@ -771,7 +825,7 @@ export async function fromDataToBuffer(
schema = sanitizeSchema(schema);
}
if (isArrowTable(data)) {
return fromTableToBuffer(data, embeddings, schema);
return fromTableToBuffer(sanitizeTable(data), embeddings, schema);
} else {
const table = await convertToTable(data, embeddings, { schema });
return fromTableToBuffer(table);
Expand All @@ -789,7 +843,7 @@ export async function fromDataToBuffer(
export async function fromTableToStreamBuffer(
table: ArrowTable,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
schema?: SchemaLike,
): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
Expand Down Expand Up @@ -854,7 +908,6 @@ function validateSchemaEmbeddings(
for (let field of schema.fields) {
if (isFixedSizeList(field.type)) {
field = sanitizeField(field);

if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(
Expand Down
14 changes: 7 additions & 7 deletions nodejs/lancedb/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import { Table as ArrowTable, Data, Schema } from "./arrow";
import { Data, Schema, SchemaLike, TableLike } from "./arrow";
import { fromTableToBuffer, makeEmptyTable } from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { Connection as LanceDbConnection } from "./native";
Expand Down Expand Up @@ -50,7 +50,7 @@ export interface CreateTableOptions {
* The default is true while the new format is in beta
*/
useLegacyFormat?: boolean;
schema?: Schema;
schema?: SchemaLike;
embeddingFunction?: EmbeddingFunctionConfig;
}

Expand Down Expand Up @@ -167,12 +167,12 @@ export abstract class Connection {
/**
* Creates a new Table and initialize it with new data.
* @param {string} name - The name of the table.
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
* @param {Record<string, unknown>[] | TableLike} data - Non-empty Array of Records
* to be inserted into the table
*/
abstract createTable(
name: string,
data: Record<string, unknown>[] | ArrowTable,
data: Record<string, unknown>[] | TableLike,
options?: Partial<CreateTableOptions>,
): Promise<Table>;

Expand All @@ -183,7 +183,7 @@ export abstract class Connection {
*/
abstract createEmptyTable(
name: string,
schema: Schema,
schema: import("./arrow").SchemaLike,
options?: Partial<CreateTableOptions>,
): Promise<Table>;

Expand Down Expand Up @@ -235,7 +235,7 @@ export class LocalConnection extends Connection {
nameOrOptions:
| string
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
data?: Record<string, unknown>[] | ArrowTable,
data?: Record<string, unknown>[] | TableLike,
options?: Partial<CreateTableOptions>,
): Promise<Table> {
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
Expand All @@ -259,7 +259,7 @@ export class LocalConnection extends Connection {

async createEmptyTable(
name: string,
schema: Schema,
schema: import("./arrow").SchemaLike,
options?: Partial<CreateTableOptions>,
): Promise<Table> {
let mode: string = options?.mode ?? "create";
Expand Down
4 changes: 3 additions & 1 deletion nodejs/lancedb/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
*
* By default "l2" is used.
*/
distanceType(distanceType: string): VectorQuery {
distanceType(
distanceType: Required<IvfPqOptions>["distanceType"],
): VectorQuery {
this.inner.distanceType(distanceType);
return this;
}
Expand Down
9 changes: 7 additions & 2 deletions nodejs/lancedb/remote/connection.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Schema } from "apache-arrow";
import { Data, fromTableToStreamBuffer, makeEmptyTable } from "../arrow";
import {
Data,
SchemaLike,
fromTableToStreamBuffer,
makeEmptyTable,
} from "../arrow";
import {
Connection,
CreateTableOptions,
Expand Down Expand Up @@ -156,7 +161,7 @@ export class RemoteConnection extends Connection {

async createEmptyTable(
name: string,
schema: Schema,
schema: SchemaLike,
options?: Partial<CreateTableOptions> | undefined,
): Promise<Table> {
if (options?.mode) {
Expand Down
Loading

0 comments on commit 79a1667

Please sign in to comment.