diff --git a/examples/llms/providers/customChatProvider.ts b/examples/llms/providers/customChatProvider.ts index e31fb059..9e5ae543 100644 --- a/examples/llms/providers/customChatProvider.ts +++ b/examples/llms/providers/customChatProvider.ts @@ -5,6 +5,7 @@ import { GenerateOptions, LLMCache, LLMMeta, + StreamGenerateOptions, } from "bee-agent-framework/llms/base"; import { shallowCopy } from "bee-agent-framework/serializer/utils"; import type { GetRunContext } from "bee-agent-framework/context"; @@ -86,7 +87,7 @@ export class CustomChatLLM extends ChatLLM, run: GetRunContext, ): Promise { // this method should do non-stream request to the API @@ -101,7 +102,7 @@ export class CustomChatLLM extends ChatLLM, run: GetRunContext, ): AsyncStream { // this method should do stream request to the API diff --git a/examples/llms/providers/customProvider.ts b/examples/llms/providers/customProvider.ts index a8b5ea0e..7aff13d7 100644 --- a/examples/llms/providers/customProvider.ts +++ b/examples/llms/providers/customProvider.ts @@ -86,7 +86,7 @@ export class CustomLLM extends LLM { protected async _generate( input: LLMInput, - options: CustomGenerateOptions, + options: Partial, run: GetRunContext, ): Promise { // this method should do non-stream request to the API @@ -101,7 +101,7 @@ export class CustomLLM extends LLM { protected async *_stream( input: LLMInput, - options: CustomGenerateOptions, + options: Partial, run: GetRunContext, ): AsyncStream { // this method should do stream request to the API diff --git a/src/adapters/bedrock/chat.ts b/src/adapters/bedrock/chat.ts index 38518215..baddd98e 100644 --- a/src/adapters/bedrock/chat.ts +++ b/src/adapters/bedrock/chat.ts @@ -208,7 +208,7 @@ export class BedrockChatLLM extends ChatLLM { protected async _generate( input: BaseMessage[], - _options: GenerateOptions | undefined, + _options: Partial, run: GetRunContext, ): Promise { const { conversation, systemMessage } = this.convertToConverseMessages(input); diff --git a/src/adapters/groq/chat.ts b/src/adapters/groq/chat.ts index 5521cd88..2d43ac46 100644 --- a/src/adapters/groq/chat.ts +++ b/src/adapters/groq/chat.ts @@ -209,7 +209,7 @@ export class GroqChatLLM extends ChatLLM { protected async *_stream( input: BaseMessage[], - options: StreamGenerateOptions, + options: Partial, run: GetRunContext, ): AsyncStream { for await (const chunk of await this.client.chat.completions.create( diff --git a/src/adapters/langchain/llms/llm.ts b/src/adapters/langchain/llms/llm.ts index 45676913..2ee3520b 100644 --- a/src/adapters/langchain/llms/llm.ts +++ b/src/adapters/langchain/llms/llm.ts @@ -110,7 +110,7 @@ export class LangChainLLM extends LLM { protected async _generate( input: LLMInput, - _options: GenerateOptions | undefined, + _options: Partial, run: GetRunContext, ): Promise { const { generations } = await this.lcLLM.generate([input], { diff --git a/src/adapters/ollama/chat.ts b/src/adapters/ollama/chat.ts index 61687385..dd4e3395 100644 --- a/src/adapters/ollama/chat.ts +++ b/src/adapters/ollama/chat.ts @@ -187,7 +187,7 @@ export class OllamaChatLLM extends ChatLLM { protected async *_stream( input: BaseMessage[], - options: StreamGenerateOptions, + options: Partial, run: GetRunContext, ): AsyncStream { for await (const chunk of await this.client.chat({ diff --git a/src/adapters/ollama/llm.ts b/src/adapters/ollama/llm.ts index c3afb438..3b490ce0 100644 --- a/src/adapters/ollama/llm.ts +++ b/src/adapters/ollama/llm.ts @@ -149,7 +149,7 @@ export class OllamaLLM extends LLM { protected async *_stream( input: LLMInput, - options: StreamGenerateOptions, + options: Partial, run: GetRunContext, ): AsyncStream { for await (const chunk of await this.client.generate({ diff --git a/src/adapters/openai/chat.ts b/src/adapters/openai/chat.ts index 89ba298e..3d9363d7 100644 --- a/src/adapters/openai/chat.ts +++ b/src/adapters/openai/chat.ts @@ -224,7 +224,7 @@ export class OpenAIChatLLM extends ChatLLM { protected async _generate( input: BaseMessage[], - options: GenerateOptions | undefined, + options: Partial, run: GetRunContext, ): Promise { const response = await this.client.chat.completions.create( diff --git a/src/adapters/vertexai/chat.ts b/src/adapters/vertexai/chat.ts index ab978800..428fcb75 100644 --- a/src/adapters/vertexai/chat.ts +++ b/src/adapters/vertexai/chat.ts @@ -21,6 +21,7 @@ import { GenerateOptions, LLMCache, LLMMeta, + StreamGenerateOptions, } from "@/llms/base.js"; import { shallowCopy } from "@/serializer/utils.js"; import type { GetRunContext } from "@/context.js"; @@ -137,7 +138,7 @@ export class VertexAIChatLLM extends ChatLLM { protected async *_stream( input: BaseMessage[], - options: GenerateOptions | undefined, + options: Partial, run: GetRunContext, ): AsyncStream { const generativeModel = createModel( diff --git a/src/adapters/vertexai/llm.ts b/src/adapters/vertexai/llm.ts index 3f8d842a..679be7f0 100644 --- a/src/adapters/vertexai/llm.ts +++ b/src/adapters/vertexai/llm.ts @@ -23,6 +23,7 @@ import { GenerateOptions, LLMCache, LLMMeta, + StreamGenerateOptions, } from "@/llms/base.js"; import { shallowCopy } from "@/serializer/utils.js"; import type { GetRunContext } from "@/context.js"; @@ -134,7 +135,7 @@ export class VertexAILLM extends LLM { protected async *_stream( input: LLMInput, - options: GenerateOptions | undefined, + options: Partial, run: GetRunContext, ): AsyncStream { const generativeModel = createModel( diff --git a/src/llms/base.test.ts b/src/llms/base.test.ts index 7fcdcabd..1d1bbbae 100644 --- a/src/llms/base.test.ts +++ b/src/llms/base.test.ts @@ -16,12 +16,12 @@ import { BaseLLMTokenizeOutput, - StreamGenerateOptions, AsyncStream, BaseLLMOutput, GenerateOptions, BaseLLM, BaseLLMEvents, + StreamGenerateOptions, } from "./base.js"; import { Emitter } from "@/emitter/emitter.js"; import { UnconstrainedCache } from "@/cache/unconstrainedCache.js"; @@ -74,7 +74,7 @@ describe("BaseLLM", () => { protected async _generate( input: string, - options: GenerateOptions | undefined, + options: Partial, ): Promise { options?.signal?.throwIfAborted(); await setTimeout(200); @@ -87,7 +87,7 @@ describe("BaseLLM", () => { protected async *_stream( input: string, - options: StreamGenerateOptions | undefined, + options: Partial, ): AsyncStream { for (const chunk of input.split(",")) { if (options?.signal?.aborted) { diff --git a/src/llms/base.ts b/src/llms/base.ts index 4561e122..f4c1322c 100644 --- a/src/llms/base.ts +++ b/src/llms/base.ts @@ -144,7 +144,10 @@ export abstract class BaseLLM< abstract tokenize(input: TInput): Promise; - generate(input: TInput, options?: TGenerateOptions) { + generate(input: TInput, options: Partial = {}) { + input = shallowCopy(input); + options = shallowCopy(options); + return RunContext.enter( this, { params: [input, options] as const, signal: options?.signal }, @@ -187,8 +190,7 @@ export abstract class BaseLLM< const result: TOutput = cacheEntry?.value?.at(0) || - // @ts-expect-error types - (await pRetry(() => this._generate(input, options ?? {}, run), { + (await pRetry(() => this._generate(input, options, run), { retries: this.executionOptions.maxRetries || 0, ...options, signal: run.signal, @@ -211,7 +213,10 @@ export abstract class BaseLLM< ).middleware(INSTRUMENTATION_ENABLED ? createTelemetryMiddleware() : doNothing()); } - async *stream(input: TInput, options?: StreamGenerateOptions): AsyncStream { + async *stream(input: TInput, options: Partial = {}): AsyncStream { + input = shallowCopy(input); + options = shallowCopy(options); + return yield* emitterToGenerator(async ({ emit }) => { return RunContext.enter( this, @@ -232,13 +237,13 @@ export abstract class BaseLLM< protected abstract _generate( input: TInput, - options: TGenerateOptions, + options: Partial, run: GetRunContext, ): Promise; protected abstract _stream( input: TInput, - options: StreamGenerateOptions, + options: Partial, run: GetRunContext, ): AsyncStream; @@ -279,7 +284,7 @@ export abstract class BaseLLM< protected async createCacheAccessor( input: TInput, - options: GenerateOptions | StreamGenerateOptions | undefined, + options: Partial | Partial, ...extra: any[] ) { const key = ObjectHashKeyFn(input, omit(options ?? {}, ["signal"]), ...extra);