-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(agent): refactor completion postprocess and caching. (#576)
- Loading branch information
Showing
23 changed files
with
418 additions
and
339 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,144 +1,204 @@ | ||
import { LRUCache } from "lru-cache"; | ||
import hashObject from "object-hash"; | ||
import sizeOfObject from "object-sizeof"; | ||
import { CompletionRequest, CompletionResponse } from "./Agent"; | ||
import { CompletionContext, CompletionResponse } from "./Agent"; | ||
import { rootLogger } from "./logger"; | ||
import { splitLines, splitWords } from "./utils"; | ||
import { splitLines, autoClosingPairOpenings, autoClosingPairClosings, findUnpairedAutoClosingChars } from "./utils"; | ||
|
||
type CompletionCacheKey = CompletionRequest; | ||
type CompletionCacheKey = CompletionContext; | ||
type CompletionCacheValue = CompletionResponse; | ||
|
||
export class CompletionCache { | ||
private readonly logger = rootLogger.child({ component: "CompletionCache" }); | ||
private cache: LRUCache<string, CompletionCacheValue>; | ||
private cache: LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>; | ||
private options = { | ||
maxSize: 1 * 1024 * 1024, // 1MB | ||
partiallyAcceptedCacheGeneration: { | ||
maxCount: 10000, | ||
prebuildCache: { | ||
enabled: true, | ||
perCharacter: { | ||
lines: 1, | ||
words: 10, | ||
max: 30, | ||
}, | ||
perWord: { | ||
lines: 1, | ||
max: 20, | ||
max: 50, | ||
}, | ||
perLine: { | ||
max: 10, | ||
}, | ||
autoClosingPairCheck: { | ||
max: 3, | ||
}, | ||
}, | ||
}; | ||
|
||
constructor() { | ||
this.cache = new LRUCache<string, CompletionCacheValue>({ | ||
maxSize: this.options.maxSize, | ||
sizeCalculation: sizeOfObject, | ||
this.cache = new LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>({ | ||
max: this.options.maxCount, | ||
}); | ||
} | ||
|
||
has(key: CompletionCacheKey): boolean { | ||
return this.cache.has(this.hash(key)); | ||
return this.cache.has(key.hash); | ||
} | ||
|
||
set(key: CompletionCacheKey, value: CompletionCacheValue): void { | ||
for (const entry of this.createCacheEntries(key, value)) { | ||
this.logger.debug({ entry }, "Setting cache entry"); | ||
this.cache.set(this.hash(entry.key), entry.value); | ||
} | ||
this.logger.debug({ size: this.cache.calculatedSize }, "Cache size"); | ||
buildCache(key: CompletionCacheKey, value: CompletionCacheValue): void { | ||
this.logger.debug({ key, value }, "Starting to build cache"); | ||
const entries = this.createCacheEntries(key, value); | ||
entries.forEach((entry) => { | ||
this.cache.set(entry.key.hash, { value: entry.value, rebuildFlag: entry.rebuildFlag }); | ||
}); | ||
this.logger.debug({ newEntries: entries.length, cacheSize: this.cache.size }, "Cache updated"); | ||
} | ||
|
||
get(key: CompletionCacheKey): CompletionCacheValue | undefined { | ||
return this.cache.get(this.hash(key)); | ||
} | ||
|
||
private hash(key: CompletionCacheKey): string { | ||
return hashObject(key); | ||
const entry = this.cache.get(key.hash); | ||
if (entry?.rebuildFlag) { | ||
this.buildCache(key, entry?.value); | ||
} | ||
return entry?.value; | ||
} | ||
|
||
private createCacheEntries( | ||
key: CompletionCacheKey, | ||
value: CompletionCacheValue, | ||
): { key: CompletionCacheKey; value: CompletionCacheValue }[] { | ||
const list = [{ key, value }]; | ||
if (this.options.partiallyAcceptedCacheGeneration.enabled) { | ||
const entries = value.choices | ||
.map((choice) => { | ||
return this.calculatePartiallyAcceptedPositions(choice.text).map((position) => { | ||
return { | ||
prefix: choice.text.slice(0, position), | ||
suffix: choice.text.slice(position), | ||
choiceIndex: choice.index, | ||
): { key: CompletionCacheKey; value: CompletionCacheValue; rebuildFlag: boolean }[] { | ||
const list = [{ key, value, rebuildFlag: false }]; | ||
if (this.options.prebuildCache.enabled) { | ||
for (const choice of value.choices) { | ||
const completionText = choice.text.slice(key.position - choice.replaceRange.start); | ||
const perLinePositions = this.getPerLinePositions(completionText); | ||
this.logger.trace({ completionText, perLinePositions }, "Calculate per-line cache positions"); | ||
for (const position of perLinePositions) { | ||
const completionTextPrefix = completionText.slice(0, position); | ||
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix); | ||
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) { | ||
const entry = { | ||
key: new CompletionContext({ | ||
...key, | ||
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position), | ||
position: key.position + position, | ||
}), | ||
value: { | ||
...value, | ||
choices: [ | ||
{ | ||
index: choice.index, | ||
text: completionText.slice(position), | ||
replaceRange: { | ||
start: key.position + position, | ||
end: key.position + position, | ||
}, | ||
}, | ||
], | ||
}, | ||
rebuildFlag: true, | ||
}; | ||
}); | ||
}) | ||
.flat() | ||
.reduce((grouped: { [key: string]: { suffix: string; choiceIndex: number }[] }, entry) => { | ||
grouped[entry.prefix] = grouped[entry.prefix] || []; | ||
grouped[entry.prefix].push({ suffix: entry.suffix, choiceIndex: entry.choiceIndex }); | ||
return grouped; | ||
}, {}); | ||
for (const prefix in entries) { | ||
const cacheKey = { | ||
...key, | ||
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position), | ||
position: key.position + prefix.length, | ||
}; | ||
const cacheValue = { | ||
...value, | ||
choices: entries[prefix].map((choice) => { | ||
return { | ||
index: choice.choiceIndex, | ||
text: choice.suffix, | ||
this.logger.trace({ prefix, entry }, "Build per-line cache entry"); | ||
list.push(entry); | ||
} | ||
} | ||
const perCharacterPositions = this.getPerCharacterPositions(completionText); | ||
this.logger.trace({ completionText, perCharacterPositions }, "Calculate per-character cache positions"); | ||
for (const position of perCharacterPositions) { | ||
let lineStart = position; | ||
while (lineStart > 0 && completionText[lineStart - 1] !== "\n") { | ||
lineStart--; | ||
} | ||
const completionTextPrefix = completionText.slice(0, position); | ||
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix); | ||
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) { | ||
const entry = { | ||
key: new CompletionContext({ | ||
...key, | ||
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position), | ||
position: key.position + position, | ||
}), | ||
value: { | ||
...value, | ||
choices: [ | ||
{ | ||
index: choice.index, | ||
text: completionText.slice(lineStart), | ||
replaceRange: { | ||
start: key.position + lineStart, | ||
end: key.position + position, | ||
}, | ||
}, | ||
], | ||
}, | ||
rebuildFlag: false, | ||
}; | ||
}), | ||
}; | ||
list.push({ | ||
key: cacheKey, | ||
value: cacheValue, | ||
}); | ||
this.logger.trace({ prefix, entry }, "Build per-character cache entry"); | ||
list.push(entry); | ||
} | ||
} | ||
} | ||
} | ||
return list; | ||
const result = list.reduce((prev, curr) => { | ||
const found = prev.find((entry) => entry.key.hash === curr.key.hash); | ||
if (found) { | ||
found.value.choices.push(...curr.value.choices); | ||
found.rebuildFlag = found.rebuildFlag || curr.rebuildFlag; | ||
} else { | ||
prev.push(curr); | ||
} | ||
return prev; | ||
}, []); | ||
return result; | ||
} | ||
|
||
private calculatePartiallyAcceptedPositions(completion: string): number[] { | ||
const positions: number[] = []; | ||
const option = this.options.partiallyAcceptedCacheGeneration; | ||
|
||
// positions for every line end (before newline character) and line begin (after indent) | ||
private getPerLinePositions(completion: string): number[] { | ||
const result: number[] = []; | ||
const option = this.options.prebuildCache; | ||
const lines = splitLines(completion); | ||
let index = 0; | ||
let offset = 0; | ||
// `index < lines.length - 1` to exclude the last line | ||
while (index < lines.length - 1 && index < option.perLine.max) { | ||
offset += lines[index].length; | ||
positions.push(offset); | ||
index++; | ||
} | ||
|
||
const words = lines.slice(0, option.perWord.lines).map(splitWords).flat(); | ||
index = 0; | ||
offset = 0; | ||
while (index < words.length && index < option.perWord.max) { | ||
offset += words[index].length; | ||
positions.push(offset); | ||
result.push(offset - 1); // cache at the end of the line (before newline character) | ||
result.push(offset); // cache at the beginning of the next line (after newline character) | ||
let offsetNextLineBegin = offset; | ||
while (offsetNextLineBegin < completion.length && completion[offsetNextLineBegin].match(/\s/)) { | ||
offsetNextLineBegin++; | ||
} | ||
result.push(offsetNextLineBegin); // cache at the beginning of the next line (after indent) | ||
index++; | ||
} | ||
return result; | ||
} | ||
|
||
const characters = lines | ||
.slice(0, option.perCharacter.lines) | ||
.map(splitWords) | ||
.flat() | ||
.slice(0, option.perCharacter.words) | ||
.join(""); | ||
offset = 1; | ||
while (offset < characters.length && offset < option.perCharacter.max) { | ||
positions.push(offset); | ||
// positions for every character in the leading lines | ||
private getPerCharacterPositions(completion: string): number[] { | ||
const result: number[] = []; | ||
const option = this.options.prebuildCache; | ||
const text = splitLines(completion).slice(0, option.perCharacter.lines).join(""); | ||
let offset = 0; | ||
while (offset < text.length && offset < option.perCharacter.max) { | ||
result.push(offset); | ||
offset++; | ||
} | ||
return result; | ||
} | ||
|
||
// distinct and sort ascending | ||
return positions.filter((v, i, arr) => arr.indexOf(v) === i).sort((a, b) => a - b); | ||
// "function(" => ["function()"] | ||
// "call([" => ["call([]", "call([])" ] | ||
// "function(arg" => ["function(arg)" ] | ||
private generateAutoClosedPrefixes(prefix: string): string[] { | ||
const result: string[] = []; | ||
const unpaired = findUnpairedAutoClosingChars(prefix); | ||
for ( | ||
let checkIndex = 0, autoClosing = ""; | ||
checkIndex < this.options.prebuildCache.autoClosingPairCheck.max; | ||
checkIndex++ | ||
) { | ||
if (unpaired.length > checkIndex) { | ||
const found = autoClosingPairOpenings.indexOf(unpaired[unpaired.length - 1 - checkIndex]); | ||
if (found < 0) { | ||
break; | ||
} | ||
autoClosing = autoClosing + autoClosingPairClosings[found]; | ||
result.push(prefix + autoClosing); | ||
} else { | ||
break; | ||
} | ||
} | ||
return result; | ||
} | ||
} |
Oops, something went wrong.