Skip to content

Commit

Permalink
refactor(agent): refactor completion postprocess and caching. (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
icycodes authored Oct 17, 2023
1 parent 2060d47 commit be5e766
Show file tree
Hide file tree
Showing 23 changed files with 418 additions and 339 deletions.
2 changes: 1 addition & 1 deletion clients/intellij/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
"devDependencies": {
"cpy-cli": "^4.2.0",
"rimraf": "^5.0.1",
"tabby-agent": "0.3.1"
"tabby-agent": "0.4.0-dev"
}
}
3 changes: 1 addition & 2 deletions clients/tabby-agent/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "tabby-agent",
"version": "0.3.1",
"version": "0.4.0-dev",
"description": "Generic client agent for Tabby AI coding assistant IDE extensions.",
"repository": "https://github.com/TabbyML/tabby",
"main": "./dist/index.js",
Expand Down Expand Up @@ -41,7 +41,6 @@
"jwt-decode": "^3.1.2",
"lru-cache": "^9.1.1",
"object-hash": "^3.0.0",
"object-sizeof": "^2.6.1",
"openapi-fetch": "^0.7.6",
"pino": "^8.14.1",
"rotating-file-stream": "^3.1.0",
Expand Down
13 changes: 3 additions & 10 deletions clients/tabby-agent/src/Agent.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import type { components as ApiComponents } from "./types/tabbyApi";
import { AgentConfig, PartialAgentConfig } from "./AgentConfig";
import { CompletionRequest, CompletionResponse, CompletionContext } from "./CompletionContext";

export { CompletionRequest, CompletionResponse, CompletionContext };

export type ClientProperties = Partial<{
user: Record<string, any>;
Expand All @@ -13,16 +16,6 @@ export type AgentInitOptions = Partial<{

export type ServerHealthState = ApiComponents["schemas"]["HealthState"];

export type CompletionRequest = {
filepath: string;
language: string;
text: string;
position: number;
manually?: boolean;
};

export type CompletionResponse = ApiComponents["schemas"]["CompletionResponse"];

export type LogEventRequest = ApiComponents["schemas"]["LogEventRequest"] & {
select_kind?: "line";
};
Expand Down
242 changes: 151 additions & 91 deletions clients/tabby-agent/src/CompletionCache.ts
Original file line number Diff line number Diff line change
@@ -1,144 +1,204 @@
import { LRUCache } from "lru-cache";
import hashObject from "object-hash";
import sizeOfObject from "object-sizeof";
import { CompletionRequest, CompletionResponse } from "./Agent";
import { CompletionContext, CompletionResponse } from "./Agent";
import { rootLogger } from "./logger";
import { splitLines, splitWords } from "./utils";
import { splitLines, autoClosingPairOpenings, autoClosingPairClosings, findUnpairedAutoClosingChars } from "./utils";

type CompletionCacheKey = CompletionRequest;
type CompletionCacheKey = CompletionContext;
type CompletionCacheValue = CompletionResponse;

export class CompletionCache {
private readonly logger = rootLogger.child({ component: "CompletionCache" });
private cache: LRUCache<string, CompletionCacheValue>;
private cache: LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>;
private options = {
maxSize: 1 * 1024 * 1024, // 1MB
partiallyAcceptedCacheGeneration: {
maxCount: 10000,
prebuildCache: {
enabled: true,
perCharacter: {
lines: 1,
words: 10,
max: 30,
},
perWord: {
lines: 1,
max: 20,
max: 50,
},
perLine: {
max: 10,
},
autoClosingPairCheck: {
max: 3,
},
},
};

constructor() {
this.cache = new LRUCache<string, CompletionCacheValue>({
maxSize: this.options.maxSize,
sizeCalculation: sizeOfObject,
this.cache = new LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>({
max: this.options.maxCount,
});
}

has(key: CompletionCacheKey): boolean {
return this.cache.has(this.hash(key));
return this.cache.has(key.hash);
}

set(key: CompletionCacheKey, value: CompletionCacheValue): void {
for (const entry of this.createCacheEntries(key, value)) {
this.logger.debug({ entry }, "Setting cache entry");
this.cache.set(this.hash(entry.key), entry.value);
}
this.logger.debug({ size: this.cache.calculatedSize }, "Cache size");
buildCache(key: CompletionCacheKey, value: CompletionCacheValue): void {
this.logger.debug({ key, value }, "Starting to build cache");
const entries = this.createCacheEntries(key, value);
entries.forEach((entry) => {
this.cache.set(entry.key.hash, { value: entry.value, rebuildFlag: entry.rebuildFlag });
});
this.logger.debug({ newEntries: entries.length, cacheSize: this.cache.size }, "Cache updated");
}

get(key: CompletionCacheKey): CompletionCacheValue | undefined {
return this.cache.get(this.hash(key));
}

private hash(key: CompletionCacheKey): string {
return hashObject(key);
const entry = this.cache.get(key.hash);
if (entry?.rebuildFlag) {
this.buildCache(key, entry?.value);
}
return entry?.value;
}

private createCacheEntries(
key: CompletionCacheKey,
value: CompletionCacheValue,
): { key: CompletionCacheKey; value: CompletionCacheValue }[] {
const list = [{ key, value }];
if (this.options.partiallyAcceptedCacheGeneration.enabled) {
const entries = value.choices
.map((choice) => {
return this.calculatePartiallyAcceptedPositions(choice.text).map((position) => {
return {
prefix: choice.text.slice(0, position),
suffix: choice.text.slice(position),
choiceIndex: choice.index,
): { key: CompletionCacheKey; value: CompletionCacheValue; rebuildFlag: boolean }[] {
const list = [{ key, value, rebuildFlag: false }];
if (this.options.prebuildCache.enabled) {
for (const choice of value.choices) {
const completionText = choice.text.slice(key.position - choice.replaceRange.start);
const perLinePositions = this.getPerLinePositions(completionText);
this.logger.trace({ completionText, perLinePositions }, "Calculate per-line cache positions");
for (const position of perLinePositions) {
const completionTextPrefix = completionText.slice(0, position);
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix);
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) {
const entry = {
key: new CompletionContext({
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + position,
}),
value: {
...value,
choices: [
{
index: choice.index,
text: completionText.slice(position),
replaceRange: {
start: key.position + position,
end: key.position + position,
},
},
],
},
rebuildFlag: true,
};
});
})
.flat()
.reduce((grouped: { [key: string]: { suffix: string; choiceIndex: number }[] }, entry) => {
grouped[entry.prefix] = grouped[entry.prefix] || [];
grouped[entry.prefix].push({ suffix: entry.suffix, choiceIndex: entry.choiceIndex });
return grouped;
}, {});
for (const prefix in entries) {
const cacheKey = {
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + prefix.length,
};
const cacheValue = {
...value,
choices: entries[prefix].map((choice) => {
return {
index: choice.choiceIndex,
text: choice.suffix,
this.logger.trace({ prefix, entry }, "Build per-line cache entry");
list.push(entry);
}
}
const perCharacterPositions = this.getPerCharacterPositions(completionText);
this.logger.trace({ completionText, perCharacterPositions }, "Calculate per-character cache positions");
for (const position of perCharacterPositions) {
let lineStart = position;
while (lineStart > 0 && completionText[lineStart - 1] !== "\n") {
lineStart--;
}
const completionTextPrefix = completionText.slice(0, position);
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix);
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) {
const entry = {
key: new CompletionContext({
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + position,
}),
value: {
...value,
choices: [
{
index: choice.index,
text: completionText.slice(lineStart),
replaceRange: {
start: key.position + lineStart,
end: key.position + position,
},
},
],
},
rebuildFlag: false,
};
}),
};
list.push({
key: cacheKey,
value: cacheValue,
});
this.logger.trace({ prefix, entry }, "Build per-character cache entry");
list.push(entry);
}
}
}
}
return list;
const result = list.reduce((prev, curr) => {
const found = prev.find((entry) => entry.key.hash === curr.key.hash);
if (found) {
found.value.choices.push(...curr.value.choices);
found.rebuildFlag = found.rebuildFlag || curr.rebuildFlag;
} else {
prev.push(curr);
}
return prev;
}, []);
return result;
}

private calculatePartiallyAcceptedPositions(completion: string): number[] {
const positions: number[] = [];
const option = this.options.partiallyAcceptedCacheGeneration;

// positions for every line end (before newline character) and line begin (after indent)
private getPerLinePositions(completion: string): number[] {
const result: number[] = [];
const option = this.options.prebuildCache;
const lines = splitLines(completion);
let index = 0;
let offset = 0;
// `index < lines.length - 1` to exclude the last line
while (index < lines.length - 1 && index < option.perLine.max) {
offset += lines[index].length;
positions.push(offset);
index++;
}

const words = lines.slice(0, option.perWord.lines).map(splitWords).flat();
index = 0;
offset = 0;
while (index < words.length && index < option.perWord.max) {
offset += words[index].length;
positions.push(offset);
result.push(offset - 1); // cache at the end of the line (before newline character)
result.push(offset); // cache at the beginning of the next line (after newline character)
let offsetNextLineBegin = offset;
while (offsetNextLineBegin < completion.length && completion[offsetNextLineBegin].match(/\s/)) {
offsetNextLineBegin++;
}
result.push(offsetNextLineBegin); // cache at the beginning of the next line (after indent)
index++;
}
return result;
}

const characters = lines
.slice(0, option.perCharacter.lines)
.map(splitWords)
.flat()
.slice(0, option.perCharacter.words)
.join("");
offset = 1;
while (offset < characters.length && offset < option.perCharacter.max) {
positions.push(offset);
// positions for every character in the leading lines
private getPerCharacterPositions(completion: string): number[] {
const result: number[] = [];
const option = this.options.prebuildCache;
const text = splitLines(completion).slice(0, option.perCharacter.lines).join("");
let offset = 0;
while (offset < text.length && offset < option.perCharacter.max) {
result.push(offset);
offset++;
}
return result;
}

// distinct and sort ascending
return positions.filter((v, i, arr) => arr.indexOf(v) === i).sort((a, b) => a - b);
// "function(" => ["function()"]
// "call([" => ["call([]", "call([])" ]
// "function(arg" => ["function(arg)" ]
private generateAutoClosedPrefixes(prefix: string): string[] {
const result: string[] = [];
const unpaired = findUnpairedAutoClosingChars(prefix);
for (
let checkIndex = 0, autoClosing = "";
checkIndex < this.options.prebuildCache.autoClosingPairCheck.max;
checkIndex++
) {
if (unpaired.length > checkIndex) {
const found = autoClosingPairOpenings.indexOf(unpaired[unpaired.length - 1 - checkIndex]);
if (found < 0) {
break;
}
autoClosing = autoClosing + autoClosingPairClosings[found];
result.push(prefix + autoClosing);
} else {
break;
}
}
return result;
}
}
Loading

0 comments on commit be5e766

Please sign in to comment.