Skip to content

Commit

Permalink
feat(llm): add caching support
Browse files Browse the repository at this point in the history
Ref: #69
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Oct 7, 2024
1 parent e4ec120 commit e330fb2
Show file tree
Hide file tree
Showing 16 changed files with 336 additions and 56 deletions.
34 changes: 34 additions & 0 deletions docs/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,40 @@ _Source: [examples/cache/toolCache.ts](/examples/cache/toolCache.ts)_
>
> Cache key is created by serializing function parameters (the order of keys in the object does not matter).
### Usage with LLMs

<!-- embedme examples/cache/llmCache.ts -->

```ts
import { SlidingCache } from "bee-agent-framework/cache/slidingCache";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";

const llm = new OllamaChatLLM({
modelId: "llama3.1",
parameters: {
temperature: 0,
num_predict: 50,
},
cache: new SlidingCache({
size: 50,
}),
});

console.info(await llm.cache.size()); // 0
const first = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]);
// upcoming requests with the EXACTLY same input will be retrieved from the cache
console.info(await llm.cache.size()); // 1
const second = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]);
console.info(first === second); // true
```

_Source: [examples/cache/llmCache.ts](/examples/cache/toolCache.ts)_

> [!TIP]
>
> Caching for non-chat LLMs works exactly the same way.
## Cache types

The framework provides multiple out-of-the-box cache implementations.
Expand Down
21 changes: 21 additions & 0 deletions examples/cache/llmCache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { SlidingCache } from "bee-agent-framework/cache/slidingCache";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";

const llm = new OllamaChatLLM({
modelId: "llama3.1",
parameters: {
temperature: 0,
num_predict: 50,
},
cache: new SlidingCache({
size: 50,
}),
});

console.info(await llm.cache.size()); // 0
const first = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]);
// upcoming requests with the EXACTLY same input will be retrieved from the cache
console.info(await llm.cache.size()); // 1
const second = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]);
console.info(first === second); // true
21 changes: 14 additions & 7 deletions src/adapters/bam/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
* limitations under the License.
*/

import { AsyncStream, GenerateCallbacks, LLMError } from "@/llms/base.js";
import {
AsyncStream,
GenerateCallbacks,
LLMCache,
LLMError,
StreamGenerateOptions,
} from "@/llms/base.js";
import { isFunction, isObjectType } from "remeda";
import {
BAMLLM,
Expand Down Expand Up @@ -88,6 +94,7 @@ export interface BAMChatLLMInputConfig {
export interface BAMChatLLMInput {
llm: BAMLLM;
config: BAMChatLLMInputConfig;
cache?: LLMCache<BAMChatLLMOutput>;
}

export class BAMChatLLM extends ChatLLM<BAMChatLLMOutput> {
Expand All @@ -99,8 +106,8 @@ export class BAMChatLLM extends ChatLLM<BAMChatLLMOutput> {
public readonly llm: BAMLLM;
protected readonly config: BAMChatLLMInputConfig;

constructor({ llm, config }: BAMChatLLMInput) {
super(llm.modelId, llm.executionOptions);
constructor({ llm, config, cache }: BAMChatLLMInput) {
super(llm.modelId, llm.executionOptions, cache);
this.llm = llm;
this.config = config;
}
Expand Down Expand Up @@ -130,8 +137,8 @@ export class BAMChatLLM extends ChatLLM<BAMChatLLMOutput> {

protected async _generate(
messages: BaseMessage[],
options: BAMLLMGenerateOptions,
run: GetRunContext<this>,
options: BAMLLMGenerateOptions | undefined,
run: GetRunContext<typeof this>,
): Promise<BAMChatLLMOutput> {
const prompt = this.messagesToPrompt(messages);
// @ts-expect-error protected property
Expand All @@ -141,8 +148,8 @@ export class BAMChatLLM extends ChatLLM<BAMChatLLMOutput> {

protected async *_stream(
messages: BaseMessage[],
options: BAMLLMGenerateOptions,
run: GetRunContext<this>,
options: StreamGenerateOptions | undefined,
run: GetRunContext<typeof this>,
): AsyncStream<BAMChatLLMOutput, void> {
const prompt = this.messagesToPrompt(messages);
// @ts-expect-error protected property
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/bam/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMError,
LLMMeta,
LLMOutputError,
Expand Down Expand Up @@ -163,6 +164,7 @@ export interface BAMLLMInput {
modelId: string;
parameters?: BAMLLMParameters;
executionOptions?: ExecutionOptions;
cache?: LLMCache<BAMLLMOutput>;
}

export class BAMLLM extends LLM<BAMLLMOutput, BAMLLMGenerateOptions> {
Expand All @@ -174,8 +176,8 @@ export class BAMLLM extends LLM<BAMLLMOutput, BAMLLMGenerateOptions> {
public readonly client: Client;
public readonly parameters: Partial<BAMLLMParameters>;

constructor({ client, parameters, modelId, executionOptions = {} }: BAMLLMInput) {
super(modelId, executionOptions);
constructor({ client, parameters, modelId, cache, executionOptions = {} }: BAMLLMInput) {
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
this.parameters = parameters ?? {};
}
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/groq/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMMeta,
StreamGenerateOptions,
} from "@/llms/base.js";
Expand Down Expand Up @@ -87,6 +88,7 @@ interface Input {
client?: Client;
parameters?: Parameters;
executionOptions?: ExecutionOptions;
cache?: LLMCache<ChatGroqOutput>;
}

export class GroqChatLLM extends ChatLLM<ChatGroqOutput> {
Expand All @@ -105,8 +107,9 @@ export class GroqChatLLM extends ChatLLM<ChatGroqOutput> {
temperature: 0,
},
executionOptions = {},
cache,
}: Input = {}) {
super(modelId, executionOptions);
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
this.parameters = parameters ?? {};
}
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/ibm-vllm/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
AsyncStream,
BaseLLMTokenizeOutput,
GenerateCallbacks,
LLMCache,
LLMError,
LLMMeta,
} from "@/llms/base.js";
Expand Down Expand Up @@ -87,6 +88,7 @@ export interface IBMVllmInputConfig {
export interface GrpcChatLLMInput {
llm: IBMvLLM;
config: IBMVllmInputConfig;
cache?: LLMCache<GrpcChatLLMOutput>;
}

export class IBMVllmChatLLM extends ChatLLM<GrpcChatLLMOutput> {
Expand All @@ -98,8 +100,8 @@ export class IBMVllmChatLLM extends ChatLLM<GrpcChatLLMOutput> {
public readonly llm: IBMvLLM;
protected readonly config: IBMVllmInputConfig;

constructor({ llm, config }: GrpcChatLLMInput) {
super(llm.modelId, llm.executionOptions);
constructor({ llm, config, cache }: GrpcChatLLMInput) {
super(llm.modelId, llm.executionOptions, cache);
this.llm = llm;
this.config = config;
}
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/ibm-vllm/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMError,
LLMMeta,
} from "@/llms/base.js";
Expand Down Expand Up @@ -88,6 +89,7 @@ export interface IBMvLLMInput {
modelId: string;
parameters?: IBMvLLMParameters;
executionOptions?: ExecutionOptions;
cache?: LLMCache<IBMvLLMOutput>;
}

export type IBMvLLMParameters = NonNullable<
Expand All @@ -105,8 +107,8 @@ export class IBMvLLM extends LLM<IBMvLLMOutput, IBMvLLMGenerateOptions> {
public readonly client: Client;
public readonly parameters: Partial<IBMvLLMParameters>;

constructor({ client, modelId, parameters = {}, executionOptions }: IBMvLLMInput) {
super(modelId, executionOptions);
constructor({ client, modelId, parameters = {}, executionOptions, cache }: IBMvLLMInput) {
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
this.parameters = parameters ?? {};
}
Expand Down
4 changes: 3 additions & 1 deletion src/adapters/langchain/llms/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMMeta,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
Expand Down Expand Up @@ -97,8 +98,9 @@ export class LangChainChatLLM<
public readonly lcLLM: BaseChatModel<CallOptions, OutputMessageType>,
protected modelMeta?: LLMMeta,
executionOptions?: ExecutionOptions,
cache?: LLMCache<LangChainChatLLMOutput>,
) {
super(lcLLM._modelType(), executionOptions);
super(lcLLM._modelType(), executionOptions, cache);
this.parameters = lcLLM.invocationParams();
}

Expand Down
10 changes: 6 additions & 4 deletions src/adapters/langchain/llms/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import {
BaseLLMTokenizeOutput,
ExecutionOptions,
GenerateCallbacks,
InternalGenerateOptions,
GenerateOptions,
LLMCache,
LLMMeta,
StreamGenerateOptions,
} from "@/llms/base.js";
Expand Down Expand Up @@ -80,8 +81,9 @@ export class LangChainLLM extends LLM<LangChainLLMOutput> {
public readonly lcLLM: LCBaseLLM,
private modelMeta?: LLMMeta,
executionOptions?: ExecutionOptions,
cache?: LLMCache<LangChainLLMOutput>,
) {
super(lcLLM._modelType(), executionOptions);
super(lcLLM._modelType(), executionOptions, cache);
this.parameters = lcLLM.invocationParams();
}

Expand All @@ -107,7 +109,7 @@ export class LangChainLLM extends LLM<LangChainLLMOutput> {

protected async _generate(
input: LLMInput,
options: InternalGenerateOptions,
_options: GenerateOptions | undefined,
run: GetRunContext<this>,
): Promise<LangChainLLMOutput> {
const { generations } = await this.lcLLM.generate([input], {
Expand All @@ -118,7 +120,7 @@ export class LangChainLLM extends LLM<LangChainLLMOutput> {

protected async *_stream(
input: string,
options: StreamGenerateOptions,
_options: StreamGenerateOptions | undefined,
run: GetRunContext<this>,
): AsyncStream<LangChainLLMOutput> {
const response = this.lcLLM._streamResponseChunks(input, {
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/ollama/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMOutputError,
StreamGenerateOptions,
} from "@/llms/base.js";
Expand Down Expand Up @@ -111,6 +112,7 @@ interface Input {
client?: Client;
parameters?: Partial<Parameters>;
executionOptions?: ExecutionOptions;
cache?: LLMCache<OllamaChatLLMOutput>;
}

export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
Expand All @@ -123,11 +125,11 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {
public readonly parameters: Partial<Parameters>;

constructor(
{ client, modelId, parameters, executionOptions = {} }: Input = {
{ client, modelId, parameters, executionOptions = {}, cache }: Input = {
modelId: "llama3.1",
},
) {
super(modelId, executionOptions);
super(modelId, executionOptions, cache);
this.client = client ?? new Client({ fetch });
this.parameters = parameters ?? {
temperature: 0,
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/ollama/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMMeta,
LLMOutputError,
StreamGenerateOptions,
Expand All @@ -41,6 +42,7 @@ interface Input {
client?: Client;
parameters?: Partial<Parameters>;
executionOptions?: ExecutionOptions;
cache?: LLMCache<OllamaLLMOutput>;
}

export class OllamaLLMOutput extends BaseLLMOutput {
Expand Down Expand Up @@ -115,8 +117,8 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {
registerClient();
}

constructor({ client, modelId, parameters, executionOptions = {} }: Input) {
super(modelId, executionOptions);
constructor({ client, modelId, parameters, executionOptions = {}, cache }: Input) {
super(modelId, executionOptions, cache);
this.client = client ?? new Client();
this.parameters = parameters ?? {};
}
Expand Down
Loading

0 comments on commit e330fb2

Please sign in to comment.