From e330fb246409296af41dce9f1bb56cde72cac556 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Mon, 7 Oct 2024 21:59:15 +0200 Subject: [PATCH] feat(llm): add caching support Ref: #69 Signed-off-by: Tomas Dvorak --- docs/cache.md | 34 +++++++ examples/cache/llmCache.ts | 21 ++++ src/adapters/bam/chat.ts | 21 ++-- src/adapters/bam/llm.ts | 6 +- src/adapters/groq/chat.ts | 5 +- src/adapters/ibm-vllm/chat.ts | 6 +- src/adapters/ibm-vllm/llm.ts | 6 +- src/adapters/langchain/llms/chat.ts | 4 +- src/adapters/langchain/llms/llm.ts | 10 +- src/adapters/ollama/chat.ts | 6 +- src/adapters/ollama/llm.ts | 6 +- src/adapters/openai/chat.ts | 18 +++- src/adapters/watsonx/chat.ts | 11 ++- src/adapters/watsonx/llm.ts | 8 +- src/llms/base.test.ts | 148 +++++++++++++++++++++++++++- src/llms/base.ts | 82 ++++++++++++--- 16 files changed, 336 insertions(+), 56 deletions(-) create mode 100644 examples/cache/llmCache.ts diff --git a/docs/cache.md b/docs/cache.md index 331f128b..e8eb34cf 100644 --- a/docs/cache.md +++ b/docs/cache.md @@ -99,6 +99,40 @@ _Source: [examples/cache/toolCache.ts](/examples/cache/toolCache.ts)_ > > Cache key is created by serializing function parameters (the order of keys in the object does not matter). +### Usage with LLMs + + + +```ts +import { SlidingCache } from "bee-agent-framework/cache/slidingCache"; +import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; + +const llm = new OllamaChatLLM({ + modelId: "llama3.1", + parameters: { + temperature: 0, + num_predict: 50, + }, + cache: new SlidingCache({ + size: 50, + }), +}); + +console.info(await llm.cache.size()); // 0 +const first = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]); +// upcoming requests with the EXACTLY same input will be retrieved from the cache +console.info(await llm.cache.size()); // 1 +const second = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]); +console.info(first === second); // true +``` + +_Source: [examples/cache/llmCache.ts](/examples/cache/toolCache.ts)_ + +> [!TIP] +> +> Caching for non-chat LLMs works exactly the same way. + ## Cache types The framework provides multiple out-of-the-box cache implementations. diff --git a/examples/cache/llmCache.ts b/examples/cache/llmCache.ts new file mode 100644 index 00000000..312c560d --- /dev/null +++ b/examples/cache/llmCache.ts @@ -0,0 +1,21 @@ +import { SlidingCache } from "bee-agent-framework/cache/slidingCache"; +import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; + +const llm = new OllamaChatLLM({ + modelId: "llama3.1", + parameters: { + temperature: 0, + num_predict: 50, + }, + cache: new SlidingCache({ + size: 50, + }), +}); + +console.info(await llm.cache.size()); // 0 +const first = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]); +// upcoming requests with the EXACTLY same input will be retrieved from the cache +console.info(await llm.cache.size()); // 1 +const second = await llm.generate([BaseMessage.of({ role: "user", text: "Who was Alan Turing?" })]); +console.info(first === second); // true diff --git a/src/adapters/bam/chat.ts b/src/adapters/bam/chat.ts index b15f32da..395ba974 100644 --- a/src/adapters/bam/chat.ts +++ b/src/adapters/bam/chat.ts @@ -14,7 +14,13 @@ * limitations under the License. */ -import { AsyncStream, GenerateCallbacks, LLMError } from "@/llms/base.js"; +import { + AsyncStream, + GenerateCallbacks, + LLMCache, + LLMError, + StreamGenerateOptions, +} from "@/llms/base.js"; import { isFunction, isObjectType } from "remeda"; import { BAMLLM, @@ -88,6 +94,7 @@ export interface BAMChatLLMInputConfig { export interface BAMChatLLMInput { llm: BAMLLM; config: BAMChatLLMInputConfig; + cache?: LLMCache; } export class BAMChatLLM extends ChatLLM { @@ -99,8 +106,8 @@ export class BAMChatLLM extends ChatLLM { public readonly llm: BAMLLM; protected readonly config: BAMChatLLMInputConfig; - constructor({ llm, config }: BAMChatLLMInput) { - super(llm.modelId, llm.executionOptions); + constructor({ llm, config, cache }: BAMChatLLMInput) { + super(llm.modelId, llm.executionOptions, cache); this.llm = llm; this.config = config; } @@ -130,8 +137,8 @@ export class BAMChatLLM extends ChatLLM { protected async _generate( messages: BaseMessage[], - options: BAMLLMGenerateOptions, - run: GetRunContext, + options: BAMLLMGenerateOptions | undefined, + run: GetRunContext, ): Promise { const prompt = this.messagesToPrompt(messages); // @ts-expect-error protected property @@ -141,8 +148,8 @@ export class BAMChatLLM extends ChatLLM { protected async *_stream( messages: BaseMessage[], - options: BAMLLMGenerateOptions, - run: GetRunContext, + options: StreamGenerateOptions | undefined, + run: GetRunContext, ): AsyncStream { const prompt = this.messagesToPrompt(messages); // @ts-expect-error protected property diff --git a/src/adapters/bam/llm.ts b/src/adapters/bam/llm.ts index e5a23b14..0f3d169e 100644 --- a/src/adapters/bam/llm.ts +++ b/src/adapters/bam/llm.ts @@ -22,6 +22,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMError, LLMMeta, LLMOutputError, @@ -163,6 +164,7 @@ export interface BAMLLMInput { modelId: string; parameters?: BAMLLMParameters; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export class BAMLLM extends LLM { @@ -174,8 +176,8 @@ export class BAMLLM extends LLM { public readonly client: Client; public readonly parameters: Partial; - constructor({ client, parameters, modelId, executionOptions = {} }: BAMLLMInput) { - super(modelId, executionOptions); + constructor({ client, parameters, modelId, cache, executionOptions = {} }: BAMLLMInput) { + super(modelId, executionOptions, cache); this.client = client ?? new Client(); this.parameters = parameters ?? {}; } diff --git a/src/adapters/groq/chat.ts b/src/adapters/groq/chat.ts index 84ddea47..f0ef6b08 100644 --- a/src/adapters/groq/chat.ts +++ b/src/adapters/groq/chat.ts @@ -20,6 +20,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMMeta, StreamGenerateOptions, } from "@/llms/base.js"; @@ -87,6 +88,7 @@ interface Input { client?: Client; parameters?: Parameters; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export class GroqChatLLM extends ChatLLM { @@ -105,8 +107,9 @@ export class GroqChatLLM extends ChatLLM { temperature: 0, }, executionOptions = {}, + cache, }: Input = {}) { - super(modelId, executionOptions); + super(modelId, executionOptions, cache); this.client = client ?? new Client(); this.parameters = parameters ?? {}; } diff --git a/src/adapters/ibm-vllm/chat.ts b/src/adapters/ibm-vllm/chat.ts index cec38521..a9a6229f 100644 --- a/src/adapters/ibm-vllm/chat.ts +++ b/src/adapters/ibm-vllm/chat.ts @@ -26,6 +26,7 @@ import { AsyncStream, BaseLLMTokenizeOutput, GenerateCallbacks, + LLMCache, LLMError, LLMMeta, } from "@/llms/base.js"; @@ -87,6 +88,7 @@ export interface IBMVllmInputConfig { export interface GrpcChatLLMInput { llm: IBMvLLM; config: IBMVllmInputConfig; + cache?: LLMCache; } export class IBMVllmChatLLM extends ChatLLM { @@ -98,8 +100,8 @@ export class IBMVllmChatLLM extends ChatLLM { public readonly llm: IBMvLLM; protected readonly config: IBMVllmInputConfig; - constructor({ llm, config }: GrpcChatLLMInput) { - super(llm.modelId, llm.executionOptions); + constructor({ llm, config, cache }: GrpcChatLLMInput) { + super(llm.modelId, llm.executionOptions, cache); this.llm = llm; this.config = config; } diff --git a/src/adapters/ibm-vllm/llm.ts b/src/adapters/ibm-vllm/llm.ts index 9f3c5485..58d53dd3 100644 --- a/src/adapters/ibm-vllm/llm.ts +++ b/src/adapters/ibm-vllm/llm.ts @@ -21,6 +21,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMError, LLMMeta, } from "@/llms/base.js"; @@ -88,6 +89,7 @@ export interface IBMvLLMInput { modelId: string; parameters?: IBMvLLMParameters; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export type IBMvLLMParameters = NonNullable< @@ -105,8 +107,8 @@ export class IBMvLLM extends LLM { public readonly client: Client; public readonly parameters: Partial; - constructor({ client, modelId, parameters = {}, executionOptions }: IBMvLLMInput) { - super(modelId, executionOptions); + constructor({ client, modelId, parameters = {}, executionOptions, cache }: IBMvLLMInput) { + super(modelId, executionOptions, cache); this.client = client ?? new Client(); this.parameters = parameters ?? {}; } diff --git a/src/adapters/langchain/llms/chat.ts b/src/adapters/langchain/llms/chat.ts index 4396fd1f..b908a9fd 100644 --- a/src/adapters/langchain/llms/chat.ts +++ b/src/adapters/langchain/llms/chat.ts @@ -20,6 +20,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMMeta, } from "@/llms/base.js"; import { shallowCopy } from "@/serializer/utils.js"; @@ -97,8 +98,9 @@ export class LangChainChatLLM< public readonly lcLLM: BaseChatModel, protected modelMeta?: LLMMeta, executionOptions?: ExecutionOptions, + cache?: LLMCache, ) { - super(lcLLM._modelType(), executionOptions); + super(lcLLM._modelType(), executionOptions, cache); this.parameters = lcLLM.invocationParams(); } diff --git a/src/adapters/langchain/llms/llm.ts b/src/adapters/langchain/llms/llm.ts index 8188a4e6..d90e6d41 100644 --- a/src/adapters/langchain/llms/llm.ts +++ b/src/adapters/langchain/llms/llm.ts @@ -22,7 +22,8 @@ import { BaseLLMTokenizeOutput, ExecutionOptions, GenerateCallbacks, - InternalGenerateOptions, + GenerateOptions, + LLMCache, LLMMeta, StreamGenerateOptions, } from "@/llms/base.js"; @@ -80,8 +81,9 @@ export class LangChainLLM extends LLM { public readonly lcLLM: LCBaseLLM, private modelMeta?: LLMMeta, executionOptions?: ExecutionOptions, + cache?: LLMCache, ) { - super(lcLLM._modelType(), executionOptions); + super(lcLLM._modelType(), executionOptions, cache); this.parameters = lcLLM.invocationParams(); } @@ -107,7 +109,7 @@ export class LangChainLLM extends LLM { protected async _generate( input: LLMInput, - options: InternalGenerateOptions, + _options: GenerateOptions | undefined, run: GetRunContext, ): Promise { const { generations } = await this.lcLLM.generate([input], { @@ -118,7 +120,7 @@ export class LangChainLLM extends LLM { protected async *_stream( input: string, - options: StreamGenerateOptions, + _options: StreamGenerateOptions | undefined, run: GetRunContext, ): AsyncStream { const response = this.lcLLM._streamResponseChunks(input, { diff --git a/src/adapters/ollama/chat.ts b/src/adapters/ollama/chat.ts index 5950fd35..eed77d1d 100644 --- a/src/adapters/ollama/chat.ts +++ b/src/adapters/ollama/chat.ts @@ -20,6 +20,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMOutputError, StreamGenerateOptions, } from "@/llms/base.js"; @@ -111,6 +112,7 @@ interface Input { client?: Client; parameters?: Partial; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export class OllamaChatLLM extends ChatLLM { @@ -123,11 +125,11 @@ export class OllamaChatLLM extends ChatLLM { public readonly parameters: Partial; constructor( - { client, modelId, parameters, executionOptions = {} }: Input = { + { client, modelId, parameters, executionOptions = {}, cache }: Input = { modelId: "llama3.1", }, ) { - super(modelId, executionOptions); + super(modelId, executionOptions, cache); this.client = client ?? new Client({ fetch }); this.parameters = parameters ?? { temperature: 0, diff --git a/src/adapters/ollama/llm.ts b/src/adapters/ollama/llm.ts index 3691f352..cfc97a61 100644 --- a/src/adapters/ollama/llm.ts +++ b/src/adapters/ollama/llm.ts @@ -23,6 +23,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMMeta, LLMOutputError, StreamGenerateOptions, @@ -41,6 +42,7 @@ interface Input { client?: Client; parameters?: Partial; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export class OllamaLLMOutput extends BaseLLMOutput { @@ -115,8 +117,8 @@ export class OllamaLLM extends LLM { registerClient(); } - constructor({ client, modelId, parameters, executionOptions = {} }: Input) { - super(modelId, executionOptions); + constructor({ client, modelId, parameters, executionOptions = {}, cache }: Input) { + super(modelId, executionOptions, cache); this.client = client ?? new Client(); this.parameters = parameters ?? {}; } diff --git a/src/adapters/openai/chat.ts b/src/adapters/openai/chat.ts index 3f4aa763..cfcdc4cf 100644 --- a/src/adapters/openai/chat.ts +++ b/src/adapters/openai/chat.ts @@ -20,6 +20,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMMeta, StreamGenerateOptions, } from "@/llms/base.js"; @@ -88,6 +89,7 @@ interface Input { client?: Client; parameters?: Partial; executionOptions?: ExecutionOptions; + cache?: LLMCache; } export class OpenAIChatLLM extends ChatLLM { @@ -99,8 +101,14 @@ export class OpenAIChatLLM extends ChatLLM { public readonly client: Client; public readonly parameters: Partial; - constructor({ client, modelId = "gpt-4o", parameters, executionOptions = {} }: Input = {}) { - super(modelId, executionOptions); + constructor({ + client, + modelId = "gpt-4o", + parameters, + executionOptions = {}, + cache, + }: Input = {}) { + super(modelId, executionOptions, cache); this.client = client ?? new Client(); this.parameters = parameters ?? {}; } @@ -154,7 +162,7 @@ export class OpenAIChatLLM extends ChatLLM { protected _prepareRequest( input: BaseMessage[], - options: GenerateOptions, + options?: GenerateOptions, ): Client.Chat.ChatCompletionCreateParams { return { ...this.parameters, @@ -181,7 +189,7 @@ export class OpenAIChatLLM extends ChatLLM { protected async _generate( input: BaseMessage[], - options: GenerateOptions, + options: GenerateOptions | undefined, run: GetRunContext, ): Promise { const response = await this.client.chat.completions.create( @@ -215,7 +223,7 @@ export class OpenAIChatLLM extends ChatLLM { protected async *_stream( input: BaseMessage[], - options: StreamGenerateOptions, + options: StreamGenerateOptions | undefined, run: GetRunContext, ): AsyncStream { for await (const chunk of await this.client.chat.completions.create( diff --git a/src/adapters/watsonx/chat.ts b/src/adapters/watsonx/chat.ts index f6c92aa3..ee6cc43c 100644 --- a/src/adapters/watsonx/chat.ts +++ b/src/adapters/watsonx/chat.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { AsyncStream, GenerateCallbacks, LLMError } from "@/llms/base.js"; +import { AsyncStream, GenerateCallbacks, LLMCache, LLMError } from "@/llms/base.js"; import { WatsonXLLM, WatsonXLLMGenerateOptions, @@ -84,6 +84,7 @@ export interface WatsonXChatLLMInputConfig { export interface WatsonXChatLLMInput { llm: WatsonXLLM; config: WatsonXChatLLMInputConfig; + cache?: LLMCache; } export class WatsonXChatLLM extends ChatLLM { @@ -96,8 +97,8 @@ export class WatsonXChatLLM extends ChatLLM { protected readonly config: WatsonXChatLLMInputConfig; public readonly parameters: WatsonXLLMParameters; - constructor({ llm, config }: WatsonXChatLLMInput) { - super(llm.modelId, llm.executionOptions); + constructor({ llm, config, cache }: WatsonXChatLLMInput) { + super(llm.modelId, llm.executionOptions, cache); this.parameters = llm.parameters ?? {}; this.llm = llm; this.config = config; @@ -133,7 +134,7 @@ export class WatsonXChatLLM extends ChatLLM { protected async _generate( messages: BaseMessage[], - options: WatsonXLLMGenerateOptions, + options: WatsonXLLMGenerateOptions | undefined, run: GetRunContext, ): Promise { const prompt = this.messagesToPrompt(messages); @@ -144,7 +145,7 @@ export class WatsonXChatLLM extends ChatLLM { protected async *_stream( messages: BaseMessage[], - options: WatsonXLLMGenerateOptions, + options: WatsonXLLMGenerateOptions | undefined, run: GetRunContext, ): AsyncStream { const prompt = this.messagesToPrompt(messages); diff --git a/src/adapters/watsonx/llm.ts b/src/adapters/watsonx/llm.ts index 5717b436..f51cdb70 100644 --- a/src/adapters/watsonx/llm.ts +++ b/src/adapters/watsonx/llm.ts @@ -22,6 +22,7 @@ import { ExecutionOptions, GenerateCallbacks, GenerateOptions, + LLMCache, LLMError, LLMFatalError, LLMMeta, @@ -179,6 +180,7 @@ export interface WatsonXLLMInput { moderations?: WatsonXLLMModerations; executionOptions?: ExecutionOptions; transform?: WatsonXLLMTransformFn; + cache?: LLMCache; } type WatsonXLLMTransformFn = (body: Record) => Record; @@ -283,7 +285,7 @@ export class WatsonXLLM extends LLM public readonly parameters: WatsonXLLMParameters; constructor(input: WatsonXLLMInput) { - super(input.modelId, input.executionOptions); + super(input.modelId, input.executionOptions, input.cache); this.projectId = input.projectId; this.spaceId = input.spaceId; this.deploymentId = input.deploymentId; @@ -380,7 +382,7 @@ export class WatsonXLLM extends LLM protected async _generate( input: LLMInput, - options: WatsonXLLMGenerateOptions, + options: WatsonXLLMGenerateOptions | undefined, run: GetRunContext, ): Promise { try { @@ -408,7 +410,7 @@ export class WatsonXLLM extends LLM protected async *_stream( input: LLMInput, - options: WatsonXLLMGenerateOptions, + options: WatsonXLLMGenerateOptions | undefined, run: GetRunContext, ): AsyncStream { try { diff --git a/src/llms/base.test.ts b/src/llms/base.test.ts index 9ada4d7e..758d6fe1 100644 --- a/src/llms/base.test.ts +++ b/src/llms/base.test.ts @@ -24,6 +24,9 @@ import { GenerateCallbacks, } from "./base.js"; import { Emitter } from "@/emitter/emitter.js"; +import { UnconstrainedCache } from "@/cache/unconstrainedCache.js"; +import { setTimeout } from "node:timers/promises"; +import { verifyDeserialization } from "@tests/e2e/utils.js"; describe("BaseLLM", () => { class DummyOutput extends BaseLLMOutput { @@ -53,6 +56,8 @@ describe("BaseLLM", () => { } class DummyLLM extends BaseLLM { + public throwErrorCount = 0; + public readonly emitter = Emitter.root.child({ namespace: ["dummy", "llm"], creator: this, @@ -67,20 +72,33 @@ describe("BaseLLM", () => { throw new Error("Method not implemented."); } - // eslint-disable-next-line unused-imports/no-unused-vars - protected _generate(input: string, options?: GenerateOptions): Promise { - throw new Error("Method not implemented."); + protected async _generate( + input: string, + options: GenerateOptions | undefined, + ): Promise { + options?.signal?.throwIfAborted(); + await setTimeout(200); + if (this.throwErrorCount > 0) { + this.throwErrorCount--; + throw new Error("Error has occurred"); + } + return new DummyOutput(input); } + protected async *_stream( input: string, - options?: StreamGenerateOptions, + options: StreamGenerateOptions | undefined, ): AsyncStream { for (const chunk of input.split(",")) { options?.signal?.throwIfAborted(); - await new Promise((resolve) => setTimeout(resolve, 100)); + await setTimeout(100); yield new DummyOutput(chunk); } } + + createSnapshot() { + return { ...super.createSnapshot(), throwErrorCount: this.throwErrorCount }; + } } it("Stops generating", async () => { @@ -103,4 +121,124 @@ describe("BaseLLM", () => { ); expect(chunks.join(",")).toBe("1,2,3"); }); + + describe("Caching", () => { + let model: DummyLLM; + beforeEach(() => { + model = new DummyLLM("my-model", {}, new UnconstrainedCache()); + }); + + const generate = async (input: unknown[], options?: GenerateOptions) => { + const chunks: string[] = []; + const events: string[] = []; + + await model.generate(input.join(","), options).observe((emitter) => { + emitter.registerCallbacks({ + newToken: ({ value }) => { + chunks.push(value.getTextContent()); + }, + }); + emitter.match("*.*", (_, event) => { + events.push(event.path); + }); + }); + + return { chunks, events }; + }; + + it("Handles streaming", async () => { + const [a, b] = await Promise.all([ + generate([1, 2, 3], { + stream: true, + }), + generate([1, 2, 3], { + stream: true, + }), + ]); + expect(a).toEqual(b); + await expect(model.cache.size()).resolves.toBe(1); + }); + + it("Handles non-streaming", async () => { + const [c, d] = await Promise.all([ + generate([1, 2, 3], { + stream: false, + }), + generate([1, 2, 3], { + stream: false, + }), + ]); + expect(c).toEqual(d); + await expect(model.cache.size()).resolves.toBe(1); + }); + + it("Correctly generates cache keys", async () => { + await expect(model.cache.size()).resolves.toBe(0); + + await generate(["a"]); + await expect(model.cache.size()).resolves.toBe(1); + + await generate(["a"], {}); + await expect(model.cache.size()).resolves.toBe(1); + await generate(["a"], { signal: AbortSignal.timeout(1000) }); + await expect(model.cache.size()).resolves.toBe(1); + + await generate(["a"], { stream: false }); + await expect(model.cache.size()).resolves.toBe(2); + + await generate(["a"], { stream: true }); + await expect(model.cache.size()).resolves.toBe(3); + await generate(["a"], { signal: AbortSignal.timeout(1500), stream: true }); + await expect(model.cache.size()).resolves.toBe(3); + + await generate(["a"], { guided: { regex: /.+/.source } }); + await expect(model.cache.size()).resolves.toBe(4); + await generate(["a"], { guided: { regex: /.+/.source } }); + await expect(model.cache.size()).resolves.toBe(4); + }); + + it("Clears cache", async () => { + await generate(["a"]); + await expect(model.cache.size()).resolves.toBe(1); + await model.cache.clear(); + await expect(model.cache.size()).resolves.toBe(0); + }); + + it("Ignores rejected values", async () => { + vi.useRealTimers(); + + model.throwErrorCount = 1; + for (const promise of await Promise.allSettled([ + setTimeout(0, model.generate("Test")), + setTimeout(0, model.generate("Test")), + setTimeout(0, model.generate("Test")), + ])) { + expect(promise).property("status").to.eq("rejected"); + } + await expect(model.cache.size()).resolves.toBe(0); + + await expect(model.generate("Test")).resolves.toBeTruthy(); + await expect(model.cache.size()).resolves.toBe(1); + }); + + it("Serializes with non-empty cache", async () => { + await expect(model.generate("Test")).resolves.toBeTruthy(); + await expect(model.cache.size()).resolves.toBe(1); + + const serialized = model.serialize(); + const deserialized = DummyLLM.fromSerialized(serialized); + verifyDeserialization(model, deserialized); + + await expect(deserialized.cache.size()).resolves.toBe(1); + await expect(deserialized.generate("Test")).resolves.toBeTruthy(); + await expect(deserialized.cache.size()).resolves.toBe(1); + }); + }); + + it("Serializes", () => { + const model = new DummyLLM("my-model"); + const serialized = model.serialize(); + const deserialized = DummyLLM.fromSerialized(serialized); + verifyDeserialization(model, deserialized); + }); }); diff --git a/src/llms/base.ts b/src/llms/base.ts index 7b82154b..6f870467 100644 --- a/src/llms/base.ts +++ b/src/llms/base.ts @@ -24,6 +24,11 @@ import { Callback } from "@/emitter/types.js"; import { shallowCopy } from "@/serializer/utils.js"; import { pRetry } from "@/internals/helpers/retry.js"; import { emitterToGenerator } from "@/internals/helpers/promise.js"; +import { BaseCache } from "@/cache/base.js"; +import { NullCache } from "@/cache/nullCache.js"; +import { ObjectHashKeyFn } from "@/cache/decoratorCache.js"; +import { omit } from "remeda"; +import { Task } from "promise-based-task"; export interface GenerateCallbacks { newToken?: Callback<{ value: BaseLLMOutput; callbacks: { abort: () => void } }>; @@ -111,6 +116,8 @@ export interface LLMMeta { tokenLimit: number; } +export type LLMCache = BaseCache>; + export abstract class BaseLLM< TInput, TOutput extends BaseLLMOutput, @@ -121,6 +128,7 @@ export abstract class BaseLLM< constructor( public readonly modelId: string, public readonly executionOptions: ExecutionOptions = {}, + public readonly cache: LLMCache = new NullCache(), ) { super(); } @@ -134,6 +142,8 @@ export abstract class BaseLLM< this, { params: [input, options] as const, signal: options?.signal }, async (run) => { + const cacheEntry = await this.createCacheAccessor(input, options); + try { await run.emitter.emit("start", { input, options }); @@ -142,14 +152,15 @@ export abstract class BaseLLM< const controller = createAbortController(options?.signal); const tokenEmitter = run.emitter.child({ groupId: "tokens" }); - for await (const chunk of this._stream( - input, - { - ...options, - signal: controller.signal, - }, - run, - )) { + for await (const chunk of cacheEntry.value ?? + this._stream( + input, + { + ...options, + signal: controller.signal, + }, + run, + )) { chunks.push(chunk); await tokenEmitter.emit("newToken", { value: chunk, @@ -162,19 +173,24 @@ export abstract class BaseLLM< const result = this._mergeChunks(chunks); await run.emitter.emit("success", { value: result }); + cacheEntry.resolve(chunks); return result; } - // @ts-expect-error types - const result = await pRetry(() => this._generate(input, options ?? {}, run), { - retries: this.executionOptions.maxRetries || 0, - ...options, - signal: run.signal, - }); + const result: TOutput = + cacheEntry?.value?.at(0) || + // @ts-expect-error types + (await pRetry(() => this._generate(input, options ?? {}, run), { + retries: this.executionOptions.maxRetries || 0, + ...options, + signal: run.signal, + })); await run.emitter.emit("success", { value: result }); + cacheEntry.resolve([result]); return result; } catch (error) { await run.emitter.emit("error", { input, error, options }); + await cacheEntry.reject(error); if (error instanceof LLMError) { throw error; } else { @@ -193,9 +209,14 @@ export abstract class BaseLLM< this, { params: [input, options] as const, signal: options?.signal }, async (run) => { - for await (const token of this._stream(input, options ?? {}, run)) { + const cacheEntry = await this.createCacheAccessor(input, options); + + const tokens: TOutput[] = []; + for await (const token of cacheEntry.value || this._stream(input, options ?? {}, run)) { + tokens.push(token); emit(token); } + cacheEntry.resolve(tokens); }, ); }); @@ -240,12 +261,43 @@ export abstract class BaseLLM< modelId: this.modelId, executionOptions: shallowCopy(this.executionOptions), emitter: this.emitter, + cache: this.cache, }; } loadSnapshot(snapshot: ReturnType) { Object.assign(this, snapshot); } + + protected async createCacheAccessor( + input: TInput, + options: GenerateOptions | StreamGenerateOptions | undefined, + ...extra: any[] + ) { + const key = ObjectHashKeyFn(input, omit(options ?? {}, ["signal"]), ...extra); + const value = await this.cache.get(key); + const isNew = value === undefined; + + let task: Task | null = null; + if (isNew) { + task = new Task(); + await this.cache.set(key, task); + } + + return { + key, + value, + resolve: (value: T2 | T2[]) => { + task?.resolve?.(Array.isArray(value) ? value : [value]); + }, + reject: async (error: Error) => { + task?.reject?.(error); + if (isNew) { + await this.cache.delete(key); + } + }, + }; + } } export type AnyLLM = BaseLLM;