Skip to content

Commit

Permalink
refactor: move node specific code into separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
Loïc Mangeonjean committed Apr 15, 2024
1 parent bb49b83 commit 1be1a2b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 36 deletions.
48 changes: 48 additions & 0 deletions lib/index.node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { ModelOperations, ModelOperationsOptions } from './index';
export * from './index'

type NodeModelOperationsOptions = Omit<ModelOperationsOptions, 'modelJsonLoaderFunc' | 'weightsLoaderFunc'> & Partial<ModelOperationsOptions>

class NodeModelOperations extends ModelOperations {
private static NODE_MODEL_JSON_FUNC: () => Promise<{ [key:string]: any }> = async () => {
const fs = await import('fs');
const path = await import('path');

return new Promise<any>((resolve, reject) => {
fs.readFile(path.join(__dirname, '..', '..', 'model', 'model.json'), (err, data) => {
if(err) {
reject(err);
return;
}
resolve(JSON.parse(data.toString()));
});
});
}

private static NODE_WEIGHTS_FUNC: () => Promise<ArrayBuffer> = async () => {
const fs = await import('fs');
const path = await import('path');

return new Promise<ArrayBuffer>((resolve, reject) => {
fs.readFile(path.join(__dirname, '..', '..', 'model', 'group1-shard1of1.bin'), (err, data) => {
if(err) {
reject(err);
return;
}
resolve(data.buffer);
});
});
}

constructor(modelOptions?: NodeModelOperationsOptions) {
super({
modelJsonLoaderFunc: modelOptions?.modelJsonLoaderFunc ?? NodeModelOperations.NODE_MODEL_JSON_FUNC,
weightsLoaderFunc: modelOptions?.weightsLoaderFunc ?? NodeModelOperations.NODE_WEIGHTS_FUNC,
...modelOptions
})
}
}

export {
NodeModelOperations as ModelOperations
}
40 changes: 5 additions & 35 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class InMemoryIOHandler implements io.IOHandler {
}

export interface ModelOperationsOptions {
modelJsonLoaderFunc?: () => Promise<{ [key:string]: any }>;
weightsLoaderFunc?: () => Promise<ArrayBuffer>;
modelJsonLoaderFunc: () => Promise<{ [key:string]: any }>;
weightsLoaderFunc: () => Promise<ArrayBuffer>;
minContentSize?: number;
maxContentSize?: number;
normalizeNewline?: boolean;
Expand All @@ -82,36 +82,6 @@ export class ModelOperations {
private static DEFAULT_MAX_CONTENT_SIZE = 100000;
private static DEFAULT_MIN_CONTENT_SIZE = 20;

private static NODE_MODEL_JSON_FUNC: () => Promise<{ [key:string]: any }> = async () => {
const fs = await import('fs');
const path = await import('path');

return new Promise<any>((resolve, reject) => {
fs.readFile(path.join(__dirname, '..', '..', 'model', 'model.json'), (err, data) => {
if(err) {
reject(err);
return;
}
resolve(JSON.parse(data.toString()));
});
});
}

private static NODE_WEIGHTS_FUNC: () => Promise<ArrayBuffer> = async () => {
const fs = await import('fs');
const path = await import('path');

return new Promise<ArrayBuffer>((resolve, reject) => {
fs.readFile(path.join(__dirname, '..', '..', 'model', 'group1-shard1of1.bin'), (err, data) => {
if(err) {
reject(err);
return;
}
resolve(data.buffer);
});
});
}

private _model: GraphModel | undefined;
private _modelJson: io.ModelJSON | undefined;
private _weights: ArrayBuffer | undefined;
Expand All @@ -121,9 +91,9 @@ export class ModelOperations {
private readonly _weightsLoaderFunc: () => Promise<ArrayBuffer>;
private readonly _normalizeNewline: boolean;

constructor(modelOptions?: ModelOperationsOptions) {
this._modelJsonLoaderFunc = modelOptions?.modelJsonLoaderFunc ?? ModelOperations.NODE_MODEL_JSON_FUNC;
this._weightsLoaderFunc = modelOptions?.weightsLoaderFunc ?? ModelOperations.NODE_WEIGHTS_FUNC;
constructor(modelOptions: ModelOperationsOptions) {
this._modelJsonLoaderFunc = modelOptions?.modelJsonLoaderFunc;
this._weightsLoaderFunc = modelOptions?.weightsLoaderFunc;
this._minContentSize = modelOptions?.minContentSize ?? ModelOperations.DEFAULT_MIN_CONTENT_SIZE;
this._maxContentSize = modelOptions?.maxContentSize ?? ModelOperations.DEFAULT_MAX_CONTENT_SIZE;
this._normalizeNewline = modelOptions?.normalizeNewline ?? true;
Expand Down
2 changes: 1 addition & 1 deletion rollup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const isProduction = process.env.NODE_ENV === 'production'

export default rollup.defineConfig([{
input: {
index: './lib/index.ts'
index: './lib/index.node.ts'
},
output: [{
sourcemap: true,
Expand Down

0 comments on commit 1be1a2b

Please sign in to comment.