Skip to content

Commit

Permalink
(fix): serializations and restore
Browse files Browse the repository at this point in the history
  • Loading branch information
lightningspirit committed Jul 26, 2024
1 parent 5393e52 commit 549edac
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 33 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@moveyourdigital/nlp",
"version": "0.1.3",
"version": "0.1.4",
"description": "NLP (natural language processing) for server and the browser in TypeScript. All lightweight and super-fast.",
"type": "module",
"source": "src/index.ts",
Expand Down
70 changes: 42 additions & 28 deletions src/classifiers/bayes-classifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

import { isObject } from "../utils/is-object.js";
import { Classification, Classifier, Document, Label, ModelAsJSON, Observation, Stats } from "./classifier.js";

/**
Expand All @@ -34,42 +35,47 @@ type Properties = {
smoothing: number
}

/**
* Serialized format
*/
type Plain<T extends Observation, K extends Label> = {
features: Record<T, number>
matrix: Partial<Record<K, number[]>>
corpus?: Document<T, K>[]
properties: Properties
stats: Stats<K>
}

/**
* Class representing a Naive Bayes classifier.
* @template T - Type extending Observation
* @template K - Type extending Label
* @extends {Classifier<T, K>}
*/
export class BayesClassifier<T extends Observation, K extends Label> extends Classifier<T, K> {
/**
* @private
* @type {Partial<Record<T, number>>}
*/
private features: Map<T, number> = new Map()

/**
* @private
* @type {Partial<Record<K, number[]>>}
*/
private matrix: Partial<Record<K, number[]>> = {}

/**
* @private
* @type {Document<T, K>[]}
*/
private corpus: Document<T, K>[] = []

/**
* @public
* @type {Properties}
*/
private properties: Properties = {
smoothing: 1.0
}

/**
* @private
* @type {Stats<K>}
*/
readonly stats: Stats<K> = {
labels: {},
Expand All @@ -78,9 +84,6 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla

/**
* Sets a property of the classifier.
* @param {keyof Properties} prop - The property to set.
* @param {Properties[K]} value - The value to set.
* @returns {this} The classifier instance.
*/
set<K extends keyof Properties>(prop: K, value: Properties[K]): this {
this.properties[prop] = value;
Expand All @@ -89,8 +92,6 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla

/**
* Gets a property of the classifier.
* @param {keyof Properties} prop - The property to get.
* @returns {Properties[K]} The value of property.
*/
get<K extends keyof Properties>(prop: K): Properties[K] {
return this.properties[prop];
Expand Down Expand Up @@ -126,14 +127,14 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
label,
observation,
});

observation.forEach(
(token) => this.features.set(
token,
token,
(this.features.get(token) || 0) + 1
)
);

return this;
}

Expand All @@ -150,9 +151,9 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
if (label in this.matrix) {
vector.forEach((value, index) => {
this.matrix[label]![index] =
this.matrix[label]![index] + value;
this.matrix[label]![index] + value;
})

} else {
this.matrix[label] = vector.map((v) => v + 1 + this.properties.smoothing);
}
Expand All @@ -173,15 +174,22 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
* @param {ModelAsJSON} data - The model in JSON format.
* @returns {this} The classifier instance.
*/
restore(data: ModelAsJSON | Record<string, unknown>): this {
restore(data: ModelAsJSON | Plain<T, K>): this {
try {
const model = typeof data === "string" ? JSON.parse(data) : data;
const model = typeof data === "string" ? JSON.parse(data) as Partial<Plain<T, K>> : data;

Object.getOwnPropertyNames(this).forEach((key) => {
if (key in model && typeof model[key] === 'object' && model[key] !== null) {
this[key] = model[key]
if ('features' in model && isObject<T, number>(model.features)) {
Object
.entries(model.features)
.forEach(([key, value]) => this.features.set(key as T, value as number))
}

['matrix', 'corpus', 'properties', 'stats'].forEach((prop) => {
if (prop in model && isObject(model[prop])) {
this[prop] = model[prop]
}
})

} catch { /* empty */ }

return this
Expand All @@ -199,14 +207,20 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
serializer,
}: {
compact?: boolean
serializer?: (x: Record<string, unknown>) => ModelAsJSON
serializer?: (x: Plain<T, K>) => ModelAsJSON
} = {}): ModelAsJSON {
return (serializer || JSON.stringify)({
features: Object.fromEntries(this.features),
const model: Plain<T, K> = {
features: Object.fromEntries(this.features) as Record<T, number>,
matrix: this.matrix,
corpus: compact ? [] : this.corpus,
properties: this.properties,
stats: this.stats,
})
}

if (!compact) {
model.corpus = this.corpus
}

return (serializer || JSON.stringify)(model)
}

/**
Expand All @@ -217,7 +231,7 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
static of<T extends Observation, K extends Label>(model: ModelAsJSON): BayesClassifier<T, K> {
return new this<T, K>().restore(model)
}

/**
* Converts an observation to a feature vector.
* @private
Expand All @@ -232,7 +246,7 @@ export class BayesClassifier<T extends Observation, K extends Label> extends Cla
observation.includes(feature as T) ? 1 : 0
)
}

return vector
}

Expand Down
28 changes: 28 additions & 0 deletions src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,32 @@ describe('bayes classification', () => {
assert(classification[0].score > classification[1].score)
})
})

describe('serializing / restoring', () => {
const model = new BayesClassifier<string, string>()
.set('smoothing', 0.1)
.addDocument({
observation: ['one', 'two', 'three', 'four'],
label: 'numbers',
})
.addDocument({
observation: ['water', 'earth', 'fire', 'wind'],
label: 'elements',
})
.train()
.toJSON({ compact: true })

test('should match good label', () => {
const classifier = new BayesClassifier<string, string>()
.restore(model)

const classification = classifier.classify(
['one', 'water', 'one', 'fire', 'earth']
)

assert.equal(classifier.get('smoothing'), 0.1)
assert.deepStrictEqual(classifier.stats, { labels: { numbers: 1, elements: 1 }, corpus: 2 })
assert.equal(classification[0].label, 'elements')
})
})
})
2 changes: 1 addition & 1 deletion src/normalizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class Normalizer {
* @param {Stemmer} stemmer - The stemmer to use for reducing tokens to their root form.
* @param {StopWords} [stopwords=[]] - Optional array of stop words to remove from the tokens.
*/
constructor(private tokenizer: ITokenizer, private stemmer: IStemmer, private stopwords: StopWords = []) {}
constructor(private tokenizer: ITokenizer, private stemmer: IStemmer, private stopwords: StopWords = []) { }

/**
* Normalizes the given text by tokenizing, removing stop words, and stemming each token.
Expand Down
2 changes: 1 addition & 1 deletion src/tokenizers/regexp-tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export class RegexpTokenizer extends Tokenizer {
* Creates an instance of RegexpTokenizer.
* @param {RegExp} pattern - The regular expression pattern used for tokenization.
*/
constructor (private pattern: RegExp) {
constructor(private pattern: RegExp) {
super()
}

Expand Down
4 changes: 2 additions & 2 deletions src/tokenizers/tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ export abstract class Tokenizer {
* @returns {string[]} An array of tokenized strings.
*/
abstract tokenize(text: string): string[]

/**
* Trims the tokens by filtering out any tokens that are `undefined`, `null`, or whitespace-only strings.
* @param {(string | undefined | null)[]} tokens - An array of tokens to trim.
* @returns {string[]} An array of trimmed tokens.
*/
trim (tokens: (string | undefined | null)[]): string[] {
trim(tokens: (string | undefined | null)[]): string[] {
return tokens.filter((token): token is string => !(token && token.trim()))
}
}
3 changes: 3 additions & 0 deletions src/utils/is-object.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export function isObject<K extends string | number | symbol, T>(subject: unknown): subject is Record<K, T> {
return !!(subject && typeof subject === "object")
}

0 comments on commit 549edac

Please sign in to comment.