Skip to content

Commit

Permalink
Merge pull request #10343 from keymanapp/change/web/lmworker-async-pr…
Browse files Browse the repository at this point in the history
…edict

change(web): prep for better asynchronous prediction handling 🕐
  • Loading branch information
jahorton authored Jun 13, 2024
2 parents b0ff102 + e5f2fbd commit 3cf0082
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 69 deletions.
21 changes: 11 additions & 10 deletions common/web/lm-worker/src/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,12 @@ export default class LMLayerWorker {
switch(payload.message) {
case 'predict':
var {transform, context} = payload;
var suggestions = compositor.predict(transform, context);

// Now that the suggestions are ready, send them out!
this.cast('suggestions', {
token: payload.token,
suggestions: suggestions
compositor.predict(transform, context).then((suggestions) => {
// Now that the suggestions are ready, send them out!
this.cast('suggestions', {
token: payload.token,
suggestions: suggestions
});
});
break;
case 'wordbreak':
Expand All @@ -358,11 +358,12 @@ export default class LMLayerWorker {
break;
case 'revert':
var {reversion, context} = payload;
var suggestions: Suggestion[] = compositor.applyReversion(reversion, context);

this.cast('postrevert', {
token: payload.token,
suggestions: suggestions
compositor.applyReversion(reversion, context).then((suggestions) => {
this.cast('postrevert', {
token: payload.token,
suggestions: suggestions
});
});
break;
case 'reset-context':
Expand Down
9 changes: 5 additions & 4 deletions common/web/lm-worker/src/main/model-compositor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export default class ModelCompositor {
return returnedPredictions;
}

predict(transformDistribution: Transform | Distribution<Transform>, context: Context): Suggestion[] {
async predict(transformDistribution: Transform | Distribution<Transform>, context: Context): Promise<Suggestion[]> {
let suggestionDistribution: Distribution<Suggestion> = [];
let lexicalModel = this.lexicalModel;
let punctuation = this.punctuation;
Expand Down Expand Up @@ -648,18 +648,19 @@ export default class ModelCompositor {
return reversion;
}

applyReversion(reversion: Reversion, context: Context): Suggestion[] {
async applyReversion(reversion: Reversion, context: Context): Promise<Suggestion[]> {
// If we are unable to track context (because the model does not support LexiconTraversal),
// we need a "fallback" strategy.
let compositor = this;
let fallbackSuggestions = function() {
let fallbackSuggestions = async function() {
let revertedContext = models.applyTransform(reversion.transform, context);
let suggestions = compositor.predict({insert: '', deleteLeft: 0}, revertedContext);
const suggestions = await compositor.predict({insert: '', deleteLeft: 0}, revertedContext);
suggestions.forEach(function(suggestion) {
// A reversion's transform ID is the additive inverse of its original suggestion;
// we revert to the state of said original suggestion.
suggestion.transformId = -reversion.transformId;
});

return suggestions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import DummyModel from '#./models/dummy-model.js';
import ModelCompositor from '#./model-compositor.js';

describe('Custom Punctuation', function () {
it('appears in the keep suggestion', function () {
it('appears in the keep suggestion', async function () {
let dummySuggestions = [{
transform: {
insert: 'Hello',
Expand Down Expand Up @@ -37,7 +37,7 @@ describe('Custom Punctuation', function () {
// The model compositor is responsible for adding this to the display as
// string.
var composite = new ModelCompositor(model, true);
var suggestions = composite.predict([{ sample: { insert: 'o', deleteLeft: 0 }, p: 1.00 }], {
var suggestions = await composite.predict([{ sample: { insert: 'o', deleteLeft: 0 }, p: 1.00 }], {
left: 'Hrll', startOfBuffer: false, endOfBuffer: true
});
assert.lengthOf(suggestions, 3);
Expand All @@ -54,7 +54,7 @@ describe('Custom Punctuation', function () {
});

describe("insertAfterWord", function () {
it('appears after "word" suggestion', function () {
it('appears after "word" suggestion', async function () {
let dummySuggestions = [
{
transform: { insert: 'ᚈᚑᚋ', deleteLeft: 0, },
Expand Down Expand Up @@ -82,7 +82,7 @@ describe('Custom Punctuation', function () {
// The model compositor is responsible for adding this to the display as
// string.
var composite = new ModelCompositor(model, true);
var suggestions = composite.predict([{ sample: { insert: 'ᚋ', deleteLeft: 0 }, p: 1.00 }], {
var suggestions = await composite.predict([{ sample: { insert: 'ᚋ', deleteLeft: 0 }, p: 1.00 }], {
left: '᚛ᚈᚑ', startOfBuffer: false, endOfBuffer: true
});
assert.lengthOf(suggestions, dummySuggestions.length);
Expand Down
Loading

0 comments on commit 3cf0082

Please sign in to comment.