Skip to content

Commit

Permalink
feat(llms)!: make "options" parameter for generate/stream always partial
Browse files Browse the repository at this point in the history
The "options" parameter in stream/run method now fallbacks to an empty object to allow overrides.

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Dec 3, 2024
1 parent ff65e0c commit 20fbe71
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 22 deletions.
5 changes: 3 additions & 2 deletions examples/llms/providers/customChatProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
GenerateOptions,
LLMCache,
LLMMeta,
StreamGenerateOptions,
} from "bee-agent-framework/llms/base";
import { shallowCopy } from "bee-agent-framework/serializer/utils";
import type { GetRunContext } from "bee-agent-framework/context";
Expand Down Expand Up @@ -86,7 +87,7 @@ export class CustomChatLLM extends ChatLLM<CustomChatLLMOutput, CustomGenerateOp

protected async _generate(
input: BaseMessage[],
options: CustomGenerateOptions,
options: Partial<CustomGenerateOptions>,
run: GetRunContext<this>,
): Promise<CustomChatLLMOutput> {
// this method should do non-stream request to the API
Expand All @@ -101,7 +102,7 @@ export class CustomChatLLM extends ChatLLM<CustomChatLLMOutput, CustomGenerateOp

protected async *_stream(
input: BaseMessage[],
options: CustomGenerateOptions,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<this>,
): AsyncStream<CustomChatLLMOutput, void> {
// this method should do stream request to the API
Expand Down
4 changes: 2 additions & 2 deletions examples/llms/providers/customProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export class CustomLLM extends LLM<CustomLLMOutput, CustomGenerateOptions> {

protected async _generate(
input: LLMInput,
options: CustomGenerateOptions,
options: Partial<CustomGenerateOptions>,
run: GetRunContext<this>,
): Promise<CustomLLMOutput> {
// this method should do non-stream request to the API
Expand All @@ -101,7 +101,7 @@ export class CustomLLM extends LLM<CustomLLMOutput, CustomGenerateOptions> {

protected async *_stream(
input: LLMInput,
options: CustomGenerateOptions,
options: Partial<CustomGenerateOptions>,
run: GetRunContext<this>,
): AsyncStream<CustomLLMOutput, void> {
// this method should do stream request to the API
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/bedrock/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ export class BedrockChatLLM extends ChatLLM<ChatBedrockOutput> {

protected async _generate(
input: BaseMessage[],
_options: GenerateOptions | undefined,
_options: Partial<GenerateOptions>,
run: GetRunContext<typeof this>,
): Promise<ChatBedrockOutput> {
const { conversation, systemMessage } = this.convertToConverseMessages(input);
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/groq/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ export class GroqChatLLM extends ChatLLM<ChatGroqOutput> {

protected async *_stream(
input: BaseMessage[],
options: StreamGenerateOptions,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<typeof this>,
): AsyncStream<ChatGroqOutput> {
for await (const chunk of await this.client.chat.completions.create(
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/langchain/llms/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export class LangChainLLM extends LLM<LangChainLLMOutput> {

protected async _generate(
input: LLMInput,
_options: GenerateOptions | undefined,
_options: Partial<GenerateOptions>,
run: GetRunContext<this>,
): Promise<LangChainLLMOutput> {
const { generations } = await this.lcLLM.generate([input], {
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/ollama/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ export class OllamaChatLLM extends ChatLLM<OllamaChatLLMOutput> {

protected async *_stream(
input: BaseMessage[],
options: StreamGenerateOptions,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<typeof this>,
): AsyncStream<OllamaChatLLMOutput> {
for await (const chunk of await this.client.chat({
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/ollama/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ export class OllamaLLM extends LLM<OllamaLLMOutput> {

protected async *_stream(
input: LLMInput,
options: StreamGenerateOptions,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<typeof this>,
): AsyncStream<OllamaLLMOutput, void> {
for await (const chunk of await this.client.generate({
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ export class OpenAIChatLLM extends ChatLLM<OpenAIChatLLMOutput> {

protected async _generate(
input: BaseMessage[],
options: GenerateOptions | undefined,
options: Partial<GenerateOptions>,
run: GetRunContext<typeof this>,
): Promise<OpenAIChatLLMOutput> {
const response = await this.client.chat.completions.create(
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/vertexai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerateOptions,
LLMCache,
LLMMeta,
StreamGenerateOptions,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
Expand Down Expand Up @@ -137,7 +138,7 @@ export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {

protected async *_stream(
input: BaseMessage[],
options: GenerateOptions | undefined,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<this>,
): AsyncStream<VertexAIChatLLMOutput, void> {
const generativeModel = createModel(
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/vertexai/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
GenerateOptions,
LLMCache,
LLMMeta,
StreamGenerateOptions,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
Expand Down Expand Up @@ -134,7 +135,7 @@ export class VertexAILLM extends LLM<VertexAILLMOutput> {

protected async *_stream(
input: LLMInput,
options: GenerateOptions | undefined,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<this>,
): AsyncStream<VertexAILLMOutput, void> {
const generativeModel = createModel(
Expand Down
6 changes: 3 additions & 3 deletions src/llms/base.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import {
BaseLLMTokenizeOutput,
StreamGenerateOptions,
AsyncStream,
BaseLLMOutput,
GenerateOptions,
BaseLLM,
BaseLLMEvents,
StreamGenerateOptions,
} from "./base.js";
import { Emitter } from "@/emitter/emitter.js";
import { UnconstrainedCache } from "@/cache/unconstrainedCache.js";
Expand Down Expand Up @@ -74,7 +74,7 @@ describe("BaseLLM", () => {

protected async _generate(
input: string,
options: GenerateOptions | undefined,
options: Partial<GenerateOptions>,
): Promise<DummyOutput> {
options?.signal?.throwIfAborted();
await setTimeout(200);
Expand All @@ -87,7 +87,7 @@ describe("BaseLLM", () => {

protected async *_stream(
input: string,
options: StreamGenerateOptions | undefined,
options: Partial<StreamGenerateOptions>,
): AsyncStream<DummyOutput, void> {
for (const chunk of input.split(",")) {
if (options?.signal?.aborted) {
Expand Down
19 changes: 12 additions & 7 deletions src/llms/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ export abstract class BaseLLM<

abstract tokenize(input: TInput): Promise<BaseLLMTokenizeOutput>;

generate(input: TInput, options?: TGenerateOptions) {
generate(input: TInput, options: Partial<TGenerateOptions> = {}) {
input = shallowCopy(input);
options = shallowCopy(options);

return RunContext.enter(
this,
{ params: [input, options] as const, signal: options?.signal },
Expand Down Expand Up @@ -187,8 +190,7 @@ export abstract class BaseLLM<

const result: TOutput =
cacheEntry?.value?.at(0) ||
// @ts-expect-error types
(await pRetry(() => this._generate(input, options ?? {}, run), {
(await pRetry(() => this._generate(input, options, run), {
retries: this.executionOptions.maxRetries || 0,
...options,
signal: run.signal,
Expand All @@ -211,7 +213,10 @@ export abstract class BaseLLM<
).middleware(INSTRUMENTATION_ENABLED ? createTelemetryMiddleware() : doNothing());
}

async *stream(input: TInput, options?: StreamGenerateOptions): AsyncStream<TOutput> {
async *stream(input: TInput, options: Partial<StreamGenerateOptions> = {}): AsyncStream<TOutput> {
input = shallowCopy(input);
options = shallowCopy(options);

return yield* emitterToGenerator(async ({ emit }) => {
return RunContext.enter(
this,
Expand All @@ -232,13 +237,13 @@ export abstract class BaseLLM<

protected abstract _generate(
input: TInput,
options: TGenerateOptions,
options: Partial<TGenerateOptions>,
run: GetRunContext<typeof this>,
): Promise<TOutput>;

protected abstract _stream(
input: TInput,
options: StreamGenerateOptions,
options: Partial<StreamGenerateOptions>,
run: GetRunContext<typeof this>,
): AsyncStream<TOutput, void>;

Expand Down Expand Up @@ -279,7 +284,7 @@ export abstract class BaseLLM<

protected async createCacheAccessor(
input: TInput,
options: GenerateOptions | StreamGenerateOptions | undefined,
options: Partial<GenerateOptions> | Partial<StreamGenerateOptions>,
...extra: any[]
) {
const key = ObjectHashKeyFn(input, omit(options ?? {}, ["signal"]), ...extra);
Expand Down

0 comments on commit 20fbe71

Please sign in to comment.