From 549edac72d2090c96d90eaca500410a2883ac334 Mon Sep 17 00:00:00 2001 From: Vitor Carvalho Date: Fri, 26 Jul 2024 11:32:13 +0100 Subject: [PATCH] (fix): serializations and restore --- package.json | 2 +- src/classifiers/bayes-classifier.ts | 70 +++++++++++++++++------------ src/index.test.ts | 28 ++++++++++++ src/normalizer.ts | 2 +- src/tokenizers/regexp-tokenizer.ts | 2 +- src/tokenizers/tokenizer.ts | 4 +- src/utils/is-object.ts | 3 ++ 7 files changed, 78 insertions(+), 33 deletions(-) create mode 100644 src/utils/is-object.ts diff --git a/package.json b/package.json index f437ade..d70a6bb 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@moveyourdigital/nlp", - "version": "0.1.3", + "version": "0.1.4", "description": "NLP (natural language processing) for server and the browser in TypeScript. All lightweight and super-fast.", "type": "module", "source": "src/index.ts", diff --git a/src/classifiers/bayes-classifier.ts b/src/classifiers/bayes-classifier.ts index 7f8a721..ee7c13c 100644 --- a/src/classifiers/bayes-classifier.ts +++ b/src/classifiers/bayes-classifier.ts @@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +import { isObject } from "../utils/is-object.js"; import { Classification, Classifier, Document, Label, ModelAsJSON, Observation, Stats } from "./classifier.js"; /** @@ -34,34 +35,40 @@ type Properties = { smoothing: number } +/** + * Serialized format + */ +type Plain = { + features: Record + matrix: Partial> + corpus?: Document[] + properties: Properties + stats: Stats +} + /** * Class representing a Naive Bayes classifier. * @template T - Type extending Observation * @template K - Type extending Label - * @extends {Classifier} */ export class BayesClassifier extends Classifier { /** * @private - * @type {Partial>} */ private features: Map = new Map() /** * @private - * @type {Partial>} */ private matrix: Partial> = {} /** * @private - * @type {Document[]} */ private corpus: Document[] = [] /** * @public - * @type {Properties} */ private properties: Properties = { smoothing: 1.0 @@ -69,7 +76,6 @@ export class BayesClassifier extends Cla /** * @private - * @type {Stats} */ readonly stats: Stats = { labels: {}, @@ -78,9 +84,6 @@ export class BayesClassifier extends Cla /** * Sets a property of the classifier. - * @param {keyof Properties} prop - The property to set. - * @param {Properties[K]} value - The value to set. - * @returns {this} The classifier instance. */ set(prop: K, value: Properties[K]): this { this.properties[prop] = value; @@ -89,8 +92,6 @@ export class BayesClassifier extends Cla /** * Gets a property of the classifier. - * @param {keyof Properties} prop - The property to get. - * @returns {Properties[K]} The value of property. */ get(prop: K): Properties[K] { return this.properties[prop]; @@ -126,14 +127,14 @@ export class BayesClassifier extends Cla label, observation, }); - + observation.forEach( (token) => this.features.set( - token, + token, (this.features.get(token) || 0) + 1 ) ); - + return this; } @@ -150,9 +151,9 @@ export class BayesClassifier extends Cla if (label in this.matrix) { vector.forEach((value, index) => { this.matrix[label]![index] = - this.matrix[label]![index] + value; + this.matrix[label]![index] + value; }) - + } else { this.matrix[label] = vector.map((v) => v + 1 + this.properties.smoothing); } @@ -173,15 +174,22 @@ export class BayesClassifier extends Cla * @param {ModelAsJSON} data - The model in JSON format. * @returns {this} The classifier instance. */ - restore(data: ModelAsJSON | Record): this { + restore(data: ModelAsJSON | Plain): this { try { - const model = typeof data === "string" ? JSON.parse(data) : data; + const model = typeof data === "string" ? JSON.parse(data) as Partial> : data; - Object.getOwnPropertyNames(this).forEach((key) => { - if (key in model && typeof model[key] === 'object' && model[key] !== null) { - this[key] = model[key] + if ('features' in model && isObject(model.features)) { + Object + .entries(model.features) + .forEach(([key, value]) => this.features.set(key as T, value as number)) + } + + ['matrix', 'corpus', 'properties', 'stats'].forEach((prop) => { + if (prop in model && isObject(model[prop])) { + this[prop] = model[prop] } }) + } catch { /* empty */ } return this @@ -199,14 +207,20 @@ export class BayesClassifier extends Cla serializer, }: { compact?: boolean - serializer?: (x: Record) => ModelAsJSON + serializer?: (x: Plain) => ModelAsJSON } = {}): ModelAsJSON { - return (serializer || JSON.stringify)({ - features: Object.fromEntries(this.features), + const model: Plain = { + features: Object.fromEntries(this.features) as Record, matrix: this.matrix, - corpus: compact ? [] : this.corpus, + properties: this.properties, stats: this.stats, - }) + } + + if (!compact) { + model.corpus = this.corpus + } + + return (serializer || JSON.stringify)(model) } /** @@ -217,7 +231,7 @@ export class BayesClassifier extends Cla static of(model: ModelAsJSON): BayesClassifier { return new this().restore(model) } - + /** * Converts an observation to a feature vector. * @private @@ -232,7 +246,7 @@ export class BayesClassifier extends Cla observation.includes(feature as T) ? 1 : 0 ) } - + return vector } diff --git a/src/index.test.ts b/src/index.test.ts index c591b88..6ecbf54 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -53,4 +53,32 @@ describe('bayes classification', () => { assert(classification[0].score > classification[1].score) }) }) + + describe('serializing / restoring', () => { + const model = new BayesClassifier() + .set('smoothing', 0.1) + .addDocument({ + observation: ['one', 'two', 'three', 'four'], + label: 'numbers', + }) + .addDocument({ + observation: ['water', 'earth', 'fire', 'wind'], + label: 'elements', + }) + .train() + .toJSON({ compact: true }) + + test('should match good label', () => { + const classifier = new BayesClassifier() + .restore(model) + + const classification = classifier.classify( + ['one', 'water', 'one', 'fire', 'earth'] + ) + + assert.equal(classifier.get('smoothing'), 0.1) + assert.deepStrictEqual(classifier.stats, { labels: { numbers: 1, elements: 1 }, corpus: 2 }) + assert.equal(classification[0].label, 'elements') + }) + }) }) diff --git a/src/normalizer.ts b/src/normalizer.ts index 88f4f74..4d2e97e 100644 --- a/src/normalizer.ts +++ b/src/normalizer.ts @@ -61,7 +61,7 @@ export class Normalizer { * @param {Stemmer} stemmer - The stemmer to use for reducing tokens to their root form. * @param {StopWords} [stopwords=[]] - Optional array of stop words to remove from the tokens. */ - constructor(private tokenizer: ITokenizer, private stemmer: IStemmer, private stopwords: StopWords = []) {} + constructor(private tokenizer: ITokenizer, private stemmer: IStemmer, private stopwords: StopWords = []) { } /** * Normalizes the given text by tokenizing, removing stop words, and stemming each token. diff --git a/src/tokenizers/regexp-tokenizer.ts b/src/tokenizers/regexp-tokenizer.ts index 34bde6c..ebec634 100644 --- a/src/tokenizers/regexp-tokenizer.ts +++ b/src/tokenizers/regexp-tokenizer.ts @@ -31,7 +31,7 @@ export class RegexpTokenizer extends Tokenizer { * Creates an instance of RegexpTokenizer. * @param {RegExp} pattern - The regular expression pattern used for tokenization. */ - constructor (private pattern: RegExp) { + constructor(private pattern: RegExp) { super() } diff --git a/src/tokenizers/tokenizer.ts b/src/tokenizers/tokenizer.ts index 09b3b1a..ac3729b 100644 --- a/src/tokenizers/tokenizer.ts +++ b/src/tokenizers/tokenizer.ts @@ -32,13 +32,13 @@ export abstract class Tokenizer { * @returns {string[]} An array of tokenized strings. */ abstract tokenize(text: string): string[] - + /** * Trims the tokens by filtering out any tokens that are `undefined`, `null`, or whitespace-only strings. * @param {(string | undefined | null)[]} tokens - An array of tokens to trim. * @returns {string[]} An array of trimmed tokens. */ - trim (tokens: (string | undefined | null)[]): string[] { + trim(tokens: (string | undefined | null)[]): string[] { return tokens.filter((token): token is string => !(token && token.trim())) } } diff --git a/src/utils/is-object.ts b/src/utils/is-object.ts new file mode 100644 index 0000000..6e41ccb --- /dev/null +++ b/src/utils/is-object.ts @@ -0,0 +1,3 @@ +export function isObject(subject: unknown): subject is Record { + return !!(subject && typeof subject === "object") +}