From 5635923037a35ad3f9561f30a0e7d274a5985012 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 28 Sep 2024 02:27:55 -0400 Subject: [PATCH 01/11] serializer code for workflows --- packages/ai/package.json | 1 + packages/ai/src/Artifact.ts | 48 ++- packages/ai/src/Code.ts | 8 +- packages/ai/src/LLMClient.ts | 4 +- packages/ai/src/LLMStep.ts | 318 -------------- packages/ai/src/LLMStepInstance.ts | 389 ++++++++++++++++++ packages/ai/src/LLMStepTemplate.ts | 77 ++++ packages/ai/src/Logger.ts | 14 +- packages/ai/src/generateSummary.ts | 2 +- packages/ai/src/index.ts | 14 +- packages/ai/src/serialization.ts | 54 +++ packages/ai/src/steps/adjustToFeedbackStep.ts | 2 +- .../ai/src/steps/fixCodeUntilItRunsStep.ts | 2 +- packages/ai/src/steps/generateCodeStep.ts | 2 +- packages/ai/src/steps/registry.ts | 20 + packages/ai/src/steps/runAndFormatCodeStep.ts | 2 +- packages/ai/src/types.ts | 36 +- .../ai/src/workflows/ControlledWorkflow.ts | 6 +- packages/ai/src/workflows/Workflow.ts | 108 +++-- packages/ai/src/workflows/streaming.ts | 18 +- packages/hub/src/app/ai/AiDashboard.tsx | 6 +- packages/hub/src/app/ai/Sidebar.tsx | 6 +- packages/hub/src/app/ai/StepStatusIcon.tsx | 4 +- .../hub/src/app/ai/WorkflowStatusIcon.tsx | 4 +- .../hub/src/app/ai/WorkflowSummaryItem.tsx | 4 +- .../hub/src/app/ai/WorkflowSummaryList.tsx | 6 +- .../app/ai/WorkflowViewer/ArtifactDisplay.tsx | 8 +- .../ai/WorkflowViewer/ArtifactMessages.tsx | 6 +- .../hub/src/app/ai/WorkflowViewer/Header.tsx | 4 +- .../WorkflowViewer/SelectedNodeSideView.tsx | 8 +- .../src/app/ai/WorkflowViewer/StepNode.tsx | 4 +- .../app/ai/WorkflowViewer/WorkflowActions.tsx | 6 +- .../hub/src/app/ai/WorkflowViewer/index.tsx | 6 +- packages/hub/src/app/ai/api/create/route.ts | 4 +- packages/hub/src/app/ai/page.tsx | 9 +- .../hub/src/app/ai/useSquiggleWorkflows.tsx | 13 +- pnpm-lock.yaml | 3 + 37 files changed, 769 insertions(+), 457 deletions(-) delete mode 100644 packages/ai/src/LLMStep.ts create mode 100644 packages/ai/src/LLMStepInstance.ts create mode 100644 packages/ai/src/LLMStepTemplate.ts create mode 100644 packages/ai/src/serialization.ts create mode 100644 packages/ai/src/steps/registry.ts diff --git a/packages/ai/package.json b/packages/ai/package.json index aa4e277c0b..7864746d39 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -24,6 +24,7 @@ "@quri/prettier-plugin-squiggle": "workspace:*", "@quri/squiggle-lang": "workspace:*", "@quri/versioned-squiggle-components": "workspace:*", + "@quri/serializer": "workspace:*", "axios": "^1.7.2", "chalk": "^5.3.0", "clsx": "^2.1.1", diff --git a/packages/ai/src/Artifact.ts b/packages/ai/src/Artifact.ts index 10dced6ad7..926408f7dd 100644 --- a/packages/ai/src/Artifact.ts +++ b/packages/ai/src/Artifact.ts @@ -1,5 +1,5 @@ import { Code } from "./Code.js"; -import { LLMStepInstance } from "./LLMStep.js"; +import { LLMStepInstance } from "./LLMStepInstance.js"; export class BaseArtifact { public readonly id: string; @@ -78,3 +78,49 @@ export function makeArtifact( throw kind satisfies never; } } + +type ArtifactKindToSerialized = + K extends Artifact["kind"] + ? { + id: string; + kind: K; + value: ArtifactValue; + } + : never; + +export type SerializedArtifact = ArtifactKindToSerialized; + +export function serializeArtifact(artifact: Artifact): SerializedArtifact { + switch (artifact.kind) { + case "prompt": + case "source": + return { + id: artifact.id, + kind: artifact.kind, + value: artifact.value, + }; + case "code": + // copy-pasted but type-safe + return { + id: artifact.id, + kind: artifact.kind, + value: artifact.value, + }; + default: + throw artifact satisfies never; + } +} + +export function deserializeArtifact(serialized: SerializedArtifact): Artifact { + switch (serialized.kind) { + case "prompt": + // TODO - pass in createdBy somehow + return new PromptArtifact(serialized.value, undefined); + case "source": + return new SourceArtifact(serialized.value, undefined); + case "code": + return new CodeArtifact(serialized.value, undefined); + default: + throw serialized satisfies never; + } +} diff --git a/packages/ai/src/Code.ts b/packages/ai/src/Code.ts index 2a9a7001ce..0d05ed676b 100644 --- a/packages/ai/src/Code.ts +++ b/packages/ai/src/Code.ts @@ -5,7 +5,6 @@ import { result, simpleValueFromAny, simpleValueToCompactString, - SqError, SqErrorList, SqProject, } from "@quri/squiggle-lang"; @@ -27,7 +26,7 @@ export type Code = error: string; source: string; } - | { type: "runFailed"; source: string; error: SqError; project: SqProject } + | { type: "runFailed"; source: string; error: string } | { type: "success"; source: string; @@ -41,7 +40,7 @@ export function codeErrorString(code: Code): string { if (code.type === "formattingFailed") { return code.error; } else if (code.type === "runFailed") { - return code.error.toStringWithDetails(); + return code.error; } return ""; } @@ -99,8 +98,7 @@ export async function codeStringToCode(code: string): Promise { return { type: "runFailed", source: runningCode, - error: run.value.error.errors[0], - project: run.value.project, + error: run.value.error.errors[0].toStringWithDetails(), }; } diff --git a/packages/ai/src/LLMClient.ts b/packages/ai/src/LLMClient.ts index ffbc7093ef..36bf8860c0 100644 --- a/packages/ai/src/LLMClient.ts +++ b/packages/ai/src/LLMClient.ts @@ -85,12 +85,12 @@ function convertOpenAIToStandardFormat( }; } -export interface LlmMetrics { +export type LlmMetrics = { apiCalls: number; inputTokens: number; outputTokens: number; llmId: LlmId; -} +}; export function calculatePriceMultipleCalls( metrics: Partial> diff --git a/packages/ai/src/LLMStep.ts b/packages/ai/src/LLMStep.ts deleted file mode 100644 index 05137b91ff..0000000000 --- a/packages/ai/src/LLMStep.ts +++ /dev/null @@ -1,318 +0,0 @@ -import { - Artifact, - ArtifactKind, - BaseArtifact, - makeArtifact, -} from "./Artifact.js"; -import { - calculatePriceMultipleCalls, - LlmMetrics, - Message, -} from "./LLMClient.js"; -import { LogEntry, Logger, TimestampedLogEntry } from "./Logger.js"; -import { LlmId } from "./modelConfigs.js"; -import { PromptPair } from "./prompts.js"; -import { Workflow } from "./workflows/Workflow.js"; - -export type ErrorType = "CRITICAL" | "MINOR"; - -export type StepState = - | { - kind: "PENDING"; - } - | { - kind: "DONE"; - durationMs: number; - } - | { - kind: "FAILED"; - errorType: ErrorType; - durationMs: number; - message: string; - }; - -export type StepShape< - I extends Record = Record, - O extends Record = Record, -> = { - inputs: I; - outputs: O; -}; - -type ExecuteContext = { - setOutput>( - key: K, - value: Outputs[K] | Outputs[K]["value"] // can be either the artifact or the value inside the artifact - ): void; - queryLLM(promptPair: PromptPair): Promise; - log(log: LogEntry): void; - fail(errorType: ErrorType, message: string): void; - // workflow: Workflow; // intentionally not exposed, but if you need it, add it here -}; - -export type Inputs> = { - [K in keyof Shape["inputs"]]: Extract; -}; - -type Outputs> = { - [K in keyof Shape["outputs"]]: Extract< - Artifact, - { kind: Shape["outputs"][K] } - >; -}; - -export class LLMStepTemplate { - constructor( - public readonly name: string, - public readonly shape: Shape, - public readonly execute: ( - context: ExecuteContext, - inputs: Inputs - ) => Promise - ) {} - - instantiate( - workflow: Workflow, - inputs: Inputs, - retryingStep?: LLMStepInstance | undefined - ): LLMStepInstance { - return new LLMStepInstance(this, workflow, inputs, retryingStep); - } -} - -export class LLMStepInstance { - public id: string; - private logger: Logger; - private conversationMessages: Message[] = []; - public llmMetricsList: LlmMetrics[] = []; - private startTime: number; - private state: StepState = { kind: "PENDING" }; - private outputs: Partial> = {}; - - constructor( - public readonly template: LLMStepTemplate, - public readonly workflow: Workflow, - public readonly inputs: Inputs, - public retryingStep?: LLMStepInstance | undefined - ) { - this.startTime = Date.now(); - this.id = crypto.randomUUID(); - this.logger = new Logger(this.workflow.id, this.workflow.getStepCount()); - this.inputs = inputs; - } - - getLogs(): TimestampedLogEntry[] { - return this.logger.logs; - } - - isRetrying(): boolean { - return !!this.retryingStep; - } - - getConversationMessages(): Message[] { - return this.conversationMessages; - } - - async _run() { - if (this.state.kind !== "PENDING") { - return; - } - - const limits = this.workflow.checkResourceLimits(); - if (limits) { - this.fail("CRITICAL", limits); - return; - } - - const executeContext: ExecuteContext = { - setOutput: (key, value) => this.setOutput(key, value), - log: (log) => this.log(log), - queryLLM: (promptPair) => this.queryLLM(promptPair), - fail: (errorType, message) => this.fail(errorType, message), - }; - - try { - await this.template.execute(executeContext, this.inputs); - } catch (error) { - this.fail( - "MINOR", - error instanceof Error ? error.message : String(error) - ); - return; - } - - const hasFailed = (this.state as StepState).kind === "FAILED"; - - if (!hasFailed) { - this.state = { kind: "DONE", durationMs: this.calculateDuration() }; - } - } - - async run() { - this.log({ - type: "info", - message: `Step "${this.template.name}" started`, - }); - - await this._run(); - - const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${ - this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s` - }`; - - this.log({ - type: "info", - message: completionMessage, - }); - } - - getState() { - return this.state; - } - - getDuration() { - return this.state.kind === "PENDING" ? 0 : this.state.durationMs; - } - - getOutputs() { - return this.outputs; - } - - getInputs() { - return this.inputs; - } - - getTotalCost() { - const totalCost = calculatePriceMultipleCalls( - this.llmMetricsList.reduce( - (acc, metrics) => { - acc[metrics.llmId] = metrics; - return acc; - }, - {} as Record - ) - ); - - return totalCost; - } - - isGenerationStep() { - const stepHasCodeInput = Object.values(this.template.shape.inputs).some( - (kind) => kind === "source" || kind === "code" - ); - - const stepHasCodeOutput = Object.values(this.template.shape.outputs).some( - (kind) => kind === "source" || kind === "code" - ); - - return !stepHasCodeInput && stepHasCodeOutput; - } - - isDone() { - return this.state.kind === "DONE"; - } - - // private methods - - private setOutput>( - key: K, - value: Outputs[K] | Outputs[K]["value"] - ): void { - if (key in this.outputs) { - this.fail( - "CRITICAL", - `Output ${key} is already set. This is a bug with the workflow code.` - ); - return; - } - - if (value instanceof BaseArtifact) { - // already existing artifact - probably passed through from another step - this.outputs[key] = value; - } else { - const kind = this.template.shape.outputs[ - key - ] as Outputs[K]["kind"]; - this.outputs[key] = makeArtifact(kind, value as any, this) as any; - } - } - - private log(log: LogEntry): void { - this.logger.log(log); - } - - private fail(errorType: ErrorType, message: string) { - this.log({ type: "error", message }); - this.state = { - kind: "FAILED", - durationMs: this.calculateDuration(), - errorType, - message, - }; - } - - private calculateDuration() { - return Date.now() - this.startTime; - } - - private addConversationMessage(message: Message): void { - this.conversationMessages.push(message); - } - - private async queryLLM(promptPair: PromptPair): Promise { - try { - const workflow = this.workflow; - const messagesToSend: Message[] = [ - ...workflow.getRelevantPreviousConversationMessages( - workflow.llmConfig.messagesInHistoryToKeep - ), - { - role: "user", - content: promptPair.fullPrompt, - }, - ]; - const completion = await workflow.llmClient.run(messagesToSend); - - this.log({ - type: "llmResponse", - response: completion, - content: completion.content, - messages: messagesToSend, - prompt: promptPair.fullPrompt, - }); - - this.llmMetricsList.push({ - apiCalls: 1, - inputTokens: completion?.usage?.prompt_tokens ?? 0, - outputTokens: completion?.usage?.completion_tokens ?? 0, - llmId: workflow.llmConfig.llmId, - }); - - if (!completion?.content) { - this.log({ - type: "error", - message: "Received an empty response from the API", - }); - return null; - } else { - this.addConversationMessage({ - role: "user", - content: promptPair.summarizedPrompt, - }); - - this.addConversationMessage({ - role: "assistant", - content: completion?.content ?? "no response", - }); - } - - return completion.content; - } catch (error) { - this.fail( - "MINOR", - `Error in queryLLM: ${error instanceof Error ? error.message : error}` - ); - return null; - } - } -} diff --git a/packages/ai/src/LLMStepInstance.ts b/packages/ai/src/LLMStepInstance.ts new file mode 100644 index 0000000000..90901f61fe --- /dev/null +++ b/packages/ai/src/LLMStepInstance.ts @@ -0,0 +1,389 @@ +import { BaseArtifact, makeArtifact } from "./Artifact.js"; +import { + calculatePriceMultipleCalls, + LlmMetrics, + Message, +} from "./LLMClient.js"; +import { + ErrorType, + ExecuteContext, + Inputs, + LLMStepTemplate, + Outputs, + StepShape, + StepState, +} from "./LLMStepTemplate.js"; +import { LogEntry, Logger, TimestampedLogEntry } from "./Logger.js"; +import { LlmId } from "./modelConfigs.js"; +import { PromptPair } from "./prompts.js"; +import { + AiDeserializationVisitor, + AiSerializationVisitor, +} from "./serialization.js"; +import { getStepTemplateByName } from "./steps/registry.js"; +import { Workflow } from "./workflows/Workflow.js"; + +interface Params { + id: string; + sequentialId: number; + template: LLMStepTemplate; + state: StepState; + inputs: Inputs; + outputs: Partial>; + retryingStep?: LLMStepInstance; + startTime: number; + conversationMessages: Message[]; + llmMetricsList: LlmMetrics[]; +} + +export class LLMStepInstance { + public id: Params["id"]; + public sequentialId: number; + public readonly template: Params["template"]; + + private state: Params["state"]; + private outputs: Params["outputs"]; + public readonly inputs: Params["inputs"]; + + public retryingStep?: Params["retryingStep"]; + + private startTime: Params["startTime"]; + private logger: Logger; + private conversationMessages: Params["conversationMessages"]; + public llmMetricsList: Params["llmMetricsList"]; + + private constructor(params: Params) { + this.id = params.id; + this.sequentialId = params.sequentialId; + + this.llmMetricsList = params.llmMetricsList; + this.conversationMessages = params.conversationMessages; + + this.startTime = params.startTime; + this.state = params.state; + this.outputs = params.outputs; + + this.template = params.template; + this.inputs = params.inputs; + this.retryingStep = params.retryingStep; + + this.logger = new Logger(); + } + + // Create a new, PENDING step instance + static create(params: { + template: LLMStepInstance["template"]; + inputs: LLMStepInstance["inputs"]; + retryingStep: LLMStepInstance["retryingStep"]; + workflow: Workflow; + }): LLMStepInstance { + return new LLMStepInstance({ + id: crypto.randomUUID(), + sequentialId: params.workflow.getStepCount(), + conversationMessages: [], + llmMetricsList: [], + startTime: Date.now(), + state: { kind: "PENDING" }, + outputs: {}, + ...params, + }); + } + + getLogs(): TimestampedLogEntry[] { + return this.logger.logs; + } + + isRetrying(): boolean { + return !!this.retryingStep; + } + + getConversationMessages(): Message[] { + return this.conversationMessages; + } + + async _run(workflow: Workflow) { + if (this.state.kind !== "PENDING") { + return; + } + + const limits = workflow.checkResourceLimits(); + if (limits) { + this.fail("CRITICAL", limits, workflow); + return; + } + + const executeContext: ExecuteContext = { + setOutput: (key, value) => this.setOutput(key, value, workflow), + log: (log) => this.log(log, workflow), + queryLLM: (promptPair) => this.queryLLM(promptPair, workflow), + fail: (errorType, message) => this.fail(errorType, message, workflow), + }; + + try { + await this.template.execute(executeContext, this.inputs); + } catch (error) { + this.fail( + "MINOR", + error instanceof Error ? error.message : String(error), + workflow + ); + return; + } + + const hasFailed = (this.state as StepState).kind === "FAILED"; + + if (!hasFailed) { + this.state = { kind: "DONE", durationMs: this.calculateDuration() }; + } + } + + async run(workflow: Workflow) { + this.log( + { + type: "info", + message: `Step "${this.template.name}" started`, + }, + workflow + ); + + await this._run(workflow); + + const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${ + this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s` + }`; + + this.log( + { + type: "info", + message: completionMessage, + }, + workflow + ); + } + + getState() { + return this.state; + } + + getDuration() { + return this.state.kind === "PENDING" ? 0 : this.state.durationMs; + } + + getOutputs() { + return this.outputs; + } + + getInputs() { + return this.inputs; + } + + getTotalCost() { + const totalCost = calculatePriceMultipleCalls( + this.llmMetricsList.reduce( + (acc, metrics) => { + acc[metrics.llmId] = metrics; + return acc; + }, + {} as Record + ) + ); + + return totalCost; + } + + isGenerationStep() { + const stepHasCodeInput = Object.values(this.template.shape.inputs).some( + (kind) => kind === "source" || kind === "code" + ); + + const stepHasCodeOutput = Object.values(this.template.shape.outputs).some( + (kind) => kind === "source" || kind === "code" + ); + + return !stepHasCodeInput && stepHasCodeOutput; + } + + isDone() { + return this.state.kind === "DONE"; + } + + // private methods + + private setOutput>( + key: K, + value: Outputs[K] | Outputs[K]["value"], + workflow: Workflow + ): void { + if (key in this.outputs) { + this.fail( + "CRITICAL", + `Output ${key} is already set. This is a bug with the workflow code.`, + workflow + ); + return; + } + + if (value instanceof BaseArtifact) { + // already existing artifact - probably passed through from another step + this.outputs[key] = value; + } else { + const kind = this.template.shape.outputs[ + key + ] as Outputs[K]["kind"]; + this.outputs[key] = makeArtifact(kind, value as any, this) as any; + } + } + + private log(log: LogEntry, workflow: Workflow): void { + this.logger.log(log, { + workflowId: workflow.id, + stepIndex: this.sequentialId, + }); + } + + private fail(errorType: ErrorType, message: string, workflow: Workflow) { + this.log({ type: "error", message }, workflow); + this.state = { + kind: "FAILED", + durationMs: this.calculateDuration(), + errorType, + message, + }; + } + + private calculateDuration() { + return Date.now() - this.startTime; + } + + private addConversationMessage(message: Message): void { + this.conversationMessages.push(message); + } + + private async queryLLM( + promptPair: PromptPair, + workflow: Workflow + ): Promise { + try { + const messagesToSend: Message[] = [ + ...workflow.getRelevantPreviousConversationMessages( + workflow.llmConfig.messagesInHistoryToKeep + ), + { + role: "user", + content: promptPair.fullPrompt, + }, + ]; + const completion = await workflow.llmClient.run(messagesToSend); + + this.log( + { + type: "llmResponse", + response: completion, + content: completion.content, + messages: messagesToSend, + prompt: promptPair.fullPrompt, + }, + workflow + ); + + this.llmMetricsList.push({ + apiCalls: 1, + inputTokens: completion?.usage?.prompt_tokens ?? 0, + outputTokens: completion?.usage?.completion_tokens ?? 0, + llmId: workflow.llmConfig.llmId, + }); + + if (!completion?.content) { + this.log( + { + type: "error", + message: "Received an empty response from the API", + }, + workflow + ); + return null; + } else { + this.addConversationMessage({ + role: "user", + content: promptPair.summarizedPrompt, + }); + + this.addConversationMessage({ + role: "assistant", + content: completion?.content ?? "no response", + }); + } + + return completion.content; + } catch (error) { + this.fail( + "MINOR", + `Error in queryLLM: ${error instanceof Error ? error.message : error}`, + workflow + ); + return null; + } + } + + // Serialization/deserialization + + serialize(visitor: AiSerializationVisitor): SerializedStep { + return { + id: this.id, + sequentialId: this.sequentialId, + templateName: this.template.name, + state: this.state, + startTime: this.startTime, + conversationMessages: this.conversationMessages, + llmMetricsList: this.llmMetricsList, + inputIds: Object.fromEntries( + Object.entries(this.inputs).map(([key, input]) => [ + key, + visitor.artifact(input), + ]) + ), + outputIds: Object.fromEntries( + Object.entries(this.outputs).map(([key, output]) => [ + key, + visitor.artifact(output), + ]) + ), + }; + } + + static deserialize( + { templateName, inputIds, outputIds, ...params }: SerializedStep, + visitor: AiDeserializationVisitor + ): LLMStepInstance { + const template: LLMStepTemplate = getStepTemplateByName(templateName); + const inputs = Object.fromEntries( + Object.entries(inputIds).map(([name, inputId]) => [ + name, + visitor.artifact(inputId), + ]) + ); + const outputs = Object.fromEntries( + Object.entries(outputIds).map(([name, outputId]) => [ + name, + visitor.artifact(outputId), + ]) + ); + + return new LLMStepInstance({ + ...params, + template, + inputs, + outputs, + }); + } +} + +export type SerializedStep = Omit< + Params, + // TODO - serialize retryingStep reference + "inputs" | "outputs" | "template" | "retryingStep" +> & { + templateName: string; + inputIds: Record; + outputIds: Record; +}; diff --git a/packages/ai/src/LLMStepTemplate.ts b/packages/ai/src/LLMStepTemplate.ts new file mode 100644 index 0000000000..44af958cb7 --- /dev/null +++ b/packages/ai/src/LLMStepTemplate.ts @@ -0,0 +1,77 @@ +import { Artifact, ArtifactKind } from "./Artifact.js"; +import { LLMStepInstance } from "./LLMStepInstance.js"; +import { LogEntry } from "./Logger.js"; +import { PromptPair } from "./prompts.js"; +import { Workflow } from "./workflows/Workflow.js"; + +export type ErrorType = "CRITICAL" | "MINOR"; + +export type StepState = + | { + kind: "PENDING"; + } + | { + kind: "DONE"; + durationMs: number; + } + | { + kind: "FAILED"; + errorType: ErrorType; + durationMs: number; + message: string; + }; + +export type StepShape< + I extends Record = Record, + O extends Record = Record, +> = { + inputs: I; + outputs: O; +}; + +export type Inputs> = { + [K in keyof Shape["inputs"]]: Extract; +}; + +export type Outputs> = { + [K in keyof Shape["outputs"]]: Extract< + Artifact, + { kind: Shape["outputs"][K] } + >; +}; + +// ExecuteContext is the context that's available to the step implementation. +// We intentionally don't pass the reference to the step implementation, so that steps won't mess with their internal state. +export type ExecuteContext = { + setOutput>( + key: K, + value: Outputs[K] | Outputs[K]["value"] // can be either the artifact or the value inside the artifact + ): void; + queryLLM(promptPair: PromptPair): Promise; + log(log: LogEntry): void; + fail(errorType: ErrorType, message: string): void; +}; + +export class LLMStepTemplate { + constructor( + public readonly name: string, + public readonly shape: Shape, + public readonly execute: ( + context: ExecuteContext, + inputs: Inputs + ) => Promise + ) {} + + instantiate( + workflow: Workflow, + inputs: Inputs, + retryingStep?: LLMStepInstance | undefined + ): LLMStepInstance { + return LLMStepInstance.create({ + template: this, + inputs, + retryingStep, + workflow, + }); + } +} diff --git a/packages/ai/src/Logger.ts b/packages/ai/src/Logger.ts index 491377629f..f593d5ecb3 100644 --- a/packages/ai/src/Logger.ts +++ b/packages/ai/src/Logger.ts @@ -10,6 +10,11 @@ export type LogEntry = | HighlightLogEntry | LlmResponseLogEntry; +export type LoggerContext = { + workflowId: string; + stepIndex: number; +}; + export function getLogEntryFullName(entry: LogEntry): string { switch (entry.type) { case "info": @@ -125,13 +130,8 @@ type LlmResponseLogEntry = { export class Logger { logs: TimestampedLogEntry[] = []; - constructor( - private workflowId: string, - private stepIndex: number - ) {} - - log(log: LogEntry): void { + log(log: LogEntry, context: LoggerContext): void { this.logs.push({ timestamp: new Date(), entry: log }); - displayLog(log, this.workflowId, this.stepIndex); + displayLog(log, context.workflowId, context.stepIndex); } } diff --git a/packages/ai/src/generateSummary.ts b/packages/ai/src/generateSummary.ts index f4a19a5341..7fc4ad2414 100644 --- a/packages/ai/src/generateSummary.ts +++ b/packages/ai/src/generateSummary.ts @@ -206,7 +206,7 @@ ${code.source} **Error:** \`\`\` -${code.error.toStringWithDetails()} +${code.error} \`\`\` **Code:** diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index c77860f0fa..1f9ca82d97 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,12 +1,12 @@ export { type LlmConfig } from "./workflows/Workflow.js"; export { - type SerializedArtifact, - type SerializedMessage, - type SerializedStep, - type SerializedWorkflow, - serializedWorkflowSchema, - workflowMessageSchema, - type WorkflowResult, + type ClientArtifact, + type ClientMessage, + type ClientStep, + type ClientWorkflow, + type ClientWorkflowResult, + clientWorkflowSchema, + streamingMessageSchema, } from "./types.js"; export { llmLinker } from "./Code.js"; diff --git a/packages/ai/src/serialization.ts b/packages/ai/src/serialization.ts new file mode 100644 index 0000000000..018801188d --- /dev/null +++ b/packages/ai/src/serialization.ts @@ -0,0 +1,54 @@ +import { DeserializationVisitor, makeCodec } from "@quri/serializer"; + +import { SerializationVisitor } from "../../serializer/dist/serialization.js"; +import { + Artifact, + deserializeArtifact, + serializeArtifact, + SerializedArtifact, +} from "./Artifact.js"; +import { LLMStepInstance, SerializedStep } from "./LLMStepInstance.js"; +import { + LlmConfig, + SerializedWorkflow, + Workflow, +} from "./workflows/Workflow.js"; + +type AiShape = { + workflow: [Workflow, SerializedWorkflow]; + step: [LLMStepInstance, SerializedStep]; + artifact: [Artifact, SerializedArtifact]; +}; + +export type AiSerializationVisitor = SerializationVisitor; + +export type AiDeserializationVisitor = DeserializationVisitor; + +export function codecFactory( + llmConfig: LlmConfig, + openaiApiKey?: string, + anthropicApiKey?: string +) { + return makeCodec({ + workflow: { + serialize: (node, visitor) => node.serialize(visitor), + deserialize: (node, visitor) => + Workflow.deserialize({ + node, + visitor, + llmConfig, + openaiApiKey, + anthropicApiKey, + }), + }, + step: { + serialize: (node, visitor) => node.serialize(visitor), + deserialize: (node, visitor) => + LLMStepInstance.deserialize(node, visitor), + }, + artifact: { + serialize: serializeArtifact, + deserialize: deserializeArtifact, + }, + }); +} diff --git a/packages/ai/src/steps/adjustToFeedbackStep.ts b/packages/ai/src/steps/adjustToFeedbackStep.ts index 67b35ba087..a916eec14c 100644 --- a/packages/ai/src/steps/adjustToFeedbackStep.ts +++ b/packages/ai/src/steps/adjustToFeedbackStep.ts @@ -1,5 +1,5 @@ import { Code, codeStringToCode } from "../Code.js"; -import { LLMStepTemplate } from "../LLMStep.js"; +import { LLMStepTemplate } from "../LLMStepTemplate.js"; import { changeFormatPrompt, PromptPair } from "../prompts.js"; import { diffToNewCode } from "../squiggle/processSquiggleCode.js"; diff --git a/packages/ai/src/steps/fixCodeUntilItRunsStep.ts b/packages/ai/src/steps/fixCodeUntilItRunsStep.ts index 15e952e5f0..3637eb459c 100644 --- a/packages/ai/src/steps/fixCodeUntilItRunsStep.ts +++ b/packages/ai/src/steps/fixCodeUntilItRunsStep.ts @@ -1,5 +1,5 @@ import { Code, codeErrorString } from "../Code.js"; -import { LLMStepTemplate } from "../LLMStep.js"; +import { LLMStepTemplate } from "../LLMStepTemplate.js"; import { changeFormatPrompt, PromptPair } from "../prompts.js"; import { diffCompletionContentToCode } from "../squiggle/processSquiggleCode.js"; import { addLineNumbers } from "../squiggle/searchReplace.js"; diff --git a/packages/ai/src/steps/generateCodeStep.ts b/packages/ai/src/steps/generateCodeStep.ts index 8e31e9b98f..5e046b462e 100644 --- a/packages/ai/src/steps/generateCodeStep.ts +++ b/packages/ai/src/steps/generateCodeStep.ts @@ -1,4 +1,4 @@ -import { LLMStepTemplate } from "../LLMStep.js"; +import { LLMStepTemplate } from "../LLMStepTemplate.js"; import { PromptPair } from "../prompts.js"; import { generationCompletionContentToCode } from "../squiggle/processSquiggleCode.js"; diff --git a/packages/ai/src/steps/registry.ts b/packages/ai/src/steps/registry.ts new file mode 100644 index 0000000000..813c134de9 --- /dev/null +++ b/packages/ai/src/steps/registry.ts @@ -0,0 +1,20 @@ +import { adjustToFeedbackStep } from "./adjustToFeedbackStep.js"; +import { fixCodeUntilItRunsStep } from "./fixCodeUntilItRunsStep.js"; +import { generateCodeStep } from "./generateCodeStep.js"; +import { runAndFormatCodeStep } from "./runAndFormatCodeStep.js"; + +const templates = Object.fromEntries( + [ + adjustToFeedbackStep, + generateCodeStep, + fixCodeUntilItRunsStep, + runAndFormatCodeStep, + ].map((step) => [step.name, step]) +); + +export function getStepTemplateByName(name: string) { + if (!(name in templates)) { + throw new Error(`Step ${name} not found`); + } + return templates[name]; +} diff --git a/packages/ai/src/steps/runAndFormatCodeStep.ts b/packages/ai/src/steps/runAndFormatCodeStep.ts index 68a88e242b..deae894a65 100644 --- a/packages/ai/src/steps/runAndFormatCodeStep.ts +++ b/packages/ai/src/steps/runAndFormatCodeStep.ts @@ -1,5 +1,5 @@ import { codeStringToCode } from "../Code.js"; -import { LLMStepTemplate } from "../LLMStep.js"; +import { LLMStepTemplate } from "../LLMStepTemplate.js"; export const runAndFormatCodeStep = new LLMStepTemplate( "RunAndFormatCode", diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 3b31219ca8..8cd5a2392d 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -15,7 +15,7 @@ export const squiggleWorkflowInputSchema = z.discriminatedUnion("type", [ // Protocol for streaming workflow changes between server and client. -// SerializedArtifact type +// ClientArtifact type const commonArtifactFields = { id: z.string(), @@ -41,9 +41,9 @@ const artifactSchema = z.discriminatedUnion("kind", [ }), ]); -export type SerializedArtifact = z.infer; +export type ClientArtifact = z.infer; -// SerializedStep type +// ClientStep type const stepStateSchema = z.enum(["PENDING", "DONE", "FAILED"]); @@ -52,7 +52,7 @@ const messageSchema = z.object({ content: z.string(), }); -export type SerializedMessage = z.infer; +export type ClientMessage = z.infer; const stepSchema = z.object({ id: z.string(), @@ -63,9 +63,9 @@ const stepSchema = z.object({ messages: z.array(messageSchema), }); -export type SerializedStep = z.infer; +export type ClientStep = z.infer; -// Messages that incrementally update the SerializedWorkflow. +// Messages that incrementally update the ClientWorkflow. // They are using for streaming updates from the server to the client. // They are similar to Workflow events, but not exactly the same. They must be JSON-serializable. // See `addStreamingListeners` in workflows/streaming.ts for how they are used. @@ -89,7 +89,7 @@ const stepUpdatedSchema = stepSchema.partial().required({ // WorkflowResult type -export const workflowResultSchema = z.object({ +export const clientWorkflowResultSchema = z.object({ code: z.string().describe("Squiggle code snippet"), isValid: z.boolean(), totalPrice: z.number(), @@ -98,16 +98,16 @@ export const workflowResultSchema = z.object({ logSummary: z.string(), // markdown }); -export type WorkflowResult = z.infer; +export type ClientWorkflowResult = z.infer; -export const workflowMessageSchema = z.discriminatedUnion("kind", [ +export const streamingMessageSchema = z.discriminatedUnion("kind", [ z.object({ kind: z.literal("workflowStarted"), content: workflowStartedSchema, }), z.object({ kind: z.literal("finalResult"), - content: workflowResultSchema, + content: clientWorkflowResultSchema, }), z.object({ kind: z.literal("stepAdded"), @@ -119,11 +119,11 @@ export const workflowMessageSchema = z.discriminatedUnion("kind", [ }), ]); -export type WorkflowMessage = z.infer; +export type StreamingMessage = z.infer; // Client-side representation of a workflow -const commonWorkflowFields = { +const commonClientWorkflowFields = { id: z.string(), timestamp: z.number(), // milliseconds since epoch input: squiggleWorkflowInputSchema, // FIXME - SquiggleWorkflow-specific @@ -131,22 +131,22 @@ const commonWorkflowFields = { currentStep: z.string().optional(), }; -export const serializedWorkflowSchema = z.discriminatedUnion("status", [ +export const clientWorkflowSchema = z.discriminatedUnion("status", [ z.object({ - ...commonWorkflowFields, + ...commonClientWorkflowFields, status: z.literal("loading"), result: z.undefined(), }), z.object({ - ...commonWorkflowFields, + ...commonClientWorkflowFields, status: z.literal("finished"), - result: workflowResultSchema, + result: clientWorkflowResultSchema, }), z.object({ - ...commonWorkflowFields, + ...commonClientWorkflowFields, status: z.literal("error"), result: z.string(), }), ]); -export type SerializedWorkflow = z.infer; +export type ClientWorkflow = z.infer; diff --git a/packages/ai/src/workflows/ControlledWorkflow.ts b/packages/ai/src/workflows/ControlledWorkflow.ts index 5bb834ed59..9b52890c84 100644 --- a/packages/ai/src/workflows/ControlledWorkflow.ts +++ b/packages/ai/src/workflows/ControlledWorkflow.ts @@ -1,6 +1,6 @@ import { ReadableStream } from "stream/web"; -import { WorkflowResult } from "../types.js"; +import { ClientWorkflowResult } from "../types.js"; import { addStreamingListeners } from "./streaming.js"; import { LlmConfig, Workflow } from "./Workflow.js"; @@ -29,7 +29,7 @@ export abstract class ControlledWorkflow { openaiApiKey?: string; anthropicApiKey?: string; }) { - this.workflow = new Workflow( + this.workflow = Workflow.create( params.llmConfig, params.openaiApiKey, params.anthropicApiKey @@ -75,7 +75,7 @@ export abstract class ControlledWorkflow { } // Run workflow without streaming, only capture the final result - async runToResult(): Promise { + async runToResult(): Promise { this.startOrThrow(); this.configure(); diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index 9e12b6e1c5..335db6857b 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -5,15 +5,15 @@ import { LlmMetrics, Message, } from "../LLMClient.js"; -import { - Inputs, - LLMStepInstance, - LLMStepTemplate, - StepShape, -} from "../LLMStep.js"; +import { LLMStepInstance } from "../LLMStepInstance.js"; +import { Inputs, LLMStepTemplate, StepShape } from "../LLMStepTemplate.js"; import { TimestampedLogEntry } from "../Logger.js"; import { LlmId } from "../modelConfigs.js"; -import { WorkflowResult } from "../types.js"; +import { + AiDeserializationVisitor, + AiSerializationVisitor, +} from "../serialization.js"; +import { ClientWorkflowResult } from "../types.js"; export interface LlmConfig { llmId: LlmId; @@ -85,32 +85,40 @@ export type WorkflowEventListener = ( const MAX_RETRIES = 5; export class Workflow { - private steps: LLMStepInstance[] = []; - private priceLimit: number; - private durationLimitMs: number; + private steps: LLMStepInstance[]; + public llmConfig: LlmConfig; public llmClient: LLMClient; public id: string; public startTime: number; - constructor( - public llmConfig: LlmConfig = llmConfigDefault, - openaiApiKey?: string, - anthropicApiKey?: string - ) { - this.priceLimit = llmConfig.priceLimit; - this.durationLimitMs = llmConfig.durationLimitMinutes * 1000 * 60; - + private constructor(params: { + id?: string; + steps?: LLMStepInstance[]; + llmConfig?: LlmConfig; + openaiApiKey?: string; + anthropicApiKey?: string; + }) { + this.llmConfig = params.llmConfig ?? llmConfigDefault; this.startTime = Date.now(); - this.id = crypto.randomUUID(); + this.id = params.id ?? crypto.randomUUID(); + this.steps = params.steps ?? []; this.llmClient = new LLMClient( - llmConfig.llmId, - openaiApiKey, - anthropicApiKey + this.llmConfig.llmId, + params.openaiApiKey, + params.anthropicApiKey ); } + static create( + llmConfig: LlmConfig = llmConfigDefault, + openaiApiKey?: string, + anthropicApiKey?: string + ) { + return new Workflow({ llmConfig, openaiApiKey, anthropicApiKey }); + } + // This is a hook that ControlledWorkflow can use to prepare the workflow. // It's a bit of a hack; we need to dispatch this event after we configured the event handlers, // but before we add any steps. @@ -168,7 +176,7 @@ export class Workflow { type: "stepStarted", payload: { step }, }); - await step.run(); + await step.run(this); this.dispatchEvent({ type: "stepFinished", @@ -185,12 +193,15 @@ export class Workflow { } checkResourceLimits(): string | undefined { - if (Date.now() - this.startTime > this.durationLimitMs) { - return `Duration limit of ${this.durationLimitMs / 1000 / 60} minutes exceeded`; + if ( + Date.now() - this.startTime > + this.llmConfig.durationLimitMinutes * 1000 * 60 + ) { + return `Duration limit of ${this.llmConfig.durationLimitMinutes} minutes exceeded`; } - if (this.priceSoFar() > this.priceLimit) { - return `Price limit of $${this.priceLimit.toFixed(2)} exceeded`; + if (this.priceSoFar() > this.llmConfig.priceLimit) { + return `Price limit of $${this.llmConfig.priceLimit.toFixed(2)} exceeded`; } return undefined; } @@ -215,7 +226,7 @@ export class Workflow { return this.getCurrentStep()?.getState().kind !== "PENDING"; } - getFinalResult(): WorkflowResult { + getFinalResult(): ClientWorkflowResult { const finalStep = this.getCurrentStep(); if (!finalStep) { throw new Error("No steps found"); @@ -242,13 +253,15 @@ export class Workflow { const runTimeMs = endTime - this.startTime; const { totalPrice, llmRunCount } = this.getLlmMetrics(); + const logSummary = generateSummary(this); + return { code, isValid, totalPrice, runTimeMs, llmRunCount, - logSummary: generateSummary(this), + logSummary, }; } @@ -304,7 +317,7 @@ export class Workflow { .filter((step) => step.isDone()); // We always include the last generation step, and then at most `maxRecentSteps - 1` other steps. - let remainingNSteps = [ + const remainingNSteps = [ this.steps[lastGenerateCodeIndex], ...remainingSteps.slice(-(maxRecentSteps - 1)), ]; @@ -338,4 +351,39 @@ export class Workflow { listener as (event: Event) => void ); } + + // Serialization/deserialization + serialize(visitor: AiSerializationVisitor): SerializedWorkflow { + return { + id: this.id, + stepIds: this.steps.map(visitor.step), + }; + } + + static deserialize({ + node, + visitor, + llmConfig, + openaiApiKey, + anthropicApiKey, + }: { + node: SerializedWorkflow; + visitor: AiDeserializationVisitor; + llmConfig: LlmConfig; + openaiApiKey?: string; + anthropicApiKey?: string; + }): Workflow { + return new Workflow({ + id: node.id, + steps: node.stepIds.map(visitor.step), + llmConfig, + openaiApiKey, + anthropicApiKey, + }); + } } + +export type SerializedWorkflow = { + id: string; + stepIds: number[]; +}; diff --git a/packages/ai/src/workflows/streaming.ts b/packages/ai/src/workflows/streaming.ts index 85a7c3a4b2..553038e83f 100644 --- a/packages/ai/src/workflows/streaming.ts +++ b/packages/ai/src/workflows/streaming.ts @@ -5,15 +5,15 @@ import { import { Artifact } from "../Artifact.js"; import { - SerializedArtifact, - SerializedWorkflow, - WorkflowMessage, - workflowMessageSchema, + ClientArtifact, + ClientWorkflow, + StreamingMessage, + streamingMessageSchema, } from "../types.js"; import { type SquiggleWorkflowInput } from "./SquiggleWorkflow.js"; import { Workflow } from "./Workflow.js"; -export function serializeArtifact(value: Artifact): SerializedArtifact { +export function serializeArtifact(value: Artifact): ClientArtifact { const commonArtifactFields = { id: value.id, createdBy: value.createdBy?.id, @@ -55,7 +55,7 @@ export function addStreamingListeners( workflow: Workflow, controller: ReadableStreamController ) { - const send = (message: WorkflowMessage) => { + const send = (message: StreamingMessage) => { controller.enqueue(JSON.stringify(message) + "\n"); }; @@ -134,11 +134,11 @@ export async function decodeWorkflowFromReader({ // In the future, we should store input parameters in the Workflow object. input: SquiggleWorkflowInput; // This adds an initial version of the workflow. - addWorkflow: (workflow: SerializedWorkflow) => Promise; + addWorkflow: (workflow: ClientWorkflow) => Promise; // This signature might look complicated, but it matches the functional // version of `useState` state setter. setWorkflow: ( - cb: (workflow: SerializedWorkflow) => SerializedWorkflow + cb: (workflow: ClientWorkflow) => ClientWorkflow ) => Promise; }) { // eslint-disable-next-line no-constant-condition @@ -151,7 +151,7 @@ export async function decodeWorkflowFromReader({ // Note that these are streaming events. // They are easy to confuse with workflow events. - const event = workflowMessageSchema.parse(eventJson); + const event = streamingMessageSchema.parse(eventJson); switch (event.kind) { case "workflowStarted": { diff --git a/packages/hub/src/app/ai/AiDashboard.tsx b/packages/hub/src/app/ai/AiDashboard.tsx index 2a1cf83096..22597cde71 100644 --- a/packages/hub/src/app/ai/AiDashboard.tsx +++ b/packages/hub/src/app/ai/AiDashboard.tsx @@ -3,18 +3,18 @@ import { clsx } from "clsx"; import { FC, useRef, useState } from "react"; -import { SerializedWorkflow, WorkflowResult } from "@quri/squiggle-ai"; +import { ClientWorkflow, ClientWorkflowResult } from "@quri/squiggle-ai"; import { Sidebar } from "./Sidebar"; import { useSquiggleWorkflows } from "./useSquiggleWorkflows"; import { WorkflowViewer } from "./WorkflowViewer"; export type SquiggleResponse = { - result?: WorkflowResult; + result?: ClientWorkflowResult; currentStep?: string; }; -export const AiDashboard: FC<{ initialWorkflows: SerializedWorkflow[] }> = ({ +export const AiDashboard: FC<{ initialWorkflows: ClientWorkflow[] }> = ({ initialWorkflows, }) => { const { workflows, submitWorkflow, selectedWorkflow, selectWorkflow } = diff --git a/packages/hub/src/app/ai/Sidebar.tsx b/packages/hub/src/app/ai/Sidebar.tsx index 746521b0ef..65aa6f94fc 100644 --- a/packages/hub/src/app/ai/Sidebar.tsx +++ b/packages/hub/src/app/ai/Sidebar.tsx @@ -9,7 +9,7 @@ import { } from "react"; import { FormProvider, useForm } from "react-hook-form"; -import { LlmId, MODEL_CONFIGS, SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow, LlmId, MODEL_CONFIGS } from "@quri/squiggle-ai"; import { Button, SelectStringFormField, @@ -27,8 +27,8 @@ type Handle = { type Props = { submitWorkflow: (requestBody: CreateRequestBody) => void; selectWorkflow: (id: string) => void; - selectedWorkflow: SerializedWorkflow | undefined; - workflows: SerializedWorkflow[]; + selectedWorkflow: ClientWorkflow | undefined; + workflows: ClientWorkflow[]; }; type FormShape = { diff --git a/packages/hub/src/app/ai/StepStatusIcon.tsx b/packages/hub/src/app/ai/StepStatusIcon.tsx index c6ccd56cbe..d2b4452bf2 100644 --- a/packages/hub/src/app/ai/StepStatusIcon.tsx +++ b/packages/hub/src/app/ai/StepStatusIcon.tsx @@ -1,9 +1,9 @@ import { FC } from "react"; -import { SerializedStep } from "@quri/squiggle-ai"; +import { ClientStep } from "@quri/squiggle-ai"; import { CheckCircleIcon, ErrorIcon, RefreshIcon } from "@quri/ui"; -export const StepStatusIcon: FC<{ step: SerializedStep }> = ({ step }) => { +export const StepStatusIcon: FC<{ step: ClientStep }> = ({ step }) => { switch (step.state) { case "PENDING": return ; diff --git a/packages/hub/src/app/ai/WorkflowStatusIcon.tsx b/packages/hub/src/app/ai/WorkflowStatusIcon.tsx index 4c7ebf2709..9fa043511d 100644 --- a/packages/hub/src/app/ai/WorkflowStatusIcon.tsx +++ b/packages/hub/src/app/ai/WorkflowStatusIcon.tsx @@ -1,9 +1,9 @@ import { FC } from "react"; -import { SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { CheckCircleIcon, ErrorIcon, RefreshIcon } from "@quri/ui"; -export const WorkflowStatusIcon: FC<{ workflow: SerializedWorkflow }> = ({ +export const WorkflowStatusIcon: FC<{ workflow: ClientWorkflow }> = ({ workflow, }) => { switch (workflow.status) { diff --git a/packages/hub/src/app/ai/WorkflowSummaryItem.tsx b/packages/hub/src/app/ai/WorkflowSummaryItem.tsx index 2967fbb9cf..2884765029 100644 --- a/packages/hub/src/app/ai/WorkflowSummaryItem.tsx +++ b/packages/hub/src/app/ai/WorkflowSummaryItem.tsx @@ -3,12 +3,12 @@ import clsx from "clsx"; import { FC } from "react"; -import { SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { WorkflowStatusIcon } from "./WorkflowStatusIcon"; export const WorkflowSummaryItem: FC<{ - workflow: SerializedWorkflow; + workflow: ClientWorkflow; onSelect: () => void; isSelected: boolean; }> = ({ workflow, onSelect, isSelected }) => { diff --git a/packages/hub/src/app/ai/WorkflowSummaryList.tsx b/packages/hub/src/app/ai/WorkflowSummaryList.tsx index d5cc4ead05..d949b8c510 100644 --- a/packages/hub/src/app/ai/WorkflowSummaryList.tsx +++ b/packages/hub/src/app/ai/WorkflowSummaryList.tsx @@ -1,13 +1,13 @@ import { orderBy } from "lodash"; import { FC } from "react"; -import { SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { WorkflowSummaryItem } from "./WorkflowSummaryItem"; export const WorkflowSummaryList: FC<{ - workflows: SerializedWorkflow[]; - selectedWorkflow: SerializedWorkflow | undefined; + workflows: ClientWorkflow[]; + selectedWorkflow: ClientWorkflow | undefined; selectWorkflow: (id: string) => void; }> = ({ workflows, selectedWorkflow, selectWorkflow }) => { const sortedWorkflows = orderBy(workflows, ["timestamp"], ["desc"]); diff --git a/packages/hub/src/app/ai/WorkflowViewer/ArtifactDisplay.tsx b/packages/hub/src/app/ai/WorkflowViewer/ArtifactDisplay.tsx index ab0e7c54f8..82c0257920 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/ArtifactDisplay.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/ArtifactDisplay.tsx @@ -1,13 +1,13 @@ import clsx from "clsx"; import { FC, ReactNode } from "react"; -import { SerializedArtifact } from "@quri/squiggle-ai"; +import { ClientArtifact } from "@quri/squiggle-ai"; import { CodeBracketIcon, CommentIcon, Tooltip } from "@quri/ui"; type ArtifactKind = "source" | "prompt" | "code"; type ArtifactIconSize = number; -function getArtifactColor(artifact: SerializedArtifact): string { +function getArtifactColor(artifact: ClientArtifact): string { if (artifact.kind === "code" && !artifact.ok) { return "bg-red-300"; } @@ -32,7 +32,7 @@ function getArtifactIcon(kind: ArtifactKind, size: number): ReactNode { } export const ArtifactIcon: FC<{ - artifact: SerializedArtifact; + artifact: ClientArtifact; size: ArtifactIconSize; }> = ({ artifact, size }) => { const bgColor = getArtifactColor(artifact); @@ -46,7 +46,7 @@ export const ArtifactIcon: FC<{ export const ArtifactDisplay: FC<{ name: string; - artifact: SerializedArtifact; + artifact: ClientArtifact; size: ArtifactIconSize; showArtifactName?: boolean; }> = ({ name, artifact, size, showArtifactName = false }) => { diff --git a/packages/hub/src/app/ai/WorkflowViewer/ArtifactMessages.tsx b/packages/hub/src/app/ai/WorkflowViewer/ArtifactMessages.tsx index 780665a54e..8f85a5ade5 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/ArtifactMessages.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/ArtifactMessages.tsx @@ -1,9 +1,9 @@ import clsx from "clsx"; import { FC } from "react"; -import { SerializedMessage } from "@quri/squiggle-ai"; +import { ClientMessage } from "@quri/squiggle-ai"; -const Message: FC<{ message: SerializedMessage }> = ({ message }) => { +const Message: FC<{ message: ClientMessage }> = ({ message }) => { const isUser = message.role === "user"; return (
= ({ message }) => { ); }; -export const ArtifactMessages: FC<{ messages: SerializedMessage[] }> = ({ +export const ArtifactMessages: FC<{ messages: ClientMessage[] }> = ({ messages, }) => (
diff --git a/packages/hub/src/app/ai/WorkflowViewer/Header.tsx b/packages/hub/src/app/ai/WorkflowViewer/Header.tsx index 57d02e32b5..4ca379b6c4 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/Header.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/Header.tsx @@ -1,6 +1,6 @@ import { FC, ReactNode } from "react"; -import { SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { Button } from "@quri/ui"; import { WorkflowStatusIcon } from "../WorkflowStatusIcon"; @@ -9,7 +9,7 @@ import { WorkflowStatusIcon } from "../WorkflowStatusIcon"; export const Header: FC<{ renderLeft: () => ReactNode; renderRight: () => ReactNode; - workflow: SerializedWorkflow; + workflow: ClientWorkflow; expanded: boolean; setExpanded: (expanded: boolean) => void; }> = ({ renderLeft, renderRight, workflow, expanded, setExpanded }) => { diff --git a/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx b/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx index d4a9e46943..b8fdf8b356 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx @@ -1,7 +1,7 @@ import clsx from "clsx"; import { FC, useMemo } from "react"; -import { SerializedArtifact, SerializedStep } from "@quri/squiggle-ai"; +import { ClientArtifact, ClientStep } from "@quri/squiggle-ai"; import { ChevronLeftIcon, ChevronRightIcon, XIcon } from "@quri/ui"; import { useAvailableHeight } from "@/hooks/useAvailableHeight"; @@ -33,7 +33,7 @@ const NavButton: FC<{ const ArtifactList: FC<{ title: string; - artifacts: Record; + artifacts: Record; }> = ({ title, artifacts }) => { return (
@@ -54,7 +54,7 @@ const ArtifactList: FC<{ }; export const SelectedNodeSideView: FC<{ - selectedNode: SerializedStep; + selectedNode: ClientStep; onClose: () => void; onSelectPreviousNode?: () => void; onSelectNextNode?: () => void; @@ -63,7 +63,7 @@ export const SelectedNodeSideView: FC<{ const selectedNodeCodeOutput = useMemo(() => { return Object.values(selectedNode.outputs).find( - (output): output is SerializedArtifact & { kind: "code" } => + (output): output is ClientArtifact & { kind: "code" } => output.kind === "code" ); }, [selectedNode]); diff --git a/packages/hub/src/app/ai/WorkflowViewer/StepNode.tsx b/packages/hub/src/app/ai/WorkflowViewer/StepNode.tsx index b3afd637b1..722f9502ac 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/StepNode.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/StepNode.tsx @@ -1,13 +1,13 @@ import clsx from "clsx"; import { FC, MouseEvent } from "react"; -import { SerializedStep } from "@quri/squiggle-ai"; +import { ClientStep } from "@quri/squiggle-ai"; import { StepStatusIcon } from "../StepStatusIcon"; import { ArtifactDisplay } from "./ArtifactDisplay"; type StepNodeProps = { - data: SerializedStep; + data: ClientStep; onClick?: (event: MouseEvent) => void; isSelected?: boolean; stepNumber: number; diff --git a/packages/hub/src/app/ai/WorkflowViewer/WorkflowActions.tsx b/packages/hub/src/app/ai/WorkflowViewer/WorkflowActions.tsx index 5a955271a0..bef73bafa9 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/WorkflowActions.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/WorkflowActions.tsx @@ -1,14 +1,14 @@ import { FC, useEffect, useRef, useState } from "react"; -import { SerializedStep, SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientStep, ClientWorkflow } from "@quri/squiggle-ai"; import { SelectedNodeSideView } from "./SelectedNodeSideView"; import { StepNode } from "./StepNode"; export const WorkflowActions: FC<{ - workflow: SerializedWorkflow; + workflow: ClientWorkflow; height: number; - onNodeClick?: (node: SerializedStep) => void; + onNodeClick?: (node: ClientStep) => void; }> = ({ workflow, height, onNodeClick }) => { const [selectedNodeIndex, setSelectedNodeIndex] = useState( workflow.steps.length - 1 diff --git a/packages/hub/src/app/ai/WorkflowViewer/index.tsx b/packages/hub/src/app/ai/WorkflowViewer/index.tsx index cc195c3479..24c72a9cd9 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/index.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/index.tsx @@ -1,7 +1,7 @@ "use client"; import { FC } from "react"; -import { SerializedWorkflow } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { Button, StyledTab } from "@quri/ui"; import { useAvailableHeight } from "@/hooks/useAvailableHeight"; @@ -13,9 +13,9 @@ import { Header } from "./Header"; import { WorkflowActions } from "./WorkflowActions"; type WorkflowViewerProps< - T extends SerializedWorkflow["status"] = SerializedWorkflow["status"], + T extends ClientWorkflow["status"] = ClientWorkflow["status"], > = { - workflow: Extract; + workflow: Extract; onFix: (code: string) => void; expanded: boolean; setExpanded: (expanded: boolean) => void; diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index a03f0c40b1..371e05a316 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -1,9 +1,9 @@ import { getServerSession } from "next-auth"; import { + ClientWorkflow, decodeWorkflowFromReader, LlmConfig, - SerializedWorkflow, SquiggleWorkflowInput, } from "@quri/squiggle-ai"; import { SquiggleWorkflow } from "@quri/squiggle-ai/server"; @@ -30,7 +30,7 @@ export async function POST(req: Request) { const user = await getSelf(session); - let workflow: SerializedWorkflow; + let workflow: ClientWorkflow; const streamToDatabase = ( stream: ReadableStream, diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index e36960fca3..3d362eed59 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -1,7 +1,4 @@ -import { - SerializedWorkflow, - serializedWorkflowSchema, -} from "@quri/squiggle-ai"; +import { ClientWorkflow, clientWorkflowSchema } from "@quri/squiggle-ai"; import { prisma } from "@/prisma"; import { getUserOrRedirect } from "@/server/helpers"; @@ -22,7 +19,7 @@ export default async function AiPage() { const workflows = rows.map((row) => { try { - return serializedWorkflowSchema.parse(row.workflow); + return clientWorkflowSchema.parse(row.workflow); } catch (e) { return { id: row.id, @@ -34,7 +31,7 @@ export default async function AiPage() { }, steps: [], result: "Invalid workflow format in the database", - } satisfies SerializedWorkflow; + } satisfies ClientWorkflow; } }); diff --git a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx index d337ef59c4..6b56b823be 100644 --- a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx +++ b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx @@ -1,23 +1,20 @@ import { useCallback, useState } from "react"; import { + ClientWorkflow, decodeWorkflowFromReader, - SerializedWorkflow, SquiggleWorkflowInput, } from "@quri/squiggle-ai"; import { bodyToLineReader, CreateRequestBody, requestToInput } from "./utils"; -export function useSquiggleWorkflows(initialWorkflows: SerializedWorkflow[]) { +export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { const [workflows, setWorkflows] = - useState(initialWorkflows); + useState(initialWorkflows); const [selected, setSelected] = useState(undefined); const updateWorkflow = useCallback( - ( - id: string, - update: (workflow: SerializedWorkflow) => SerializedWorkflow - ) => { + (id: string, update: (workflow: ClientWorkflow) => ClientWorkflow) => { setWorkflows((workflows) => workflows.map((workflow) => { return workflow.id === id ? update(workflow) : workflow; @@ -31,7 +28,7 @@ export function useSquiggleWorkflows(initialWorkflows: SerializedWorkflow[]) { (input: SquiggleWorkflowInput) => { // This will be replaced with a real workflow once we receive the first message from the server. const id = `loading-${Date.now().toString()}`; - const workflow: SerializedWorkflow = { + const workflow: ClientWorkflow = { id, timestamp: new Date().getTime(), status: "loading", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d575b60c38..103e6bfcef 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -35,6 +35,9 @@ importers: '@quri/prettier-plugin-squiggle': specifier: workspace:* version: link:../prettier-plugin + '@quri/serializer': + specifier: workspace:* + version: link:../serializer '@quri/squiggle-lang': specifier: workspace:* version: link:../squiggle-lang From 5e4b5a64c390499df753d58e06d39afc5f9b9e38 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 28 Sep 2024 14:11:05 -0400 Subject: [PATCH 02/11] new db format --- packages/ai/src/LLMStepTemplate.ts | 15 --- packages/ai/src/serialization.ts | 21 ++-- packages/ai/src/server.ts | 4 + packages/ai/src/steps/registry.ts | 18 +-- packages/ai/src/workflows/Workflow.ts | 63 +++++++--- packages/ai/src/workflows/streaming.ts | 35 +++++- .../migration.sql | 5 + packages/hub/prisma/schema.prisma | 5 +- packages/hub/src/app/ai/api/create/route.ts | 111 ++++++++---------- packages/hub/src/app/ai/page.tsx | 21 +++- packages/hub/src/app/ai/serverUtils.ts | 26 ++++ 11 files changed, 202 insertions(+), 122 deletions(-) create mode 100644 packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql create mode 100644 packages/hub/src/app/ai/serverUtils.ts diff --git a/packages/ai/src/LLMStepTemplate.ts b/packages/ai/src/LLMStepTemplate.ts index 44af958cb7..a372734e43 100644 --- a/packages/ai/src/LLMStepTemplate.ts +++ b/packages/ai/src/LLMStepTemplate.ts @@ -1,8 +1,6 @@ import { Artifact, ArtifactKind } from "./Artifact.js"; -import { LLMStepInstance } from "./LLMStepInstance.js"; import { LogEntry } from "./Logger.js"; import { PromptPair } from "./prompts.js"; -import { Workflow } from "./workflows/Workflow.js"; export type ErrorType = "CRITICAL" | "MINOR"; @@ -61,17 +59,4 @@ export class LLMStepTemplate { inputs: Inputs ) => Promise ) {} - - instantiate( - workflow: Workflow, - inputs: Inputs, - retryingStep?: LLMStepInstance | undefined - ): LLMStepInstance { - return LLMStepInstance.create({ - template: this, - inputs, - retryingStep, - workflow, - }); - } } diff --git a/packages/ai/src/serialization.ts b/packages/ai/src/serialization.ts index 018801188d..b862602cfd 100644 --- a/packages/ai/src/serialization.ts +++ b/packages/ai/src/serialization.ts @@ -8,11 +8,7 @@ import { SerializedArtifact, } from "./Artifact.js"; import { LLMStepInstance, SerializedStep } from "./LLMStepInstance.js"; -import { - LlmConfig, - SerializedWorkflow, - Workflow, -} from "./workflows/Workflow.js"; +import { SerializedWorkflow, Workflow } from "./workflows/Workflow.js"; type AiShape = { workflow: [Workflow, SerializedWorkflow]; @@ -24,11 +20,11 @@ export type AiSerializationVisitor = SerializationVisitor; export type AiDeserializationVisitor = DeserializationVisitor; -export function codecFactory( - llmConfig: LlmConfig, - openaiApiKey?: string, - anthropicApiKey?: string -) { +export function makeAiCodec(params: { + // we don't want to serialize these secrets, so we parameterize the codec with them + openaiApiKey?: string; + anthropicApiKey?: string; +}) { return makeCodec({ workflow: { serialize: (node, visitor) => node.serialize(visitor), @@ -36,9 +32,8 @@ export function codecFactory( Workflow.deserialize({ node, visitor, - llmConfig, - openaiApiKey, - anthropicApiKey, + openaiApiKey: params.openaiApiKey, + anthropicApiKey: params.anthropicApiKey, }), }, step: { diff --git a/packages/ai/src/server.ts b/packages/ai/src/server.ts index b21655a85f..b62e1aa6ef 100644 --- a/packages/ai/src/server.ts +++ b/packages/ai/src/server.ts @@ -3,3 +3,7 @@ export { SquiggleWorkflow, SquiggleWorkflowInput, } from "./workflows/SquiggleWorkflow.js"; + +export { Workflow } from "./workflows/Workflow.js"; + +export { makeAiCodec } from "./serialization.js"; diff --git a/packages/ai/src/steps/registry.ts b/packages/ai/src/steps/registry.ts index 813c134de9..1859669333 100644 --- a/packages/ai/src/steps/registry.ts +++ b/packages/ai/src/steps/registry.ts @@ -3,16 +3,16 @@ import { fixCodeUntilItRunsStep } from "./fixCodeUntilItRunsStep.js"; import { generateCodeStep } from "./generateCodeStep.js"; import { runAndFormatCodeStep } from "./runAndFormatCodeStep.js"; -const templates = Object.fromEntries( - [ - adjustToFeedbackStep, - generateCodeStep, - fixCodeUntilItRunsStep, - runAndFormatCodeStep, - ].map((step) => [step.name, step]) -); - export function getStepTemplateByName(name: string) { + const templates = Object.fromEntries( + [ + adjustToFeedbackStep, + generateCodeStep, + fixCodeUntilItRunsStep, + runAndFormatCodeStep, + ].map((step) => [step.name, step]) + ); + if (!(name in templates)) { throw new Error(`Step ${name} not found`); } diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index 335db6857b..e43b006ae7 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -13,14 +13,15 @@ import { AiDeserializationVisitor, AiSerializationVisitor, } from "../serialization.js"; -import { ClientWorkflowResult } from "../types.js"; +import { ClientWorkflow, ClientWorkflowResult } from "../types.js"; +import { stepToClientStep } from "./streaming.js"; -export interface LlmConfig { +export type LlmConfig = { llmId: LlmId; priceLimit: number; durationLimitMinutes: number; messagesInHistoryToKeep: number; -} +}; export const llmConfigDefault: LlmConfig = { llmId: "Claude-Sonnet", @@ -85,21 +86,22 @@ export type WorkflowEventListener = ( const MAX_RETRIES = 5; export class Workflow { + public id: string; + public llmConfig: LlmConfig; + public startTime: number; + private steps: LLMStepInstance[]; - public llmConfig: LlmConfig; public llmClient: LLMClient; - public id: string; - public startTime: number; private constructor(params: { - id?: string; - steps?: LLMStepInstance[]; - llmConfig?: LlmConfig; + id: string; + steps: LLMStepInstance[]; + llmConfig: LlmConfig; openaiApiKey?: string; anthropicApiKey?: string; }) { - this.llmConfig = params.llmConfig ?? llmConfigDefault; + this.llmConfig = params.llmConfig; this.startTime = Date.now(); this.id = params.id ?? crypto.randomUUID(); this.steps = params.steps ?? []; @@ -116,7 +118,13 @@ export class Workflow { openaiApiKey?: string, anthropicApiKey?: string ) { - return new Workflow({ llmConfig, openaiApiKey, anthropicApiKey }); + return new Workflow({ + id: crypto.randomUUID(), + steps: [], + llmConfig, + openaiApiKey, + anthropicApiKey, + }); } // This is a hook that ControlledWorkflow can use to prepare the workflow. @@ -133,11 +141,13 @@ export class Workflow { options?: { retryingStep?: LLMStepInstance } ): LLMStepInstance { // sorry for "any"; countervariance issues - const step: LLMStepInstance = template.instantiate( - this, + const step: LLMStepInstance = LLMStepInstance.create({ + template, inputs, - options?.retryingStep - ); + retryingStep: options?.retryingStep, + workflow: this, + }); + this.steps.push(step); this.dispatchEvent({ type: "stepAdded", @@ -357,33 +367,50 @@ export class Workflow { return { id: this.id, stepIds: this.steps.map(visitor.step), + llmConfig: this.llmConfig, }; } static deserialize({ node, visitor, - llmConfig, openaiApiKey, anthropicApiKey, }: { node: SerializedWorkflow; visitor: AiDeserializationVisitor; - llmConfig: LlmConfig; openaiApiKey?: string; anthropicApiKey?: string; }): Workflow { return new Workflow({ id: node.id, + llmConfig: node.llmConfig, steps: node.stepIds.map(visitor.step), - llmConfig, openaiApiKey, anthropicApiKey, }); } + + // Client-side representation + asClientWorkflow(): ClientWorkflow { + return { + id: this.id, + timestamp: this.startTime, + steps: this.steps.map(stepToClientStep), + currentStep: this.getCurrentStep()?.id, + ...(this.isProcessComplete() + ? { + status: "finished", + result: this.getFinalResult(), + } + : { status: "loading" }), + input: { type: "Create", prompt: "FIXME - not serialized" }, + }; + } } export type SerializedWorkflow = { id: string; + llmConfig: LlmConfig; stepIds: number[]; }; diff --git a/packages/ai/src/workflows/streaming.ts b/packages/ai/src/workflows/streaming.ts index 553038e83f..c92cf26685 100644 --- a/packages/ai/src/workflows/streaming.ts +++ b/packages/ai/src/workflows/streaming.ts @@ -4,16 +4,18 @@ import { } from "stream/web"; import { Artifact } from "../Artifact.js"; +import { type LLMStepInstance } from "../LLMStepInstance.js"; import { ClientArtifact, + ClientStep, ClientWorkflow, StreamingMessage, streamingMessageSchema, } from "../types.js"; import { type SquiggleWorkflowInput } from "./SquiggleWorkflow.js"; -import { Workflow } from "./Workflow.js"; +import { type Workflow } from "./Workflow.js"; -export function serializeArtifact(value: Artifact): ClientArtifact { +function artifactToClientArtifact(value: Artifact): ClientArtifact { const commonArtifactFields = { id: value.id, createdBy: value.createdBy?.id, @@ -43,6 +45,29 @@ export function serializeArtifact(value: Artifact): ClientArtifact { } } +export function stepToClientStep(step: LLMStepInstance): ClientStep { + return { + id: step.id, + name: step.template.name ?? "unknown", + state: step.getState().kind, + inputs: Object.fromEntries( + Object.entries(step.getInputs()).map(([key, value]) => [ + key, + artifactToClientArtifact(value), + ]) + ), + outputs: Object.fromEntries( + Object.entries(step.getOutputs()) + .filter( + (pair): pair is [string, NonNullable<(typeof pair)[1]>] => + pair[1] !== undefined + ) + .map(([key, value]) => [key, artifactToClientArtifact(value)]) + ), + messages: step.getConversationMessages(), + }; +} + /** * Add listeners to a workflow to stream the results as a ReadableStream. * @@ -78,7 +103,7 @@ export function addStreamingListeners( inputs: Object.fromEntries( Object.entries(event.data.step.getInputs()).map(([key, value]) => [ key, - serializeArtifact(value), + artifactToClientArtifact(value), ]) ), }, @@ -96,7 +121,7 @@ export function addStreamingListeners( (pair): pair is [string, NonNullable<(typeof pair)[1]>] => pair[1] !== undefined ) - .map(([key, value]) => [key, serializeArtifact(value)]) + .map(([key, value]) => [key, artifactToClientArtifact(value)]) ), messages: event.data.step.getConversationMessages(), }, @@ -151,6 +176,8 @@ export async function decodeWorkflowFromReader({ // Note that these are streaming events. // They are easy to confuse with workflow events. + // The difference is that streaming events are sent over the wire, and so they contain JSON data. + // Workflow events are internal to the server and so they contain non-JSON data (such as LLMStepInstance references). const event = streamingMessageSchema.parse(eventJson); switch (event.kind) { diff --git a/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql b/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql new file mode 100644 index 0000000000..dfedaddc1e --- /dev/null +++ b/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql @@ -0,0 +1,5 @@ +-- AlterTable +ALTER TABLE "AiWorkflow" ADD COLUMN "format" INTEGER NOT NULL DEFAULT 2; + +-- old workflows +UPDATE "AiWorkflow" SET "format" = 1; diff --git a/packages/hub/prisma/schema.prisma b/packages/hub/prisma/schema.prisma index 9f2de58779..4b1b2bf3b6 100644 --- a/packages/hub/prisma/schema.prisma +++ b/packages/hub/prisma/schema.prisma @@ -394,6 +394,9 @@ model AiWorkflow { user User @relation(fields: [userId], references: [id], onDelete: Cascade) userId String - // SerializedWorkflow + // v1: SerializedWorkflow + // v2: normalized bundle + format Int @default(2) + workflow Json } diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index 371e05a316..4663de77c1 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -1,26 +1,51 @@ import { getServerSession } from "next-auth"; -import { - ClientWorkflow, - decodeWorkflowFromReader, - LlmConfig, - SquiggleWorkflowInput, -} from "@quri/squiggle-ai"; -import { SquiggleWorkflow } from "@quri/squiggle-ai/server"; +import { LlmConfig } from "@quri/squiggle-ai"; +import { SquiggleWorkflow, Workflow } from "@quri/squiggle-ai/server"; import { authOptions } from "@/app/api/auth/[...nextauth]/authOptions"; import { getSelf, isSignedIn } from "@/graphql/helpers/userHelpers"; import { prisma } from "@/prisma"; -import { - bodyToLineReader, - createRequestBodySchema, - requestToInput, -} from "../../utils"; +import { getAiCodec, V2WorkflowData } from "../../serverUtils"; +import { createRequestBodySchema, requestToInput } from "../../utils"; // https://nextjs.org/docs/app/api-reference/file-conventions/route-segment-config#maxduration export const maxDuration = 300; +async function upsertWorkflow( + user: Awaited>, + workflow: Workflow +) { + const codec = getAiCodec(); + const serializer = codec.makeSerializer(); + const entrypoint = serializer.serialize("workflow", workflow); + const bundle = serializer.getBundle(); + + const v2Workflow: V2WorkflowData = { + entrypoint, + bundle, + }; + + await prisma.aiWorkflow.upsert({ + where: { + id: workflow.id, + }, + update: { + format: 2, + workflow: v2Workflow, + }, + create: { + id: workflow.id, + user: { + connect: { id: user.id }, + }, + format: 2, + workflow: v2Workflow, + }, + }); +} + export async function POST(req: Request) { const session = await getServerSession(authOptions); @@ -30,46 +55,6 @@ export async function POST(req: Request) { const user = await getSelf(session); - let workflow: ClientWorkflow; - - const streamToDatabase = ( - stream: ReadableStream, - input: SquiggleWorkflowInput - ) => { - decodeWorkflowFromReader({ - reader: bodyToLineReader(stream) as ReadableStreamDefaultReader, - input, - addWorkflow: async (newWorkflow) => { - await prisma.aiWorkflow.create({ - data: { - id: newWorkflow.id, - user: { - connect: { - id: user.id, - }, - }, - workflow: newWorkflow, - }, - }); - workflow = newWorkflow; - }, - setWorkflow: async (update) => { - if (!workflow) { - throw new Error( - "Internal error: setWorkflow called before addWorkflow" - ); - } - workflow = update(workflow); - await prisma.aiWorkflow.update({ - where: { - id: workflow.id, - }, - data: { workflow }, - }); - }, - }); - }; - try { const body = await req.json(); const request = createRequestBodySchema.parse(body); @@ -78,7 +63,7 @@ export async function POST(req: Request) { throw new Error("Prompt or Squiggle code is required"); } - // Create a SquiggleGenerator instance + // Create a SquiggleWorkflow instance const llmConfig: LlmConfig = { llmId: request.model ?? "Claude-Sonnet", priceLimit: 0.15, @@ -87,19 +72,25 @@ export async function POST(req: Request) { }; const input = requestToInput(request); - const stream = new SquiggleWorkflow({ + const openaiApiKey = process.env["OPENAI_API_KEY"]; + const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; + + const squiggleWorkflow = new SquiggleWorkflow({ llmConfig, input, abortSignal: req.signal, - openaiApiKey: process.env["OPENAI_API_KEY"], - anthropicApiKey: process.env["ANTHROPIC_API_KEY"], - }).runAsStream(); + openaiApiKey, + anthropicApiKey, + }); - const [responseStream, dbStream] = stream.tee(); + // save workflow to the database on each update + squiggleWorkflow.workflow.addEventListener("stepFinished", ({ workflow }) => + upsertWorkflow(user, workflow) + ); - streamToDatabase(dbStream as ReadableStream, input); + const stream = squiggleWorkflow.runAsStream(); - return new Response(responseStream as ReadableStream, { + return new Response(stream as ReadableStream, { headers: { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index 3d362eed59..9390c7b49c 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -4,6 +4,7 @@ import { prisma } from "@/prisma"; import { getUserOrRedirect } from "@/server/helpers"; import { AiDashboard } from "./AiDashboard"; +import { getAiCodec, v2WorkflowDataSchema } from "./serverUtils"; export default async function AiPage() { const user = await getUserOrRedirect(); @@ -17,9 +18,25 @@ export default async function AiPage() { }, }); - const workflows = rows.map((row) => { + const workflows: ClientWorkflow[] = rows.map((row) => { try { - return clientWorkflowSchema.parse(row.workflow); + switch (row.format) { + case 1: + return clientWorkflowSchema.parse(row.workflow); + case 2: { + const { bundle, entrypoint } = v2WorkflowDataSchema.parse( + row.workflow + ); + const codec = getAiCodec(); + const deserializer = codec.makeDeserializer(bundle); + const workflow = deserializer.deserialize(entrypoint); + console.log(workflow.asClientWorkflow().steps); + + return workflow.asClientWorkflow(); + } + default: + throw new Error(`Unknown workflow format: ${row.format}`); + } } catch (e) { return { id: row.id, diff --git a/packages/hub/src/app/ai/serverUtils.ts b/packages/hub/src/app/ai/serverUtils.ts new file mode 100644 index 0000000000..a9deadb700 --- /dev/null +++ b/packages/hub/src/app/ai/serverUtils.ts @@ -0,0 +1,26 @@ +import "server-only"; + +import { z } from "zod"; + +import { makeAiCodec } from "@quri/squiggle-ai/server"; + +export function getAiCodec() { + const openaiApiKey = process.env["OPENROUTER_API_KEY"]; + const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; + return makeAiCodec({ + openaiApiKey, + anthropicApiKey, + }); +} + +// schema for serialized workflow format in the database +// this type is not precise but it's better than nothing +export const v2WorkflowDataSchema = z.object({ + entrypoint: z.object({ + entityType: z.literal("workflow"), + pos: z.number(), + }), + bundle: z.any(), +}); + +export type V2WorkflowData = z.infer; From a6d101f2ec56e201ab7ab4f8030c585c6c8cd73c Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Thu, 10 Oct 2024 14:05:37 -0300 Subject: [PATCH 03/11] WorkflowTemplate --- packages/ai/src/LLMStepInstance.ts | 10 +- packages/ai/src/LLMStepTemplate.ts | 10 +- packages/ai/src/scripts/tests/create.ts | 17 ++- packages/ai/src/scripts/tests/edit.ts | 15 ++- packages/ai/src/serialization.ts | 15 +++ packages/ai/src/server.ts | 6 +- .../ai/src/workflows/ControlledWorkflow.ts | 120 ++++++++++++++---- packages/ai/src/workflows/SquiggleWorkflow.ts | 117 ++++++++--------- packages/ai/src/workflows/Workflow.ts | 4 +- packages/ai/src/workflows/registry.ts | 17 +++ packages/hub/src/app/ai/api/create/route.ts | 35 +++-- packages/hub/src/app/ai/page.tsx | 1 - 12 files changed, 250 insertions(+), 117 deletions(-) create mode 100644 packages/ai/src/workflows/registry.ts diff --git a/packages/ai/src/LLMStepInstance.ts b/packages/ai/src/LLMStepInstance.ts index 90901f61fe..6c21a2967d 100644 --- a/packages/ai/src/LLMStepInstance.ts +++ b/packages/ai/src/LLMStepInstance.ts @@ -8,9 +8,9 @@ import { ErrorType, ExecuteContext, Inputs, + IOShape, LLMStepTemplate, Outputs, - StepShape, StepState, } from "./LLMStepTemplate.js"; import { LogEntry, Logger, TimestampedLogEntry } from "./Logger.js"; @@ -23,7 +23,7 @@ import { import { getStepTemplateByName } from "./steps/registry.js"; import { Workflow } from "./workflows/Workflow.js"; -interface Params { +interface Params { id: string; sequentialId: number; template: LLMStepTemplate; @@ -36,7 +36,7 @@ interface Params { llmMetricsList: LlmMetrics[]; } -export class LLMStepInstance { +export class LLMStepInstance { public id: Params["id"]; public sequentialId: number; public readonly template: Params["template"]; @@ -71,7 +71,7 @@ export class LLMStepInstance { } // Create a new, PENDING step instance - static create(params: { + static create(params: { template: LLMStepInstance["template"]; inputs: LLMStepInstance["inputs"]; retryingStep: LLMStepInstance["retryingStep"]; @@ -379,7 +379,7 @@ export class LLMStepInstance { } export type SerializedStep = Omit< - Params, + Params, // TODO - serialize retryingStep reference "inputs" | "outputs" | "template" | "retryingStep" > & { diff --git a/packages/ai/src/LLMStepTemplate.ts b/packages/ai/src/LLMStepTemplate.ts index a372734e43..f34105160c 100644 --- a/packages/ai/src/LLMStepTemplate.ts +++ b/packages/ai/src/LLMStepTemplate.ts @@ -19,7 +19,7 @@ export type StepState = message: string; }; -export type StepShape< +export type IOShape< I extends Record = Record, O extends Record = Record, > = { @@ -27,11 +27,11 @@ export type StepShape< outputs: O; }; -export type Inputs> = { +export type Inputs> = { [K in keyof Shape["inputs"]]: Extract; }; -export type Outputs> = { +export type Outputs> = { [K in keyof Shape["outputs"]]: Extract< Artifact, { kind: Shape["outputs"][K] } @@ -40,7 +40,7 @@ export type Outputs> = { // ExecuteContext is the context that's available to the step implementation. // We intentionally don't pass the reference to the step implementation, so that steps won't mess with their internal state. -export type ExecuteContext = { +export type ExecuteContext = { setOutput>( key: K, value: Outputs[K] | Outputs[K]["value"] // can be either the artifact or the value inside the artifact @@ -50,7 +50,7 @@ export type ExecuteContext = { fail(errorType: ErrorType, message: string): void; }; -export class LLMStepTemplate { +export class LLMStepTemplate { constructor( public readonly name: string, public readonly shape: Shape, diff --git a/packages/ai/src/scripts/tests/create.ts b/packages/ai/src/scripts/tests/create.ts index 2a07e71fcf..afd02eacf7 100644 --- a/packages/ai/src/scripts/tests/create.ts +++ b/packages/ai/src/scripts/tests/create.ts @@ -1,6 +1,7 @@ import { config } from "dotenv"; -import { SquiggleWorkflow } from "../../workflows/SquiggleWorkflow.js"; +import { PromptArtifact } from "../../Artifact.js"; +import { createSquiggleWorkflowTemplate } from "../../workflows/SquiggleWorkflow.js"; config(); @@ -9,11 +10,15 @@ async function main() { "Generate a function that takes a list of numbers and returns the sum of the numbers"; const { totalPrice, runTimeMs, llmRunCount, code, isValid, logSummary } = - await new SquiggleWorkflow({ - input: { type: "Create", prompt }, - openaiApiKey: process.env["OPENAI_API_KEY"], - anthropicApiKey: process.env["ANTHROPIC_API_KEY"], - }).runToResult(); + await createSquiggleWorkflowTemplate + .instantiate({ + inputs: { + prompt: new PromptArtifact(prompt), + }, + openaiApiKey: process.env["OPENAI_API_KEY"], + anthropicApiKey: process.env["ANTHROPIC_API_KEY"], + }) + .runToResult(); const response = { code: typeof code === "string" ? code : "", diff --git a/packages/ai/src/scripts/tests/edit.ts b/packages/ai/src/scripts/tests/edit.ts index 5fc30cf9cf..f0681b1149 100644 --- a/packages/ai/src/scripts/tests/edit.ts +++ b/packages/ai/src/scripts/tests/edit.ts @@ -1,6 +1,7 @@ import { config } from "dotenv"; -import { SquiggleWorkflow } from "../../workflows/SquiggleWorkflow.js"; +import { SourceArtifact } from "../../Artifact.js"; +import { fixSquiggleWorkflowTemplate } from "../../workflows/SquiggleWorkflow.js"; config(); @@ -138,11 +139,13 @@ sTest.describe( ) `; const { totalPrice, runTimeMs, llmRunCount, code, isValid, logSummary } = - await new SquiggleWorkflow({ - input: { type: "Edit", source: initialCode }, - openaiApiKey: process.env["OPENAI_API_KEY"], - anthropicApiKey: process.env["ANTHROPIC_API_KEY"], - }).runToResult(); + await fixSquiggleWorkflowTemplate + .instantiate({ + inputs: { source: new SourceArtifact(initialCode) }, + openaiApiKey: process.env["OPENAI_API_KEY"], + anthropicApiKey: process.env["ANTHROPIC_API_KEY"], + }) + .runToResult(); const response = { code: typeof code === "string" ? code : "", diff --git a/packages/ai/src/serialization.ts b/packages/ai/src/serialization.ts index b862602cfd..4b8619dec7 100644 --- a/packages/ai/src/serialization.ts +++ b/packages/ai/src/serialization.ts @@ -8,10 +8,15 @@ import { SerializedArtifact, } from "./Artifact.js"; import { LLMStepInstance, SerializedStep } from "./LLMStepInstance.js"; +import { + SerializedWorkflowInstance, + WorkflowInstance, +} from "./workflows/ControlledWorkflow.js"; import { SerializedWorkflow, Workflow } from "./workflows/Workflow.js"; type AiShape = { workflow: [Workflow, SerializedWorkflow]; + workflowInstance: [WorkflowInstance, SerializedWorkflowInstance]; step: [LLMStepInstance, SerializedStep]; artifact: [Artifact, SerializedArtifact]; }; @@ -36,6 +41,16 @@ export function makeAiCodec(params: { anthropicApiKey: params.anthropicApiKey, }), }, + workflowInstance: { + serialize: (node, visitor) => node.serialize(visitor), + deserialize: (node, visitor) => + WorkflowInstance.deserialize({ + node, + visitor, + openaiApiKey: params.openaiApiKey, + anthropicApiKey: params.anthropicApiKey, + }), + }, step: { serialize: (node, visitor) => node.serialize(visitor), deserialize: (node, visitor) => diff --git a/packages/ai/src/server.ts b/packages/ai/src/server.ts index b62e1aa6ef..e06eecbf46 100644 --- a/packages/ai/src/server.ts +++ b/packages/ai/src/server.ts @@ -1,9 +1,11 @@ // This file is server-only; it shouldn't be imported into client-side React components. export { - SquiggleWorkflow, - SquiggleWorkflowInput, + createSquiggleWorkflowTemplate, + fixSquiggleWorkflowTemplate, } from "./workflows/SquiggleWorkflow.js"; export { Workflow } from "./workflows/Workflow.js"; export { makeAiCodec } from "./serialization.js"; + +export { CodeArtifact, PromptArtifact, SourceArtifact } from "./Artifact.js"; diff --git a/packages/ai/src/workflows/ControlledWorkflow.ts b/packages/ai/src/workflows/ControlledWorkflow.ts index 9b52890c84..71a4fc5e93 100644 --- a/packages/ai/src/workflows/ControlledWorkflow.ts +++ b/packages/ai/src/workflows/ControlledWorkflow.ts @@ -1,45 +1,76 @@ import { ReadableStream } from "stream/web"; +import { Inputs, IOShape } from "../LLMStepTemplate.js"; +import { + AiDeserializationVisitor, + AiSerializationVisitor, +} from "../serialization.js"; import { ClientWorkflowResult } from "../types.js"; +import { getWorkflowTemplateByName } from "./registry.js"; import { addStreamingListeners } from "./streaming.js"; import { LlmConfig, Workflow } from "./Workflow.js"; +type WorkflowInstanceParams = { + template: WorkflowTemplate; + inputs: Inputs; + abortSignal?: AbortSignal; // TODO + llmConfig?: LlmConfig; + openaiApiKey?: string; + anthropicApiKey?: string; +}; + /** - * This is a base class for other, specific workflows. - * - * It assumes that the `Workflow` is controlled by injecting new steps based on - * workflow events. + * This is a base class for workflow descriptions. * - * Specific workflows should override `configureControllerLoop` and - * `configureInitialSteps` to set up the workflow. They should also commonly - * override the constructor to capture the parameters. - * - * Note: it might be good to avoid inheritance and go with function composition, - * but `runAsStream` implementation and the way it interacts with the underlying - * `Workflow` through events makes it tricky. If you can think of a better way, - * please refactor! + * It works similarly to LLMStepTemplate, but for workflows. */ -export abstract class ControlledWorkflow { +export class WorkflowTemplate { + public readonly name: string; + + public configureControllerLoop: ( + workflow: Workflow, + inputs: Inputs + ) => void; + public configureInitialSteps: ( + workflow: Workflow, + inputs: Inputs + ) => void; + + // TODO - shape parameter + constructor(params: { + name: string; + // TODO - do we need two separate functions? we always call them together + configureControllerLoop: WorkflowTemplate["configureControllerLoop"]; + configureInitialSteps: WorkflowTemplate["configureInitialSteps"]; + }) { + this.name = params.name; + this.configureInitialSteps = params.configureInitialSteps; + this.configureControllerLoop = params.configureControllerLoop; + } + + instantiate( + params: Omit, "template"> + ): WorkflowInstance { + return new WorkflowInstance({ ...params, template: this }); + } +} + +export class WorkflowInstance { public workflow: Workflow; + public readonly template: WorkflowTemplate; + public readonly inputs: Inputs; private started: boolean = false; - constructor(params: { - abortSignal?: AbortSignal; - llmConfig?: LlmConfig; - openaiApiKey?: string; - anthropicApiKey?: string; - }) { + constructor(params: WorkflowInstanceParams) { this.workflow = Workflow.create( params.llmConfig, params.openaiApiKey, params.anthropicApiKey ); + this.template = params.template; + this.inputs = params.inputs; } - protected abstract configureControllerLoop(): void; - - protected abstract configureInitialSteps(): void; - private startOrThrow() { if (this.started) { throw new Error("Workflow already started"); @@ -48,8 +79,9 @@ export abstract class ControlledWorkflow { } private configure() { - this.configureControllerLoop(); - this.configureInitialSteps(); + // we configure the controller loop first, so it has a chance to react to its initial step + this.template.configureControllerLoop(this.workflow, this.inputs); + this.template.configureInitialSteps(this.workflow, this.inputs); } // Run workflow to the ReadableStream, appropriate for streaming in Next.js routes @@ -84,4 +116,42 @@ export abstract class ControlledWorkflow { // saveSummaryToFile(generateSummary(workflow)); return this.workflow.getFinalResult(); } + + serialize(visitor: AiSerializationVisitor): SerializedWorkflowInstance { + return { + inputIds: Object.fromEntries( + Object.entries(this.inputs).map(([key, input]) => [ + key, + visitor.artifact(input), + ]) + ), + workflowId: visitor.workflow(this.workflow), + templateName: this.template.name, + }; + } + + static deserialize(params: { + node: SerializedWorkflowInstance; + visitor: AiDeserializationVisitor; + openaiApiKey?: string; + anthropicApiKey?: string; + }) { + return new WorkflowInstance({ + inputs: Object.fromEntries( + Object.entries(params.node.inputIds).map(([key, id]) => [ + key, + params.visitor.artifact(id), + ]) + ), + template: getWorkflowTemplateByName(params.node.templateName), + openaiApiKey: params.openaiApiKey, + anthropicApiKey: params.anthropicApiKey, + }); + } } + +export type SerializedWorkflowInstance = { + inputIds: Record; + workflowId: number; + templateName: string; +}; diff --git a/packages/ai/src/workflows/SquiggleWorkflow.ts b/packages/ai/src/workflows/SquiggleWorkflow.ts index e46572ecc0..9d8b36ccfa 100644 --- a/packages/ai/src/workflows/SquiggleWorkflow.ts +++ b/packages/ai/src/workflows/SquiggleWorkflow.ts @@ -1,75 +1,78 @@ import { z } from "zod"; -import { PromptArtifact, SourceArtifact } from "../Artifact.js"; +import { PromptArtifact } from "../Artifact.js"; import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js"; import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js"; import { generateCodeStep } from "../steps/generateCodeStep.js"; import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; import { squiggleWorkflowInputSchema } from "../types.js"; -import { ControlledWorkflow } from "./ControlledWorkflow.js"; -import { LlmConfig } from "./Workflow.js"; +import { WorkflowTemplate } from "./ControlledWorkflow.js"; +import { Workflow } from "./Workflow.js"; export type SquiggleWorkflowInput = z.infer; -/** - * This is a basic workflow for generating Squiggle code. - * - * It generates code based on a prompt, fixes it if necessary, and tries to - * improve it based on feedback. - */ -export class SquiggleWorkflow extends ControlledWorkflow { - public readonly input: SquiggleWorkflowInput; - public readonly prompt: PromptArtifact; +// Shared between create and edit workflows +function fixAdjustRetryLoop(workflow: Workflow, prompt: PromptArtifact) { + workflow.addEventListener("stepFinished", ({ data: { step } }) => { + const code = step.getOutputs()["code"]; + const state = step.getState(); - constructor(params: { - input: SquiggleWorkflowInput; - abortSignal?: AbortSignal; - llmConfig?: LlmConfig; - openaiApiKey?: string; - anthropicApiKey?: string; - }) { - super(params); - - this.input = params.input; - this.prompt = new PromptArtifact( - this.input.type === "Create" ? this.input.prompt : "" - ); - } - - protected configureControllerLoop(): void { - this.workflow.addEventListener("stepFinished", ({ data: { step } }) => { - const code = step.getOutputs()["code"]; - const state = step.getState(); - - if (state.kind === "FAILED") { - if (state.errorType === "MINOR") { - this.workflow.addRetryOfPreviousStep(); - } - return true; + if (state.kind === "FAILED") { + if (state.errorType === "MINOR") { + workflow.addRetryOfPreviousStep(); } + return true; + } - if (code === undefined || code.kind !== "code") return; + if (code === undefined || code.kind !== "code") return; - if (code.value.type === "success") { - this.workflow.addStep(adjustToFeedbackStep, { - prompt: this.prompt, - code, - }); - } else { - this.workflow.addStep(fixCodeUntilItRunsStep, { - code, - }); - } - }); - } - - protected configureInitialSteps(): void { - if (this.input.type === "Create") { - this.workflow.addStep(generateCodeStep, { prompt: this.prompt }); + if (code.value.type === "success") { + workflow.addStep(adjustToFeedbackStep, { + prompt, + code, + }); } else { - this.workflow.addStep(runAndFormatCodeStep, { - source: new SourceArtifact(this.input.source), + workflow.addStep(fixCodeUntilItRunsStep, { + code, }); } - } + }); } + +/** + * This is a basic workflow for generating Squiggle code. + * + * It generates code based on a prompt, fixes it if necessary, and tries to + * improve it based on feedback. + */ +export const createSquiggleWorkflowTemplate = new WorkflowTemplate<{ + inputs: { + prompt: "prompt"; + }; + outputs: Record; +}>({ + name: "CreateSquiggle", + configureControllerLoop(workflow, inputs) { + fixAdjustRetryLoop(workflow, inputs.prompt); + }, + configureInitialSteps(workflow, inputs) { + workflow.addStep(generateCodeStep, { prompt: inputs.prompt }); + }, +}); + +export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ + inputs: { + source: "source"; + }; + outputs: Record; +}>({ + name: "FixSquiggle", + configureControllerLoop(workflow) { + // TODO - cache the prompt artifact once? maybe even as a global variable + // (but it's better to just refactor steps to make the prompt optional, somehow) + fixAdjustRetryLoop(workflow, new PromptArtifact("")); + }, + configureInitialSteps(workflow, inputs) { + workflow.addStep(runAndFormatCodeStep, { source: inputs.source }); + }, +}); diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index e43b006ae7..818d1779d1 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -6,7 +6,7 @@ import { Message, } from "../LLMClient.js"; import { LLMStepInstance } from "../LLMStepInstance.js"; -import { Inputs, LLMStepTemplate, StepShape } from "../LLMStepTemplate.js"; +import { Inputs, IOShape, LLMStepTemplate } from "../LLMStepTemplate.js"; import { TimestampedLogEntry } from "../Logger.js"; import { LlmId } from "../modelConfigs.js"; import { @@ -135,7 +135,7 @@ export class Workflow { this.dispatchEvent({ type: "workflowStarted" }); } - addStep( + addStep( template: LLMStepTemplate, inputs: Inputs, options?: { retryingStep?: LLMStepInstance } diff --git a/packages/ai/src/workflows/registry.ts b/packages/ai/src/workflows/registry.ts new file mode 100644 index 0000000000..2b96189693 --- /dev/null +++ b/packages/ai/src/workflows/registry.ts @@ -0,0 +1,17 @@ +import { WorkflowTemplate } from "./ControlledWorkflow.js"; +import { + createSquiggleWorkflowTemplate, + fixSquiggleWorkflowTemplate, +} from "./SquiggleWorkflow.js"; + +export function getWorkflowTemplateByName(name: string) { + const workflows: Record> = { + createSquiggle: createSquiggleWorkflowTemplate, + fixSquiggle: fixSquiggleWorkflowTemplate, + }; + + if (!(name in workflows)) { + throw new Error(`Workflow ${name} not found`); + } + return workflows[name]; +} diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index 4663de77c1..71fcacc236 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -1,7 +1,13 @@ import { getServerSession } from "next-auth"; import { LlmConfig } from "@quri/squiggle-ai"; -import { SquiggleWorkflow, Workflow } from "@quri/squiggle-ai/server"; +import { + createSquiggleWorkflowTemplate, + fixSquiggleWorkflowTemplate, + PromptArtifact, + SourceArtifact, + Workflow, +} from "@quri/squiggle-ai/server"; import { authOptions } from "@/app/api/auth/[...nextauth]/authOptions"; import { getSelf, isSignedIn } from "@/graphql/helpers/userHelpers"; @@ -75,13 +81,26 @@ export async function POST(req: Request) { const openaiApiKey = process.env["OPENAI_API_KEY"]; const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; - const squiggleWorkflow = new SquiggleWorkflow({ - llmConfig, - input, - abortSignal: req.signal, - openaiApiKey, - anthropicApiKey, - }); + const squiggleWorkflow = + input.type === "Create" + ? createSquiggleWorkflowTemplate.instantiate({ + llmConfig, + inputs: { + prompt: new PromptArtifact(input.prompt), + }, + abortSignal: req.signal, + openaiApiKey, + anthropicApiKey, + }) + : fixSquiggleWorkflowTemplate.instantiate({ + llmConfig, + inputs: { + source: new SourceArtifact(input.source), + }, + abortSignal: req.signal, + openaiApiKey, + anthropicApiKey, + }); // save workflow to the database on each update squiggleWorkflow.workflow.addEventListener("stepFinished", ({ workflow }) => diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index 9390c7b49c..58c5e73067 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -30,7 +30,6 @@ export default async function AiPage() { const codec = getAiCodec(); const deserializer = codec.makeDeserializer(bundle); const workflow = deserializer.deserialize(entrypoint); - console.log(workflow.asClientWorkflow().steps); return workflow.asClientWorkflow(); } From 6dcf849b266b4fbf59397d24c5d7b3f2e5289da8 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Tue, 8 Oct 2024 15:24:01 -0300 Subject: [PATCH 04/11] unify Workflow and WorkflowInstance; retire squiggle-specific input --- packages/ai/src/LLMStepInstance.ts | 28 ++-- packages/ai/src/generateSummary.ts | 17 +- packages/ai/src/index.ts | 2 + packages/ai/src/serialization.ts | 17 +- packages/ai/src/types.ts | 3 +- .../ai/src/workflows/ControlledWorkflow.ts | 128 ++------------- packages/ai/src/workflows/SquiggleWorkflow.ts | 6 +- packages/ai/src/workflows/Workflow.ts | 147 +++++++++++++----- packages/ai/src/workflows/registry.ts | 9 +- packages/ai/src/workflows/streaming.ts | 22 +-- packages/hub/src/app/ai/WorkflowName.tsx | 12 ++ .../hub/src/app/ai/WorkflowSummaryItem.tsx | 5 +- .../app/ai/WorkflowViewer/ArtifactList.tsx | 29 ++++ .../hub/src/app/ai/WorkflowViewer/Header.tsx | 5 +- .../WorkflowViewer/SelectedNodeSideView.tsx | 24 +-- packages/hub/src/app/ai/api/create/route.ts | 4 +- packages/hub/src/app/ai/page.tsx | 7 +- .../hub/src/app/ai/useSquiggleWorkflows.tsx | 9 +- 18 files changed, 237 insertions(+), 237 deletions(-) create mode 100644 packages/hub/src/app/ai/WorkflowName.tsx create mode 100644 packages/hub/src/app/ai/WorkflowViewer/ArtifactList.tsx diff --git a/packages/ai/src/LLMStepInstance.ts b/packages/ai/src/LLMStepInstance.ts index 6c21a2967d..211b61992a 100644 --- a/packages/ai/src/LLMStepInstance.ts +++ b/packages/ai/src/LLMStepInstance.ts @@ -75,7 +75,7 @@ export class LLMStepInstance { template: LLMStepInstance["template"]; inputs: LLMStepInstance["inputs"]; retryingStep: LLMStepInstance["retryingStep"]; - workflow: Workflow; + workflow: Workflow; }): LLMStepInstance { return new LLMStepInstance({ id: crypto.randomUUID(), @@ -101,7 +101,7 @@ export class LLMStepInstance { return this.conversationMessages; } - async _run(workflow: Workflow) { + async _run(workflow: Workflow) { if (this.state.kind !== "PENDING") { return; } @@ -137,7 +137,7 @@ export class LLMStepInstance { } } - async run(workflow: Workflow) { + async run(workflow: Workflow) { this.log( { type: "info", @@ -209,10 +209,13 @@ export class LLMStepInstance { // private methods - private setOutput>( + private setOutput< + K extends Extract, + WorkflowShape extends IOShape, + >( key: K, value: Outputs[K] | Outputs[K]["value"], - workflow: Workflow + workflow: Workflow ): void { if (key in this.outputs) { this.fail( @@ -234,14 +237,21 @@ export class LLMStepInstance { } } - private log(log: LogEntry, workflow: Workflow): void { + private log( + log: LogEntry, + workflow: Workflow + ): void { this.logger.log(log, { workflowId: workflow.id, stepIndex: this.sequentialId, }); } - private fail(errorType: ErrorType, message: string, workflow: Workflow) { + private fail( + errorType: ErrorType, + message: string, + workflow: Workflow + ) { this.log({ type: "error", message }, workflow); this.state = { kind: "FAILED", @@ -259,9 +269,9 @@ export class LLMStepInstance { this.conversationMessages.push(message); } - private async queryLLM( + private async queryLLM( promptPair: PromptPair, - workflow: Workflow + workflow: Workflow ): Promise { try { const messagesToSend: Message[] = [ diff --git a/packages/ai/src/generateSummary.ts b/packages/ai/src/generateSummary.ts index 7fc4ad2414..02465bd98b 100644 --- a/packages/ai/src/generateSummary.ts +++ b/packages/ai/src/generateSummary.ts @@ -4,10 +4,13 @@ import path from "path"; import { Artifact, ArtifactKind } from "./Artifact.js"; import { Code } from "./Code.js"; import { calculatePriceMultipleCalls } from "./LLMClient.js"; +import { IOShape } from "./LLMStepTemplate.js"; import { getLogEntryFullName, TimestampedLogEntry } from "./Logger.js"; import { Workflow } from "./workflows/Workflow.js"; -export function generateSummary(workflow: Workflow): string { +export function generateSummary( + workflow: Workflow +): string { let summary = ""; // Overview @@ -25,7 +28,9 @@ export function generateSummary(workflow: Workflow): string { return summary; } -function generateOverview(workflow: Workflow): string { +function generateOverview( + workflow: Workflow +): string { const steps = workflow.getSteps(); const metricsByLLM = workflow.llmMetricSummary(); @@ -47,7 +52,9 @@ function generateOverview(workflow: Workflow): string { return overview; } -function generateErrorSummary(workflow: Workflow): string { +function generateErrorSummary( + workflow: Workflow +): string { const steps = workflow.getSteps(); let errorSummary = ""; @@ -66,7 +73,9 @@ function generateErrorSummary(workflow: Workflow): string { return errorSummary || "✅ No errors encountered.\n"; } -function generateDetailedStepLogs(workflow: Workflow): string { +function generateDetailedStepLogs( + workflow: Workflow +): string { let detailedLogs = ""; const steps = workflow.getSteps(); steps.forEach((step, index) => { diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 1f9ca82d97..e6e9d1cd5e 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -9,6 +9,8 @@ export { streamingMessageSchema, } from "./types.js"; +export type { IOShape } from "./LLMStepTemplate.js"; + export { llmLinker } from "./Code.js"; export { type LlmId, type LlmName, MODEL_CONFIGS } from "./modelConfigs.js"; diff --git a/packages/ai/src/serialization.ts b/packages/ai/src/serialization.ts index 4b8619dec7..5bb6f68c77 100644 --- a/packages/ai/src/serialization.ts +++ b/packages/ai/src/serialization.ts @@ -8,15 +8,10 @@ import { SerializedArtifact, } from "./Artifact.js"; import { LLMStepInstance, SerializedStep } from "./LLMStepInstance.js"; -import { - SerializedWorkflowInstance, - WorkflowInstance, -} from "./workflows/ControlledWorkflow.js"; import { SerializedWorkflow, Workflow } from "./workflows/Workflow.js"; type AiShape = { - workflow: [Workflow, SerializedWorkflow]; - workflowInstance: [WorkflowInstance, SerializedWorkflowInstance]; + workflow: [Workflow, SerializedWorkflow]; step: [LLMStepInstance, SerializedStep]; artifact: [Artifact, SerializedArtifact]; }; @@ -41,16 +36,6 @@ export function makeAiCodec(params: { anthropicApiKey: params.anthropicApiKey, }), }, - workflowInstance: { - serialize: (node, visitor) => node.serialize(visitor), - deserialize: (node, visitor) => - WorkflowInstance.deserialize({ - node, - visitor, - openaiApiKey: params.openaiApiKey, - anthropicApiKey: params.anthropicApiKey, - }), - }, step: { serialize: (node, visitor) => node.serialize(visitor), deserialize: (node, visitor) => diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 8cd5a2392d..de9d64ab8d 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -73,6 +73,7 @@ export type ClientStep = z.infer; const workflowStartedSchema = z.object({ id: z.string(), timestamp: z.number(), + inputs: z.record(z.string(), artifactSchema), }); const stepAddedSchema = stepSchema.omit({ @@ -126,7 +127,7 @@ export type StreamingMessage = z.infer; const commonClientWorkflowFields = { id: z.string(), timestamp: z.number(), // milliseconds since epoch - input: squiggleWorkflowInputSchema, // FIXME - SquiggleWorkflow-specific + inputs: z.record(z.string(), artifactSchema), steps: z.array(stepSchema), currentStep: z.string().optional(), }; diff --git a/packages/ai/src/workflows/ControlledWorkflow.ts b/packages/ai/src/workflows/ControlledWorkflow.ts index 71a4fc5e93..d24a9fbc59 100644 --- a/packages/ai/src/workflows/ControlledWorkflow.ts +++ b/packages/ai/src/workflows/ControlledWorkflow.ts @@ -1,18 +1,12 @@ -import { ReadableStream } from "stream/web"; - +import { LLMStepInstance } from "../LLMStepInstance.js"; import { Inputs, IOShape } from "../LLMStepTemplate.js"; -import { - AiDeserializationVisitor, - AiSerializationVisitor, -} from "../serialization.js"; -import { ClientWorkflowResult } from "../types.js"; -import { getWorkflowTemplateByName } from "./registry.js"; -import { addStreamingListeners } from "./streaming.js"; import { LlmConfig, Workflow } from "./Workflow.js"; -type WorkflowInstanceParams = { +export type WorkflowInstanceParams = { + id: string; template: WorkflowTemplate; inputs: Inputs; + steps: LLMStepInstance[]; abortSignal?: AbortSignal; // TODO llmConfig?: LlmConfig; openaiApiKey?: string; @@ -28,11 +22,11 @@ export class WorkflowTemplate { public readonly name: string; public configureControllerLoop: ( - workflow: Workflow, + workflow: Workflow, inputs: Inputs ) => void; public configureInitialSteps: ( - workflow: Workflow, + workflow: Workflow, inputs: Inputs ) => void; @@ -49,109 +43,13 @@ export class WorkflowTemplate { } instantiate( - params: Omit, "template"> - ): WorkflowInstance { - return new WorkflowInstance({ ...params, template: this }); - } -} - -export class WorkflowInstance { - public workflow: Workflow; - public readonly template: WorkflowTemplate; - public readonly inputs: Inputs; - private started: boolean = false; - - constructor(params: WorkflowInstanceParams) { - this.workflow = Workflow.create( - params.llmConfig, - params.openaiApiKey, - params.anthropicApiKey - ); - this.template = params.template; - this.inputs = params.inputs; - } - - private startOrThrow() { - if (this.started) { - throw new Error("Workflow already started"); - } - this.started = true; - } - - private configure() { - // we configure the controller loop first, so it has a chance to react to its initial step - this.template.configureControllerLoop(this.workflow, this.inputs); - this.template.configureInitialSteps(this.workflow, this.inputs); - } - - // Run workflow to the ReadableStream, appropriate for streaming in Next.js routes - runAsStream(): ReadableStream { - this.startOrThrow(); - - const stream = new ReadableStream({ - start: async (controller) => { - addStreamingListeners(this.workflow, controller); - - this.workflow.prepareToStart(); - - // Important! `configure` should be called after all event listeners are set up. - // We want to capture `stepAdded` events. - this.configure(); - - await this.workflow.runUntilComplete(); - controller.close(); - }, - }); - - return stream; - } - - // Run workflow without streaming, only capture the final result - async runToResult(): Promise { - this.startOrThrow(); - this.configure(); - - await this.workflow.runUntilComplete(); - - // saveSummaryToFile(generateSummary(workflow)); - return this.workflow.getFinalResult(); - } - - serialize(visitor: AiSerializationVisitor): SerializedWorkflowInstance { - return { - inputIds: Object.fromEntries( - Object.entries(this.inputs).map(([key, input]) => [ - key, - visitor.artifact(input), - ]) - ), - workflowId: visitor.workflow(this.workflow), - templateName: this.template.name, - }; - } - - static deserialize(params: { - node: SerializedWorkflowInstance; - visitor: AiDeserializationVisitor; - openaiApiKey?: string; - anthropicApiKey?: string; - }) { - return new WorkflowInstance({ - inputs: Object.fromEntries( - Object.entries(params.node.inputIds).map(([key, id]) => [ - key, - params.visitor.artifact(id), - ]) - ), - template: getWorkflowTemplateByName(params.node.templateName), - openaiApiKey: params.openaiApiKey, - anthropicApiKey: params.anthropicApiKey, + params: Omit, "id" | "template" | "steps"> + ): Workflow { + return new Workflow({ + ...params, + id: crypto.randomUUID(), + template: this, + steps: [], }); } } - -export type SerializedWorkflowInstance = { - inputIds: Record; - workflowId: number; - templateName: string; -}; diff --git a/packages/ai/src/workflows/SquiggleWorkflow.ts b/packages/ai/src/workflows/SquiggleWorkflow.ts index 9d8b36ccfa..e2d23d6608 100644 --- a/packages/ai/src/workflows/SquiggleWorkflow.ts +++ b/packages/ai/src/workflows/SquiggleWorkflow.ts @@ -1,6 +1,7 @@ import { z } from "zod"; import { PromptArtifact } from "../Artifact.js"; +import { IOShape } from "../LLMStepTemplate.js"; import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js"; import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js"; import { generateCodeStep } from "../steps/generateCodeStep.js"; @@ -12,7 +13,10 @@ import { Workflow } from "./Workflow.js"; export type SquiggleWorkflowInput = z.infer; // Shared between create and edit workflows -function fixAdjustRetryLoop(workflow: Workflow, prompt: PromptArtifact) { +function fixAdjustRetryLoop( + workflow: Workflow, + prompt: PromptArtifact +) { workflow.addEventListener("stepFinished", ({ data: { step } }) => { const code = step.getOutputs()["code"]; const state = step.getState(); diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index 818d1779d1..c87701b756 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -1,3 +1,5 @@ +import { ReadableStream } from "stream/web"; + import { generateSummary } from "../generateSummary.js"; import { calculatePriceMultipleCalls, @@ -14,7 +16,16 @@ import { AiSerializationVisitor, } from "../serialization.js"; import { ClientWorkflow, ClientWorkflowResult } from "../types.js"; -import { stepToClientStep } from "./streaming.js"; +import { + WorkflowInstanceParams, + WorkflowTemplate, +} from "./ControlledWorkflow.js"; +import { getWorkflowTemplateByName } from "./registry.js"; +import { + addStreamingListeners, + artifactToClientArtifact, + stepToClientStep, +} from "./streaming.js"; export type LlmConfig = { llmId: LlmId; @@ -38,19 +49,19 @@ export type WorkflowEventShape = | { type: "stepAdded"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { type: "stepStarted"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { type: "stepFinished"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { @@ -60,19 +71,23 @@ export type WorkflowEventShape = export type WorkflowEventType = WorkflowEventShape["type"]; -export class WorkflowEvent extends Event { +export class WorkflowEvent< + T extends WorkflowEventType, + Shape extends IOShape, +> extends Event { constructor( type: T, - public workflow: Workflow, + public workflow: Workflow, public data: Extract["payload"] ) { super(type); } } -export type WorkflowEventListener = ( - event: WorkflowEvent -) => void; +export type WorkflowEventListener< + T extends WorkflowEventType, + Shape extends IOShape, +> = (event: WorkflowEvent) => void; /** * This class is responsible for managing the steps in a workflow. @@ -85,25 +100,26 @@ export type WorkflowEventListener = ( */ const MAX_RETRIES = 5; -export class Workflow { +export class Workflow { public id: string; + public readonly template: WorkflowTemplate; + public readonly inputs: Inputs; + private started: boolean = false; + public llmConfig: LlmConfig; public startTime: number; - private steps: LLMStepInstance[]; + private steps: LLMStepInstance[]; public llmClient: LLMClient; - private constructor(params: { - id: string; - steps: LLMStepInstance[]; - llmConfig: LlmConfig; - openaiApiKey?: string; - anthropicApiKey?: string; - }) { - this.llmConfig = params.llmConfig; - this.startTime = Date.now(); + constructor(params: WorkflowInstanceParams) { this.id = params.id ?? crypto.randomUUID(); + this.template = params.template; + this.inputs = params.inputs; + + this.llmConfig = params.llmConfig ?? llmConfigDefault; + this.startTime = Date.now(); this.steps = params.steps ?? []; this.llmClient = new LLMClient( @@ -113,26 +129,52 @@ export class Workflow { ); } - static create( - llmConfig: LlmConfig = llmConfigDefault, - openaiApiKey?: string, - anthropicApiKey?: string - ) { - return new Workflow({ - id: crypto.randomUUID(), - steps: [], - llmConfig, - openaiApiKey, - anthropicApiKey, + private startOrThrow() { + if (this.started) { + throw new Error("Workflow already started"); + } + this.started = true; + } + + private configure() { + // we configure the controller loop first, so it has a chance to react to its initial step + this.template.configureControllerLoop(this, this.inputs); + this.template.configureInitialSteps(this, this.inputs); + } + + // Run workflow to the ReadableStream, appropriate for streaming in Next.js routes + runAsStream(): ReadableStream { + this.startOrThrow(); + + const stream = new ReadableStream({ + start: async (controller) => { + addStreamingListeners(this, controller); + + // We need to dispatch this event after we configured the event + // handlers, but before we add any steps. + this.dispatchEvent({ type: "workflowStarted" }); + + // Important! `configure` should be called after all event listeners are + // set up. We want to capture `stepAdded` events. + this.configure(); + + await this.runUntilComplete(); + controller.close(); + }, }); + + return stream; } - // This is a hook that ControlledWorkflow can use to prepare the workflow. - // It's a bit of a hack; we need to dispatch this event after we configured the event handlers, - // but before we add any steps. - // So we can't do this neither in the constructor nor in `runUntilComplete`. - prepareToStart() { - this.dispatchEvent({ type: "workflowStarted" }); + // Run workflow without streaming, only capture the final result + async runToResult(): Promise { + this.startOrThrow(); + this.configure(); + + await this.runUntilComplete(); + + // saveSummaryToFile(generateSummary(workflow)); + return this.getFinalResult(); } addStep( @@ -140,7 +182,7 @@ export class Workflow { inputs: Inputs, options?: { retryingStep?: LLMStepInstance } ): LLMStepInstance { - // sorry for "any"; countervariance issues + // sorry for "any"; contravariance issues const step: LLMStepInstance = LLMStepInstance.create({ template, inputs, @@ -347,14 +389,14 @@ export class Workflow { addEventListener( type: T, - listener: WorkflowEventListener + listener: WorkflowEventListener ) { this.eventTarget.addEventListener(type, listener as (event: Event) => void); } removeEventListener( type: T, - listener: WorkflowEventListener + listener: WorkflowEventListener ) { this.eventTarget.removeEventListener( type, @@ -366,6 +408,13 @@ export class Workflow { serialize(visitor: AiSerializationVisitor): SerializedWorkflow { return { id: this.id, + templateName: this.template.name, + inputIds: Object.fromEntries( + Object.entries(this.inputs).map(([key, input]) => [ + key, + visitor.artifact(input), + ]) + ), stepIds: this.steps.map(visitor.step), llmConfig: this.llmConfig, }; @@ -381,9 +430,16 @@ export class Workflow { visitor: AiDeserializationVisitor; openaiApiKey?: string; anthropicApiKey?: string; - }): Workflow { + }): Workflow { return new Workflow({ id: node.id, + template: getWorkflowTemplateByName(node.templateName), + inputs: Object.fromEntries( + Object.entries(node.inputIds).map(([key, id]) => [ + key, + visitor.artifact(id), + ]) + ), llmConfig: node.llmConfig, steps: node.stepIds.map(visitor.step), openaiApiKey, @@ -404,13 +460,20 @@ export class Workflow { result: this.getFinalResult(), } : { status: "loading" }), - input: { type: "Create", prompt: "FIXME - not serialized" }, + inputs: Object.fromEntries( + Object.entries(this.inputs).map(([key, value]) => [ + key, + artifactToClientArtifact(value), + ]) + ), }; } } export type SerializedWorkflow = { id: string; + templateName: string; + inputIds: Record; llmConfig: LlmConfig; stepIds: number[]; }; diff --git a/packages/ai/src/workflows/registry.ts b/packages/ai/src/workflows/registry.ts index 2b96189693..633d856e94 100644 --- a/packages/ai/src/workflows/registry.ts +++ b/packages/ai/src/workflows/registry.ts @@ -5,10 +5,11 @@ import { } from "./SquiggleWorkflow.js"; export function getWorkflowTemplateByName(name: string) { - const workflows: Record> = { - createSquiggle: createSquiggleWorkflowTemplate, - fixSquiggle: fixSquiggleWorkflowTemplate, - }; + const workflows: Record> = Object.fromEntries( + [createSquiggleWorkflowTemplate, fixSquiggleWorkflowTemplate].map( + (workflow) => [workflow.name, workflow] + ) + ); if (!(name in workflows)) { throw new Error(`Workflow ${name} not found`); diff --git a/packages/ai/src/workflows/streaming.ts b/packages/ai/src/workflows/streaming.ts index c92cf26685..9ff9e8232f 100644 --- a/packages/ai/src/workflows/streaming.ts +++ b/packages/ai/src/workflows/streaming.ts @@ -5,6 +5,7 @@ import { import { Artifact } from "../Artifact.js"; import { type LLMStepInstance } from "../LLMStepInstance.js"; +import { IOShape } from "../LLMStepTemplate.js"; import { ClientArtifact, ClientStep, @@ -12,10 +13,9 @@ import { StreamingMessage, streamingMessageSchema, } from "../types.js"; -import { type SquiggleWorkflowInput } from "./SquiggleWorkflow.js"; import { type Workflow } from "./Workflow.js"; -function artifactToClientArtifact(value: Artifact): ClientArtifact { +export function artifactToClientArtifact(value: Artifact): ClientArtifact { const commonArtifactFields = { id: value.id, createdBy: value.createdBy?.id, @@ -76,8 +76,8 @@ export function stepToClientStep(step: LLMStepInstance): ClientStep { * `ControlledWorkflow.runAsStream()` relies on this function; see its * implementation for more details. */ -export function addStreamingListeners( - workflow: Workflow, +export function addStreamingListeners( + workflow: Workflow, controller: ReadableStreamController ) { const send = (message: StreamingMessage) => { @@ -90,6 +90,12 @@ export function addStreamingListeners( content: { id: event.workflow.id, timestamp: event.workflow.startTime, + inputs: Object.fromEntries( + Object.entries(event.workflow.inputs).map(([key, value]) => [ + key, + artifactToClientArtifact(value), + ]) + ), }, }); }); @@ -148,16 +154,10 @@ export function addStreamingListeners( */ export async function decodeWorkflowFromReader({ reader, - input, addWorkflow, setWorkflow, }: { reader: ReadableStreamDefaultReader; - // FIXME - this shouldn't be necessary, but we need to inject the input to - // SerializedWorkflow, and it's not stored on the original Workflow yet, so - // it's not present in the stream data. - // In the future, we should store input parameters in the Workflow object. - input: SquiggleWorkflowInput; // This adds an initial version of the workflow. addWorkflow: (workflow: ClientWorkflow) => Promise; // This signature might look complicated, but it matches the functional @@ -185,7 +185,7 @@ export async function decodeWorkflowFromReader({ await addWorkflow({ id: event.content.id, timestamp: event.content.timestamp, - input, + inputs: event.content.inputs, steps: [], status: "loading", }); diff --git a/packages/hub/src/app/ai/WorkflowName.tsx b/packages/hub/src/app/ai/WorkflowName.tsx new file mode 100644 index 0000000000..5f0aedacde --- /dev/null +++ b/packages/hub/src/app/ai/WorkflowName.tsx @@ -0,0 +1,12 @@ +import { FC } from "react"; + +import { ClientWorkflow } from "@quri/squiggle-ai"; + +export const WorkflowName: FC<{ workflow: ClientWorkflow }> = ({ + workflow, +}) => { + return "prompt" in workflow.inputs && + workflow.inputs["prompt"].kind === "prompt" + ? workflow.inputs["prompt"].value + : workflow.id; +}; diff --git a/packages/hub/src/app/ai/WorkflowSummaryItem.tsx b/packages/hub/src/app/ai/WorkflowSummaryItem.tsx index 2884765029..f92774bc0b 100644 --- a/packages/hub/src/app/ai/WorkflowSummaryItem.tsx +++ b/packages/hub/src/app/ai/WorkflowSummaryItem.tsx @@ -5,6 +5,7 @@ import { FC } from "react"; import { ClientWorkflow } from "@quri/squiggle-ai"; +import { WorkflowName } from "./WorkflowName"; import { WorkflowStatusIcon } from "./WorkflowStatusIcon"; export const WorkflowSummaryItem: FC<{ @@ -24,7 +25,9 @@ export const WorkflowSummaryItem: FC<{
-
{workflow.input.prompt}
+
+ +
{workflow.status === "loading" && (
diff --git a/packages/hub/src/app/ai/WorkflowViewer/ArtifactList.tsx b/packages/hub/src/app/ai/WorkflowViewer/ArtifactList.tsx new file mode 100644 index 0000000000..55e2d541ee --- /dev/null +++ b/packages/hub/src/app/ai/WorkflowViewer/ArtifactList.tsx @@ -0,0 +1,29 @@ +import { FC } from "react"; + +import { ClientArtifact } from "@quri/squiggle-ai"; + +import { ArtifactDisplay } from "./ArtifactDisplay"; + +export const ArtifactList: FC<{ + title?: string; + artifacts: Record; +}> = ({ title, artifacts }) => { + return ( +
+ {title && ( +

{title}

+ )} +
+ {Object.entries(artifacts).map(([key, value]) => ( + + ))} +
+
+ ); +}; diff --git a/packages/hub/src/app/ai/WorkflowViewer/Header.tsx b/packages/hub/src/app/ai/WorkflowViewer/Header.tsx index 4ca379b6c4..8721ebbb02 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/Header.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/Header.tsx @@ -3,7 +3,9 @@ import { FC, ReactNode } from "react"; import { ClientWorkflow } from "@quri/squiggle-ai"; import { Button } from "@quri/ui"; +import { WorkflowName } from "../WorkflowName"; import { WorkflowStatusIcon } from "../WorkflowStatusIcon"; +import { ArtifactList } from "./ArtifactList"; // Common header for all workflow states export const Header: FC<{ @@ -17,10 +19,11 @@ export const Header: FC<{
- {workflow.input.prompt} +
{renderLeft()} +
{renderRight()} diff --git a/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx b/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx index b8fdf8b356..f79f7e6492 100644 --- a/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx +++ b/packages/hub/src/app/ai/WorkflowViewer/SelectedNodeSideView.tsx @@ -7,7 +7,7 @@ import { ChevronLeftIcon, ChevronRightIcon, XIcon } from "@quri/ui"; import { useAvailableHeight } from "@/hooks/useAvailableHeight"; import { SquigglePlaygroundForWorkflow } from "../SquigglePlaygroundForWorkflow"; -import { ArtifactDisplay } from "./ArtifactDisplay"; +import { ArtifactList } from "./ArtifactList"; import { ArtifactMessages } from "./ArtifactMessages"; const NavButton: FC<{ @@ -31,28 +31,6 @@ const NavButton: FC<{ ); }; -const ArtifactList: FC<{ - title: string; - artifacts: Record; -}> = ({ title, artifacts }) => { - return ( -
-

{title}

-
- {Object.entries(artifacts).map(([key, value]) => ( - - ))} -
-
- ); -}; - export const SelectedNodeSideView: FC<{ selectedNode: ClientStep; onClose: () => void; diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index 71fcacc236..fe96958bad 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -21,7 +21,7 @@ export const maxDuration = 300; async function upsertWorkflow( user: Awaited>, - workflow: Workflow + workflow: Workflow ) { const codec = getAiCodec(); const serializer = codec.makeSerializer(); @@ -103,7 +103,7 @@ export async function POST(req: Request) { }); // save workflow to the database on each update - squiggleWorkflow.workflow.addEventListener("stepFinished", ({ workflow }) => + squiggleWorkflow.addEventListener("stepFinished", ({ workflow }) => upsertWorkflow(user, workflow) ); diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index 58c5e73067..5f2e515ae2 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -41,12 +41,9 @@ export default async function AiPage() { id: row.id, timestamp: row.createdAt.getTime(), status: "error", - input: { - type: "Create", - prompt: "[unknown workflow]", - }, + inputs: {}, steps: [], - result: "Invalid workflow format in the database", + result: `Invalid workflow format in the database: ${e}`, } satisfies ClientWorkflow; } }); diff --git a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx index 6b56b823be..fef1aacdc1 100644 --- a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx +++ b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx @@ -32,7 +32,13 @@ export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { id, timestamp: new Date().getTime(), status: "loading", - input, + inputs: { + prompt: { + id: "prompt", + kind: "prompt", + value: input.prompt ?? "[FIX]", + }, + }, steps: [], }; setWorkflows((workflows) => [...workflows, workflow]); @@ -66,7 +72,6 @@ export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { await decodeWorkflowFromReader({ reader: reader as ReadableStreamDefaultReader, // frontend types don't precisely match Node.js types - input: requestToInput(request), addWorkflow: async (workflow) => { // Replace the mock workflow with the real workflow. setWorkflows((workflows) => From 8668f954b8b571daa27bc1788d053daeda89090e Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Thu, 10 Oct 2024 13:28:25 -0300 Subject: [PATCH 05/11] change ai request shape; refactor workflow files --- packages/ai/src/index.ts | 3 - packages/ai/src/scripts/tests/create.ts | 2 +- packages/ai/src/scripts/tests/edit.ts | 2 +- packages/ai/src/server.ts | 6 +- packages/ai/src/types.ts | 13 --- packages/ai/src/workflows/SquiggleWorkflow.ts | 82 ------------------- packages/ai/src/workflows/controllers.ts | 36 ++++++++ .../createSquiggleWorkflowTemplate.ts | 25 ++++++ .../workflows/fixSquiggleWorkflowTemplate.ts | 21 +++++ packages/ai/src/workflows/registry.ts | 6 +- packages/hub/src/app/ai/Sidebar.tsx | 23 ++++-- packages/hub/src/app/ai/api/create/route.ts | 15 ++-- packages/hub/src/app/ai/page.tsx | 7 ++ .../hub/src/app/ai/useSquiggleWorkflows.tsx | 18 ++-- packages/hub/src/app/ai/utils.ts | 33 ++++---- 15 files changed, 137 insertions(+), 155 deletions(-) delete mode 100644 packages/ai/src/workflows/SquiggleWorkflow.ts create mode 100644 packages/ai/src/workflows/controllers.ts create mode 100644 packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts create mode 100644 packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index e6e9d1cd5e..140d1b93ca 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -15,7 +15,4 @@ export { llmLinker } from "./Code.js"; export { type LlmId, type LlmName, MODEL_CONFIGS } from "./modelConfigs.js"; -// Export type only! We can't import SquiggleWorkflow.js because it depends on Node.js modules such as "fs". -export { type SquiggleWorkflowInput } from "./workflows/SquiggleWorkflow.js"; - export { decodeWorkflowFromReader } from "./workflows/streaming.js"; diff --git a/packages/ai/src/scripts/tests/create.ts b/packages/ai/src/scripts/tests/create.ts index afd02eacf7..2b9be45d69 100644 --- a/packages/ai/src/scripts/tests/create.ts +++ b/packages/ai/src/scripts/tests/create.ts @@ -1,7 +1,7 @@ import { config } from "dotenv"; import { PromptArtifact } from "../../Artifact.js"; -import { createSquiggleWorkflowTemplate } from "../../workflows/SquiggleWorkflow.js"; +import { createSquiggleWorkflowTemplate } from "../../workflows/createSquiggleWorkflowTemplate.js"; config(); diff --git a/packages/ai/src/scripts/tests/edit.ts b/packages/ai/src/scripts/tests/edit.ts index f0681b1149..729a618694 100644 --- a/packages/ai/src/scripts/tests/edit.ts +++ b/packages/ai/src/scripts/tests/edit.ts @@ -1,7 +1,7 @@ import { config } from "dotenv"; import { SourceArtifact } from "../../Artifact.js"; -import { fixSquiggleWorkflowTemplate } from "../../workflows/SquiggleWorkflow.js"; +import { fixSquiggleWorkflowTemplate } from "../../workflows/fixSquiggleWorkflowTemplate.js"; config(); diff --git a/packages/ai/src/server.ts b/packages/ai/src/server.ts index e06eecbf46..a4c17c691c 100644 --- a/packages/ai/src/server.ts +++ b/packages/ai/src/server.ts @@ -1,8 +1,6 @@ // This file is server-only; it shouldn't be imported into client-side React components. -export { - createSquiggleWorkflowTemplate, - fixSquiggleWorkflowTemplate, -} from "./workflows/SquiggleWorkflow.js"; +export { createSquiggleWorkflowTemplate } from "./workflows/createSquiggleWorkflowTemplate.js"; +export { fixSquiggleWorkflowTemplate } from "./workflows/fixSquiggleWorkflowTemplate.js"; export { Workflow } from "./workflows/Workflow.js"; diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index de9d64ab8d..e92ad454b7 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -1,18 +1,5 @@ import { z } from "zod"; -// This could be defined in SquiggleWorkflow.ts, but it would cause a dependency on server-only modules. -export const squiggleWorkflowInputSchema = z.discriminatedUnion("type", [ - z.object({ - type: z.literal("Create"), - prompt: z.string(), - }), - z.object({ - type: z.literal("Edit"), - source: z.string(), - prompt: z.string().optional(), - }), -]); - // Protocol for streaming workflow changes between server and client. // ClientArtifact type diff --git a/packages/ai/src/workflows/SquiggleWorkflow.ts b/packages/ai/src/workflows/SquiggleWorkflow.ts deleted file mode 100644 index e2d23d6608..0000000000 --- a/packages/ai/src/workflows/SquiggleWorkflow.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { z } from "zod"; - -import { PromptArtifact } from "../Artifact.js"; -import { IOShape } from "../LLMStepTemplate.js"; -import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js"; -import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js"; -import { generateCodeStep } from "../steps/generateCodeStep.js"; -import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; -import { squiggleWorkflowInputSchema } from "../types.js"; -import { WorkflowTemplate } from "./ControlledWorkflow.js"; -import { Workflow } from "./Workflow.js"; - -export type SquiggleWorkflowInput = z.infer; - -// Shared between create and edit workflows -function fixAdjustRetryLoop( - workflow: Workflow, - prompt: PromptArtifact -) { - workflow.addEventListener("stepFinished", ({ data: { step } }) => { - const code = step.getOutputs()["code"]; - const state = step.getState(); - - if (state.kind === "FAILED") { - if (state.errorType === "MINOR") { - workflow.addRetryOfPreviousStep(); - } - return true; - } - - if (code === undefined || code.kind !== "code") return; - - if (code.value.type === "success") { - workflow.addStep(adjustToFeedbackStep, { - prompt, - code, - }); - } else { - workflow.addStep(fixCodeUntilItRunsStep, { - code, - }); - } - }); -} - -/** - * This is a basic workflow for generating Squiggle code. - * - * It generates code based on a prompt, fixes it if necessary, and tries to - * improve it based on feedback. - */ -export const createSquiggleWorkflowTemplate = new WorkflowTemplate<{ - inputs: { - prompt: "prompt"; - }; - outputs: Record; -}>({ - name: "CreateSquiggle", - configureControllerLoop(workflow, inputs) { - fixAdjustRetryLoop(workflow, inputs.prompt); - }, - configureInitialSteps(workflow, inputs) { - workflow.addStep(generateCodeStep, { prompt: inputs.prompt }); - }, -}); - -export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ - inputs: { - source: "source"; - }; - outputs: Record; -}>({ - name: "FixSquiggle", - configureControllerLoop(workflow) { - // TODO - cache the prompt artifact once? maybe even as a global variable - // (but it's better to just refactor steps to make the prompt optional, somehow) - fixAdjustRetryLoop(workflow, new PromptArtifact("")); - }, - configureInitialSteps(workflow, inputs) { - workflow.addStep(runAndFormatCodeStep, { source: inputs.source }); - }, -}); diff --git a/packages/ai/src/workflows/controllers.ts b/packages/ai/src/workflows/controllers.ts new file mode 100644 index 0000000000..a957203053 --- /dev/null +++ b/packages/ai/src/workflows/controllers.ts @@ -0,0 +1,36 @@ +import { PromptArtifact } from "../Artifact.js"; +import { IOShape } from "../LLMStepTemplate.js"; +import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js"; +import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js"; +import { Workflow } from "./Workflow.js"; + +// Shared between create and edit workflows +export function fixAdjustRetryLoop( + workflow: Workflow, + prompt: PromptArtifact +) { + workflow.addEventListener("stepFinished", ({ data: { step } }) => { + const code = step.getOutputs()["code"]; + const state = step.getState(); + + if (state.kind === "FAILED") { + if (state.errorType === "MINOR") { + workflow.addRetryOfPreviousStep(); + } + return true; + } + + if (code === undefined || code.kind !== "code") return; + + if (code.value.type === "success") { + workflow.addStep(adjustToFeedbackStep, { + prompt, + code, + }); + } else { + workflow.addStep(fixCodeUntilItRunsStep, { + code, + }); + } + }); +} diff --git a/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts new file mode 100644 index 0000000000..ba0ec2dbb9 --- /dev/null +++ b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts @@ -0,0 +1,25 @@ +import { generateCodeStep } from "../steps/generateCodeStep.js"; +import { WorkflowTemplate } from "./ControlledWorkflow.js"; +import { fixAdjustRetryLoop } from "./controllers.js"; + +/** + * This is a basic workflow for generating Squiggle code. + * + * It generates code based on a prompt, fixes it if necessary, and tries to + * improve it based on feedback. + */ + +export const createSquiggleWorkflowTemplate = new WorkflowTemplate<{ + inputs: { + prompt: "prompt"; + }; + outputs: Record; +}>({ + name: "CreateSquiggle", + configureControllerLoop(workflow, inputs) { + fixAdjustRetryLoop(workflow, inputs.prompt); + }, + configureInitialSteps(workflow, inputs) { + workflow.addStep(generateCodeStep, { prompt: inputs.prompt }); + }, +}); diff --git a/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts new file mode 100644 index 0000000000..c703e76738 --- /dev/null +++ b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts @@ -0,0 +1,21 @@ +import { PromptArtifact } from "../Artifact.js"; +import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; +import { WorkflowTemplate } from "./ControlledWorkflow.js"; +import { fixAdjustRetryLoop } from "./controllers.js"; + +export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ + inputs: { + source: "source"; + }; + outputs: Record; +}>({ + name: "FixSquiggle", + configureControllerLoop(workflow) { + // TODO - cache the prompt artifact once? maybe even as a global variable + // (but it's better to just refactor steps to make the prompt optional, somehow) + fixAdjustRetryLoop(workflow, new PromptArtifact("")); + }, + configureInitialSteps(workflow, inputs) { + workflow.addStep(runAndFormatCodeStep, { source: inputs.source }); + }, +}); diff --git a/packages/ai/src/workflows/registry.ts b/packages/ai/src/workflows/registry.ts index 633d856e94..ddc90888eb 100644 --- a/packages/ai/src/workflows/registry.ts +++ b/packages/ai/src/workflows/registry.ts @@ -1,8 +1,6 @@ import { WorkflowTemplate } from "./ControlledWorkflow.js"; -import { - createSquiggleWorkflowTemplate, - fixSquiggleWorkflowTemplate, -} from "./SquiggleWorkflow.js"; +import { createSquiggleWorkflowTemplate } from "./createSquiggleWorkflowTemplate.js"; +import { fixSquiggleWorkflowTemplate } from "./fixSquiggleWorkflowTemplate.js"; export function getWorkflowTemplateByName(name: string) { const workflows: Record> = Object.fromEntries( diff --git a/packages/hub/src/app/ai/Sidebar.tsx b/packages/hub/src/app/ai/Sidebar.tsx index 65aa6f94fc..499bb0f469 100644 --- a/packages/hub/src/app/ai/Sidebar.tsx +++ b/packages/hub/src/app/ai/Sidebar.tsx @@ -17,7 +17,7 @@ import { TextAreaFormField, } from "@quri/ui"; -import { CreateRequestBody } from "./utils"; +import { AiRequestBody } from "./utils"; import { WorkflowSummaryList } from "./WorkflowSummaryList"; type Handle = { @@ -25,7 +25,7 @@ type Handle = { }; type Props = { - submitWorkflow: (requestBody: CreateRequestBody) => void; + submitWorkflow: (requestBody: AiRequestBody) => void; selectWorkflow: (id: string) => void; selectedWorkflow: ClientWorkflow | undefined; workflows: ClientWorkflow[]; @@ -80,12 +80,19 @@ Outputs: })); const handleSubmit = form.handleSubmit( - async ({ prompt, squiggleCode, model }, event) => { - const requestBody: CreateRequestBody = { - prompt: mode === "create" ? prompt : undefined, - squiggleCode: mode === "edit" ? squiggleCode : undefined, - model: model as LlmId, - }; + async ({ prompt, squiggleCode, model }) => { + const requestBody: AiRequestBody = + mode === "create" + ? { + kind: "create", + prompt, + model: model as LlmId, + } + : { + kind: "edit", + squiggleCode, + model: model as LlmId, + }; submitWorkflow(requestBody); form.setValue("prompt", ""); diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index fe96958bad..68e23b603f 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -14,7 +14,7 @@ import { getSelf, isSignedIn } from "@/graphql/helpers/userHelpers"; import { prisma } from "@/prisma"; import { getAiCodec, V2WorkflowData } from "../../serverUtils"; -import { createRequestBodySchema, requestToInput } from "../../utils"; +import { aiRequestBodySchema } from "../../utils"; // https://nextjs.org/docs/app/api-reference/file-conventions/route-segment-config#maxduration export const maxDuration = 300; @@ -63,11 +63,7 @@ export async function POST(req: Request) { try { const body = await req.json(); - const request = createRequestBodySchema.parse(body); - - if (!request.prompt && !request.squiggleCode) { - throw new Error("Prompt or Squiggle code is required"); - } + const request = aiRequestBodySchema.parse(body); // Create a SquiggleWorkflow instance const llmConfig: LlmConfig = { @@ -77,16 +73,15 @@ export async function POST(req: Request) { messagesInHistoryToKeep: 4, }; - const input = requestToInput(request); const openaiApiKey = process.env["OPENAI_API_KEY"]; const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; const squiggleWorkflow = - input.type === "Create" + request.kind === "create" ? createSquiggleWorkflowTemplate.instantiate({ llmConfig, inputs: { - prompt: new PromptArtifact(input.prompt), + prompt: new PromptArtifact(request.prompt), }, abortSignal: req.signal, openaiApiKey, @@ -95,7 +90,7 @@ export async function POST(req: Request) { : fixSquiggleWorkflowTemplate.instantiate({ llmConfig, inputs: { - source: new SourceArtifact(input.source), + source: new SourceArtifact(request.squiggleCode), }, abortSignal: req.signal, openaiApiKey, diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index 5f2e515ae2..c333c8d99b 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -24,6 +24,13 @@ export default async function AiPage() { case 1: return clientWorkflowSchema.parse(row.workflow); case 2: { + /* + * Here we go from SerializedWorkflow to Workflow to ClientWorkflow. + * + * TODO: Instead, we could go directly from SerializedWorkflow to + * ClientWorkflow (useful especially if workflow implementation is + * deprecated, so we can't resume it but still want to show it). + */ const { bundle, entrypoint } = v2WorkflowDataSchema.parse( row.workflow ); diff --git a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx index fef1aacdc1..20d8a31327 100644 --- a/packages/hub/src/app/ai/useSquiggleWorkflows.tsx +++ b/packages/hub/src/app/ai/useSquiggleWorkflows.tsx @@ -1,12 +1,8 @@ import { useCallback, useState } from "react"; -import { - ClientWorkflow, - decodeWorkflowFromReader, - SquiggleWorkflowInput, -} from "@quri/squiggle-ai"; +import { ClientWorkflow, decodeWorkflowFromReader } from "@quri/squiggle-ai"; -import { bodyToLineReader, CreateRequestBody, requestToInput } from "./utils"; +import { AiRequestBody, bodyToLineReader } from "./utils"; export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { const [workflows, setWorkflows] = @@ -25,7 +21,7 @@ export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { ); const addMockWorkflow = useCallback( - (input: SquiggleWorkflowInput) => { + (request: AiRequestBody) => { // This will be replaced with a real workflow once we receive the first message from the server. const id = `loading-${Date.now().toString()}`; const workflow: ClientWorkflow = { @@ -36,7 +32,7 @@ export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { prompt: { id: "prompt", kind: "prompt", - value: input.prompt ?? "[FIX]", + value: request.kind === "create" ? request.prompt : "[FIX]", }, }, steps: [], @@ -49,12 +45,10 @@ export function useSquiggleWorkflows(initialWorkflows: ClientWorkflow[]) { ); const submitWorkflow = useCallback( - async (request: CreateRequestBody) => { - const input = requestToInput(request); - + async (request: AiRequestBody) => { // Add a mock workflow to show loading state while we wait for the server to respond. // It will be replaced by the real workflow once we receive the first message from the server. - let id = addMockWorkflow(requestToInput(request)).id; + let id = addMockWorkflow(request).id; try { const response = await fetch("/ai/api/create", { diff --git a/packages/hub/src/app/ai/utils.ts b/packages/hub/src/app/ai/utils.ts index a7242e6855..35704d322d 100644 --- a/packages/hub/src/app/ai/utils.ts +++ b/packages/hub/src/app/ai/utils.ts @@ -1,10 +1,6 @@ import { z } from "zod"; -import { - LlmId, - MODEL_CONFIGS, - type SquiggleWorkflowInput, -} from "@quri/squiggle-ai"; +import { LlmId, MODEL_CONFIGS } from "@quri/squiggle-ai"; // SquiggleWorkflow input @@ -24,21 +20,24 @@ type UnionToTuple = type ModelKeys = UnionToTuple; -export const createRequestBodySchema = z.object({ - prompt: z.string().optional(), - squiggleCode: z.string().optional(), +const commonRequestFields = { model: z.enum(MODEL_CONFIGS.map((model) => model.id) as ModelKeys).optional(), -}); +}; -export type CreateRequestBody = z.infer; +export const aiRequestBodySchema = z.discriminatedUnion("kind", [ + z.object({ + kind: z.literal("create"), + ...commonRequestFields, + prompt: z.string(), + }), + z.object({ + kind: z.literal("edit"), + ...commonRequestFields, + squiggleCode: z.string(), + }), +]); -export function requestToInput( - request: CreateRequestBody -): SquiggleWorkflowInput { - return request.squiggleCode - ? { type: "Edit", source: request.squiggleCode } - : { type: "Create", prompt: request.prompt ?? "" }; -} +export type AiRequestBody = z.infer; // Convert a ReadableStream (`response.body` from `fetch()`) to a line-by-line reader. export function bodyToLineReader(stream: ReadableStream) { From e5a532b36d965cbb93f89a20962d8efc265a45b8 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Thu, 10 Oct 2024 15:54:01 -0300 Subject: [PATCH 06/11] step -> workflow reference is back --- packages/ai/README.md | 2 - packages/ai/src/Artifact.ts | 2 +- packages/ai/src/LLMStepInstance.ts | 225 +++++++++--------- packages/ai/src/serialization.ts | 14 +- packages/ai/src/workflows/Workflow.ts | 64 +++-- ...trolledWorkflow.ts => WorkflowTemplate.ts} | 6 +- .../createSquiggleWorkflowTemplate.ts | 2 +- .../workflows/fixSquiggleWorkflowTemplate.ts | 2 +- packages/ai/src/workflows/registry.ts | 2 +- packages/ai/src/workflows/streaming.ts | 2 +- 10 files changed, 174 insertions(+), 147 deletions(-) rename packages/ai/src/workflows/{ControlledWorkflow.ts => WorkflowTemplate.ts} (89%) diff --git a/packages/ai/README.md b/packages/ai/README.md index f6ac686b23..eeb5eb8b2a 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -6,8 +6,6 @@ The example frontend that uses it is implemented in [Squiggle Hub](https://squig Note that it can take 20s-2min to run a workflow and get a response from the LLM. -The key file is 'src/workflows/SquiggleWorkflow.ts'. This file contains the definition of a common workflow. It can be used with or without streaming to run the workflow with the given parameters; see `ControlledWorkflow` class API for more details. - After runs are complete, the results are saved to the 'logs' folder. These are saved as Markdown files. It's recommended to use the VSCode Markdown Preview plugin or similar to view the results. In these logs, note the "expand" arrows on the left of some items - these can be clicked to expand the item and see more details. ## Use diff --git a/packages/ai/src/Artifact.ts b/packages/ai/src/Artifact.ts index 926408f7dd..cd261001e0 100644 --- a/packages/ai/src/Artifact.ts +++ b/packages/ai/src/Artifact.ts @@ -55,7 +55,7 @@ type ArtifactValue = Extract< export function makeArtifact( kind: T, value: ArtifactValue, - createdBy: LLMStepInstance + createdBy: LLMStepInstance ): Extract { // sorry for the type casting, TypeScript is not smart enough to infer the type switch (kind) { diff --git a/packages/ai/src/LLMStepInstance.ts b/packages/ai/src/LLMStepInstance.ts index 211b61992a..6617236bff 100644 --- a/packages/ai/src/LLMStepInstance.ts +++ b/packages/ai/src/LLMStepInstance.ts @@ -23,7 +23,7 @@ import { import { getStepTemplateByName } from "./steps/registry.js"; import { Workflow } from "./workflows/Workflow.js"; -interface Params { +export type StepParams = { id: string; sequentialId: number; template: LLMStepTemplate; @@ -34,25 +34,35 @@ interface Params { startTime: number; conversationMessages: Message[]; llmMetricsList: LlmMetrics[]; -} +}; -export class LLMStepInstance { - public id: Params["id"]; - public sequentialId: number; - public readonly template: Params["template"]; +export class LLMStepInstance< + const Shape extends IOShape = IOShape, + const WorkflowShape extends IOShape = IOShape, +> { + public id: StepParams["id"]; + public sequentialId: StepParams["sequentialId"]; + public readonly template: StepParams["template"]; - private state: Params["state"]; - private outputs: Params["outputs"]; - public readonly inputs: Params["inputs"]; + private state: StepParams["state"]; + private outputs: StepParams["outputs"]; + public readonly inputs: StepParams["inputs"]; - public retryingStep?: Params["retryingStep"]; + public retryingStep?: StepParams["retryingStep"]; - private startTime: Params["startTime"]; + private startTime: StepParams["startTime"]; + private conversationMessages: StepParams["conversationMessages"]; + public llmMetricsList: StepParams["llmMetricsList"]; + + // These two fields are not serialized private logger: Logger; - private conversationMessages: Params["conversationMessages"]; - public llmMetricsList: Params["llmMetricsList"]; + private workflow: Workflow; - private constructor(params: Params) { + private constructor( + params: StepParams & { + workflow: Workflow; + } + ) { this.id = params.id; this.sequentialId = params.sequentialId; @@ -67,17 +77,18 @@ export class LLMStepInstance { this.inputs = params.inputs; this.retryingStep = params.retryingStep; + this.workflow = params.workflow; this.logger = new Logger(); } // Create a new, PENDING step instance - static create(params: { + static create(params: { template: LLMStepInstance["template"]; inputs: LLMStepInstance["inputs"]; retryingStep: LLMStepInstance["retryingStep"]; - workflow: Workflow; - }): LLMStepInstance { - return new LLMStepInstance({ + workflow: Workflow; + }): LLMStepInstance { + return new LLMStepInstance({ id: crypto.randomUUID(), sequentialId: params.workflow.getStepCount(), conversationMessages: [], @@ -101,22 +112,22 @@ export class LLMStepInstance { return this.conversationMessages; } - async _run(workflow: Workflow) { + async _run() { if (this.state.kind !== "PENDING") { return; } - const limits = workflow.checkResourceLimits(); + const limits = this.workflow.checkResourceLimits(); if (limits) { - this.fail("CRITICAL", limits, workflow); + this.fail("CRITICAL", limits); return; } const executeContext: ExecuteContext = { - setOutput: (key, value) => this.setOutput(key, value, workflow), - log: (log) => this.log(log, workflow), - queryLLM: (promptPair) => this.queryLLM(promptPair, workflow), - fail: (errorType, message) => this.fail(errorType, message, workflow), + setOutput: (key, value) => this.setOutput(key, value), + log: (log) => this.log(log), + queryLLM: (promptPair) => this.queryLLM(promptPair), + fail: (errorType, message) => this.fail(errorType, message), }; try { @@ -124,8 +135,7 @@ export class LLMStepInstance { } catch (error) { this.fail( "MINOR", - error instanceof Error ? error.message : String(error), - workflow + error instanceof Error ? error.message : String(error) ); return; } @@ -137,28 +147,22 @@ export class LLMStepInstance { } } - async run(workflow: Workflow) { - this.log( - { - type: "info", - message: `Step "${this.template.name}" started`, - }, - workflow - ); + async run() { + this.log({ + type: "info", + message: `Step "${this.template.name}" started`, + }); - await this._run(workflow); + await this._run(); const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${ this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s` }`; - this.log( - { - type: "info", - message: completionMessage, - }, - workflow - ); + this.log({ + type: "info", + message: completionMessage, + }); } getState() { @@ -209,19 +213,14 @@ export class LLMStepInstance { // private methods - private setOutput< - K extends Extract, - WorkflowShape extends IOShape, - >( + private setOutput>( key: K, - value: Outputs[K] | Outputs[K]["value"], - workflow: Workflow + value: Outputs[K] | Outputs[K]["value"] ): void { if (key in this.outputs) { this.fail( "CRITICAL", - `Output ${key} is already set. This is a bug with the workflow code.`, - workflow + `Output ${key} is already set. This is a bug with the workflow code.` ); return; } @@ -237,22 +236,15 @@ export class LLMStepInstance { } } - private log( - log: LogEntry, - workflow: Workflow - ): void { + private log(log: LogEntry): void { this.logger.log(log, { - workflowId: workflow.id, + workflowId: this.workflow.id, stepIndex: this.sequentialId, }); } - private fail( - errorType: ErrorType, - message: string, - workflow: Workflow - ) { - this.log({ type: "error", message }, workflow); + private fail(errorType: ErrorType, message: string) { + this.log({ type: "error", message }); this.state = { kind: "FAILED", durationMs: this.calculateDuration(), @@ -269,48 +261,39 @@ export class LLMStepInstance { this.conversationMessages.push(message); } - private async queryLLM( - promptPair: PromptPair, - workflow: Workflow - ): Promise { + private async queryLLM(promptPair: PromptPair): Promise { try { const messagesToSend: Message[] = [ - ...workflow.getRelevantPreviousConversationMessages( - workflow.llmConfig.messagesInHistoryToKeep + ...this.workflow.getRelevantPreviousConversationMessages( + this.workflow.llmConfig.messagesInHistoryToKeep ), { role: "user", content: promptPair.fullPrompt, }, ]; - const completion = await workflow.llmClient.run(messagesToSend); - - this.log( - { - type: "llmResponse", - response: completion, - content: completion.content, - messages: messagesToSend, - prompt: promptPair.fullPrompt, - }, - workflow - ); + const completion = await this.workflow.llmClient.run(messagesToSend); + + this.log({ + type: "llmResponse", + response: completion, + content: completion.content, + messages: messagesToSend, + prompt: promptPair.fullPrompt, + }); this.llmMetricsList.push({ apiCalls: 1, inputTokens: completion?.usage?.prompt_tokens ?? 0, outputTokens: completion?.usage?.completion_tokens ?? 0, - llmId: workflow.llmConfig.llmId, + llmId: this.workflow.llmConfig.llmId, }); if (!completion?.content) { - this.log( - { - type: "error", - message: "Received an empty response from the API", - }, - workflow - ); + this.log({ + type: "error", + message: "Received an empty response from the API", + }); return null; } else { this.addConversationMessage({ @@ -328,8 +311,7 @@ export class LLMStepInstance { } catch (error) { this.fail( "MINOR", - `Error in queryLLM: ${error instanceof Error ? error.message : error}`, - workflow + `Error in queryLLM: ${error instanceof Error ? error.message : error}` ); return null; } @@ -337,34 +319,33 @@ export class LLMStepInstance { // Serialization/deserialization - serialize(visitor: AiSerializationVisitor): SerializedStep { + // StepParams don't contain the workflow reference, to to avoid circular dependencies + toParams(): StepParams { return { id: this.id, sequentialId: this.sequentialId, - templateName: this.template.name, + template: this.template, state: this.state, + inputs: this.inputs, + outputs: this.outputs, + retryingStep: this.retryingStep, startTime: this.startTime, conversationMessages: this.conversationMessages, llmMetricsList: this.llmMetricsList, - inputIds: Object.fromEntries( - Object.entries(this.inputs).map(([key, input]) => [ - key, - visitor.artifact(input), - ]) - ), - outputIds: Object.fromEntries( - Object.entries(this.outputs).map(([key, output]) => [ - key, - visitor.artifact(output), - ]) - ), }; } + static fromParams( + params: StepParams, + workflow: Workflow + ): LLMStepInstance { + return new LLMStepInstance({ ...params, workflow }); + } + static deserialize( { templateName, inputIds, outputIds, ...params }: SerializedStep, visitor: AiDeserializationVisitor - ): LLMStepInstance { + ): StepParams { const template: LLMStepTemplate = getStepTemplateByName(templateName); const inputs = Object.fromEntries( Object.entries(inputIds).map(([name, inputId]) => [ @@ -379,17 +360,45 @@ export class LLMStepInstance { ]) ); - return new LLMStepInstance({ + return { ...params, template, inputs, outputs, - }); + }; } } +export function serializeStepParams( + params: StepParams, + visitor: AiSerializationVisitor +) { + return { + id: params.id, + sequentialId: params.sequentialId, + templateName: params.template.name, + state: params.state, + startTime: params.startTime, + conversationMessages: params.conversationMessages, + llmMetricsList: params.llmMetricsList, + inputIds: Object.fromEntries( + Object.entries(params.inputs).map(([key, input]) => [ + key, + visitor.artifact(input), + ]) + ), + outputIds: Object.fromEntries( + Object.entries(params.outputs) + .map(([key, output]) => + output ? [key, visitor.artifact(output)] : undefined + ) + .filter((x) => x !== undefined) + ), + }; +} + export type SerializedStep = Omit< - Params, + StepParams, // TODO - serialize retryingStep reference "inputs" | "outputs" | "template" | "retryingStep" > & { diff --git a/packages/ai/src/serialization.ts b/packages/ai/src/serialization.ts index 5bb6f68c77..7723eaf3c9 100644 --- a/packages/ai/src/serialization.ts +++ b/packages/ai/src/serialization.ts @@ -7,12 +7,19 @@ import { serializeArtifact, SerializedArtifact, } from "./Artifact.js"; -import { LLMStepInstance, SerializedStep } from "./LLMStepInstance.js"; +import { + LLMStepInstance, + SerializedStep, + serializeStepParams, + StepParams, +} from "./LLMStepInstance.js"; +import { getWorkflowTemplateByName } from "./workflows/registry.js"; import { SerializedWorkflow, Workflow } from "./workflows/Workflow.js"; type AiShape = { workflow: [Workflow, SerializedWorkflow]; - step: [LLMStepInstance, SerializedStep]; + // we serialize StepParams instead of LLMStepInstance to avoid circular dependencies; steps reference workflows + step: [StepParams, SerializedStep]; artifact: [Artifact, SerializedArtifact]; }; @@ -34,10 +41,11 @@ export function makeAiCodec(params: { visitor, openaiApiKey: params.openaiApiKey, anthropicApiKey: params.anthropicApiKey, + getWorkflowTemplateByName, }), }, step: { - serialize: (node, visitor) => node.serialize(visitor), + serialize: (node, visitor) => serializeStepParams(node, visitor), deserialize: (node, visitor) => LLMStepInstance.deserialize(node, visitor), }, diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index c87701b756..2caf3fc7cf 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -16,16 +16,15 @@ import { AiSerializationVisitor, } from "../serialization.js"; import { ClientWorkflow, ClientWorkflowResult } from "../types.js"; -import { - WorkflowInstanceParams, - WorkflowTemplate, -} from "./ControlledWorkflow.js"; -import { getWorkflowTemplateByName } from "./registry.js"; import { addStreamingListeners, artifactToClientArtifact, stepToClientStep, } from "./streaming.js"; +import { + type WorkflowInstanceParams, + type WorkflowTemplate, +} from "./WorkflowTemplate.js"; export type LlmConfig = { llmId: LlmId; @@ -49,19 +48,19 @@ export type WorkflowEventShape = | { type: "stepAdded"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { type: "stepStarted"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { type: "stepFinished"; payload: { - step: LLMStepInstance; + step: LLMStepInstance; }; } | { @@ -94,9 +93,6 @@ export type WorkflowEventListener< * * It does not make any assumptions about the steps themselves, it just * provides a way to add them and interact with them. - * - * See `ControlledWorkflow` for a common base class that controls the workflow - * by injecting new steps based on events. */ const MAX_RETRIES = 5; @@ -109,7 +105,7 @@ export class Workflow { public llmConfig: LlmConfig; public startTime: number; - private steps: LLMStepInstance[]; + private steps: LLMStepInstance[]; public llmClient: LLMClient; @@ -181,9 +177,9 @@ export class Workflow { template: LLMStepTemplate, inputs: Inputs, options?: { retryingStep?: LLMStepInstance } - ): LLMStepInstance { + ): LLMStepInstance { // sorry for "any"; contravariance issues - const step: LLMStepInstance = LLMStepInstance.create({ + const step: LLMStepInstance = LLMStepInstance.create({ template, inputs, retryingStep: options?.retryingStep, @@ -193,7 +189,7 @@ export class Workflow { this.steps.push(step); this.dispatchEvent({ type: "stepAdded", - payload: { step }, + payload: { step: step as LLMStepInstance }, }); return step; } @@ -209,7 +205,9 @@ export class Workflow { return; } - this.addStep(retryingStep.template, retryingStep.inputs, { retryingStep }); + this.addStep(retryingStep.template, retryingStep.inputs, { + retryingStep: retryingStep as LLMStepInstance, + }); } public getCurrentRetryAttempts(stepId: string): number { @@ -226,13 +224,13 @@ export class Workflow { this.dispatchEvent({ // should we fire this after `run()` is called? type: "stepStarted", - payload: { step }, + payload: { step: step as LLMStepInstance }, }); - await step.run(this); + await step.run(); this.dispatchEvent({ type: "stepFinished", - payload: { step }, + payload: { step: step as LLMStepInstance }, }); } @@ -258,7 +256,7 @@ export class Workflow { return undefined; } - getSteps(): LLMStepInstance[] { + getSteps(): LLMStepInstance[] { return this.steps; } @@ -266,7 +264,7 @@ export class Workflow { return this.steps.length; } - private getCurrentStep(): LLMStepInstance | undefined { + private getCurrentStep(): LLMStepInstance | undefined { return this.steps.at(-1); } @@ -345,7 +343,9 @@ export class Workflow { return { totalPrice, llmRunCount }; } - private getMessagesFromSteps(steps: LLMStepInstance[]): Message[] { + private getMessagesFromSteps( + steps: LLMStepInstance[] + ): Message[] { return steps.flatMap((step) => step.getConversationMessages()); } @@ -415,7 +415,7 @@ export class Workflow { visitor.artifact(input), ]) ), - stepIds: this.steps.map(visitor.step), + stepIds: this.steps.map((step) => visitor.step(step.toParams())), llmConfig: this.llmConfig, }; } @@ -425,13 +425,16 @@ export class Workflow { visitor, openaiApiKey, anthropicApiKey, + getWorkflowTemplateByName, }: { node: SerializedWorkflow; visitor: AiDeserializationVisitor; openaiApiKey?: string; anthropicApiKey?: string; + // can't be imported from workflow registry because of circular dependency + getWorkflowTemplateByName: (name: string) => WorkflowTemplate; }): Workflow { - return new Workflow({ + const workflow = new Workflow({ id: node.id, template: getWorkflowTemplateByName(node.templateName), inputs: Object.fromEntries( @@ -441,10 +444,17 @@ export class Workflow { ]) ), llmConfig: node.llmConfig, - steps: node.stepIds.map(visitor.step), + steps: [], openaiApiKey, anthropicApiKey, }); + + // restore steps and create back references from steps to workflow + workflow.steps = node.stepIds + .map(visitor.step) + .map((params) => LLMStepInstance.fromParams(params, workflow)); + + return workflow; } // Client-side representation @@ -452,7 +462,9 @@ export class Workflow { return { id: this.id, timestamp: this.startTime, - steps: this.steps.map(stepToClientStep), + steps: this.steps.map((step) => + stepToClientStep(step as LLMStepInstance) + ), currentStep: this.getCurrentStep()?.id, ...(this.isProcessComplete() ? { diff --git a/packages/ai/src/workflows/ControlledWorkflow.ts b/packages/ai/src/workflows/WorkflowTemplate.ts similarity index 89% rename from packages/ai/src/workflows/ControlledWorkflow.ts rename to packages/ai/src/workflows/WorkflowTemplate.ts index d24a9fbc59..50c8be8ea4 100644 --- a/packages/ai/src/workflows/ControlledWorkflow.ts +++ b/packages/ai/src/workflows/WorkflowTemplate.ts @@ -1,12 +1,12 @@ import { LLMStepInstance } from "../LLMStepInstance.js"; import { Inputs, IOShape } from "../LLMStepTemplate.js"; -import { LlmConfig, Workflow } from "./Workflow.js"; +import { type LlmConfig, Workflow } from "./Workflow.js"; export type WorkflowInstanceParams = { id: string; template: WorkflowTemplate; inputs: Inputs; - steps: LLMStepInstance[]; + steps: LLMStepInstance[]; abortSignal?: AbortSignal; // TODO llmConfig?: LlmConfig; openaiApiKey?: string; @@ -18,7 +18,7 @@ export type WorkflowInstanceParams = { * * It works similarly to LLMStepTemplate, but for workflows. */ -export class WorkflowTemplate { +export class WorkflowTemplate { public readonly name: string; public configureControllerLoop: ( diff --git a/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts index ba0ec2dbb9..89cbbf6c81 100644 --- a/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts +++ b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts @@ -1,6 +1,6 @@ import { generateCodeStep } from "../steps/generateCodeStep.js"; -import { WorkflowTemplate } from "./ControlledWorkflow.js"; import { fixAdjustRetryLoop } from "./controllers.js"; +import { WorkflowTemplate } from "./WorkflowTemplate.js"; /** * This is a basic workflow for generating Squiggle code. diff --git a/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts index c703e76738..1f4a4022b3 100644 --- a/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts +++ b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts @@ -1,7 +1,7 @@ import { PromptArtifact } from "../Artifact.js"; import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; -import { WorkflowTemplate } from "./ControlledWorkflow.js"; import { fixAdjustRetryLoop } from "./controllers.js"; +import { WorkflowTemplate } from "./WorkflowTemplate.js"; export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ inputs: { diff --git a/packages/ai/src/workflows/registry.ts b/packages/ai/src/workflows/registry.ts index ddc90888eb..b1a54d24e0 100644 --- a/packages/ai/src/workflows/registry.ts +++ b/packages/ai/src/workflows/registry.ts @@ -1,6 +1,6 @@ -import { WorkflowTemplate } from "./ControlledWorkflow.js"; import { createSquiggleWorkflowTemplate } from "./createSquiggleWorkflowTemplate.js"; import { fixSquiggleWorkflowTemplate } from "./fixSquiggleWorkflowTemplate.js"; +import { type WorkflowTemplate } from "./WorkflowTemplate.js"; export function getWorkflowTemplateByName(name: string) { const workflows: Record> = Object.fromEntries( diff --git a/packages/ai/src/workflows/streaming.ts b/packages/ai/src/workflows/streaming.ts index 9ff9e8232f..230780be0b 100644 --- a/packages/ai/src/workflows/streaming.ts +++ b/packages/ai/src/workflows/streaming.ts @@ -73,7 +73,7 @@ export function stepToClientStep(step: LLMStepInstance): ClientStep { * * Results are streamed as JSON-encoded lines. * - * `ControlledWorkflow.runAsStream()` relies on this function; see its + * `Workflow.runAsStream()` relies on this function; see its * implementation for more details. */ export function addStreamingListeners( From e9420a87610addfa1bb7d1fe9a539aae6a21f096 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Fri, 11 Oct 2024 13:40:08 -0300 Subject: [PATCH 07/11] update comment --- packages/hub/prisma/schema.prisma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/hub/prisma/schema.prisma b/packages/hub/prisma/schema.prisma index 4b1b2bf3b6..a45439d92d 100644 --- a/packages/hub/prisma/schema.prisma +++ b/packages/hub/prisma/schema.prisma @@ -394,7 +394,7 @@ model AiWorkflow { user User @relation(fields: [userId], references: [id], onDelete: Cascade) userId String - // v1: SerializedWorkflow + // v1: ClientWorkflow // v2: normalized bundle format Int @default(2) From d6dfe2b388ab4c0abf718b3df0ee849eb55a7a27 Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 12 Oct 2024 14:00:47 -0300 Subject: [PATCH 08/11] change workflow format column to enum; migration script --- .../migration.sql | 5 - .../migration.sql | 6 + packages/hub/prisma/schema.prisma | 13 +- packages/hub/src/app/ai/api/create/route.ts | 7 +- packages/hub/src/app/ai/page.tsx | 47 +------ packages/hub/src/app/ai/serverUtils.ts | 26 ---- .../20241012155427_workflow_format.ts | 24 ++++ packages/hub/src/migrations/README.md | 9 ++ packages/hub/src/prisma.ts | 2 + packages/hub/src/server/ai/storage.ts | 34 +++++ packages/hub/src/server/ai/utils.ts | 12 ++ packages/hub/src/server/ai/v1_0.ts | 122 ++++++++++++++++++ packages/hub/src/server/ai/v2_0.ts | 38 ++++++ 13 files changed, 267 insertions(+), 78 deletions(-) delete mode 100644 packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql create mode 100644 packages/hub/prisma/migrations/20241012155427_workflow_format/migration.sql delete mode 100644 packages/hub/src/app/ai/serverUtils.ts create mode 100644 packages/hub/src/migrations/20241012155427_workflow_format.ts create mode 100644 packages/hub/src/migrations/README.md create mode 100644 packages/hub/src/server/ai/storage.ts create mode 100644 packages/hub/src/server/ai/utils.ts create mode 100644 packages/hub/src/server/ai/v1_0.ts create mode 100644 packages/hub/src/server/ai/v2_0.ts diff --git a/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql b/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql deleted file mode 100644 index dfedaddc1e..0000000000 --- a/packages/hub/prisma/migrations/20240928161949_workflow_format/migration.sql +++ /dev/null @@ -1,5 +0,0 @@ --- AlterTable -ALTER TABLE "AiWorkflow" ADD COLUMN "format" INTEGER NOT NULL DEFAULT 2; - --- old workflows -UPDATE "AiWorkflow" SET "format" = 1; diff --git a/packages/hub/prisma/migrations/20241012155427_workflow_format/migration.sql b/packages/hub/prisma/migrations/20241012155427_workflow_format/migration.sql new file mode 100644 index 0000000000..3ac026d157 --- /dev/null +++ b/packages/hub/prisma/migrations/20241012155427_workflow_format/migration.sql @@ -0,0 +1,6 @@ +-- CreateEnum +CREATE TYPE "AiWorkflowFormat" AS ENUM ('V1_0', 'V2_0'); + +-- AlterTable +ALTER TABLE "AiWorkflow" ADD COLUMN "format" "AiWorkflowFormat" NOT NULL DEFAULT 'V1_0', +ADD COLUMN "markdown" TEXT NOT NULL DEFAULT ''; diff --git a/packages/hub/prisma/schema.prisma b/packages/hub/prisma/schema.prisma index a45439d92d..b45c0ee196 100644 --- a/packages/hub/prisma/schema.prisma +++ b/packages/hub/prisma/schema.prisma @@ -385,6 +385,11 @@ model Searchable { groupId String? @unique } +enum AiWorkflowFormat { + V1_0 // ClientWorkflow JSON, legacy format - can't be deserialized back to Workflow + V2_0 // SerializedWorkflow JSON, can be deserialized back to Workflow and resumed +} + model AiWorkflow { id String @id @default(cuid()) @@ -394,9 +399,11 @@ model AiWorkflow { user User @relation(fields: [userId], references: [id], onDelete: Cascade) userId String - // v1: ClientWorkflow - // v2: normalized bundle - format Int @default(2) + // TODO - upgrade to V2_0 + format AiWorkflowFormat @default(V1_0) workflow Json + + // TODO - remove default + markdown String @default("") } diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index 68e23b603f..d85cabf0ca 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -12,8 +12,9 @@ import { import { authOptions } from "@/app/api/auth/[...nextauth]/authOptions"; import { getSelf, isSignedIn } from "@/graphql/helpers/userHelpers"; import { prisma } from "@/prisma"; +import { getAiCodec } from "@/server/ai/utils"; +import { V2WorkflowData } from "@/server/ai/v2_0"; -import { getAiCodec, V2WorkflowData } from "../../serverUtils"; import { aiRequestBodySchema } from "../../utils"; // https://nextjs.org/docs/app/api-reference/file-conventions/route-segment-config#maxduration @@ -38,7 +39,7 @@ async function upsertWorkflow( id: workflow.id, }, update: { - format: 2, + format: "V2_0", workflow: v2Workflow, }, create: { @@ -46,7 +47,7 @@ async function upsertWorkflow( user: { connect: { id: user.id }, }, - format: 2, + format: "V2_0", workflow: v2Workflow, }, }); diff --git a/packages/hub/src/app/ai/page.tsx b/packages/hub/src/app/ai/page.tsx index c333c8d99b..d20d78847e 100644 --- a/packages/hub/src/app/ai/page.tsx +++ b/packages/hub/src/app/ai/page.tsx @@ -1,59 +1,24 @@ -import { ClientWorkflow, clientWorkflowSchema } from "@quri/squiggle-ai"; +import { ClientWorkflow } from "@quri/squiggle-ai"; import { prisma } from "@/prisma"; import { getUserOrRedirect } from "@/server/helpers"; +import { decodeDbWorkflowToClientWorkflow } from "../../server/ai/storage"; import { AiDashboard } from "./AiDashboard"; -import { getAiCodec, v2WorkflowDataSchema } from "./serverUtils"; export default async function AiPage() { const user = await getUserOrRedirect(); const rows = await prisma.aiWorkflow.findMany({ - orderBy: { - createdAt: "desc", - }, + orderBy: { createdAt: "desc" }, where: { user: { email: user.email }, }, }); - const workflows: ClientWorkflow[] = rows.map((row) => { - try { - switch (row.format) { - case 1: - return clientWorkflowSchema.parse(row.workflow); - case 2: { - /* - * Here we go from SerializedWorkflow to Workflow to ClientWorkflow. - * - * TODO: Instead, we could go directly from SerializedWorkflow to - * ClientWorkflow (useful especially if workflow implementation is - * deprecated, so we can't resume it but still want to show it). - */ - const { bundle, entrypoint } = v2WorkflowDataSchema.parse( - row.workflow - ); - const codec = getAiCodec(); - const deserializer = codec.makeDeserializer(bundle); - const workflow = deserializer.deserialize(entrypoint); - - return workflow.asClientWorkflow(); - } - default: - throw new Error(`Unknown workflow format: ${row.format}`); - } - } catch (e) { - return { - id: row.id, - timestamp: row.createdAt.getTime(), - status: "error", - inputs: {}, - steps: [], - result: `Invalid workflow format in the database: ${e}`, - } satisfies ClientWorkflow; - } - }); + const workflows: ClientWorkflow[] = rows.map((row) => + decodeDbWorkflowToClientWorkflow(row) + ); return ; } diff --git a/packages/hub/src/app/ai/serverUtils.ts b/packages/hub/src/app/ai/serverUtils.ts deleted file mode 100644 index a9deadb700..0000000000 --- a/packages/hub/src/app/ai/serverUtils.ts +++ /dev/null @@ -1,26 +0,0 @@ -import "server-only"; - -import { z } from "zod"; - -import { makeAiCodec } from "@quri/squiggle-ai/server"; - -export function getAiCodec() { - const openaiApiKey = process.env["OPENROUTER_API_KEY"]; - const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; - return makeAiCodec({ - openaiApiKey, - anthropicApiKey, - }); -} - -// schema for serialized workflow format in the database -// this type is not precise but it's better than nothing -export const v2WorkflowDataSchema = z.object({ - entrypoint: z.object({ - entityType: z.literal("workflow"), - pos: z.number(), - }), - bundle: z.any(), -}); - -export type V2WorkflowData = z.infer; diff --git a/packages/hub/src/migrations/20241012155427_workflow_format.ts b/packages/hub/src/migrations/20241012155427_workflow_format.ts new file mode 100644 index 0000000000..b3e2daf6c7 --- /dev/null +++ b/packages/hub/src/migrations/20241012155427_workflow_format.ts @@ -0,0 +1,24 @@ +import { prisma } from "@/prisma"; + +export async function migrate() { + const v1Workflows = await prisma.aiWorkflow.findMany({ + where: { + format: "V1_0", + workflow: { + path: ["status"], + equals: "finished", + }, + }, + }); + + for (const workflow of v1Workflows) { + const markdown = + (workflow.workflow as any)?.["result"]?.["logSummary"] ?? ""; + await prisma.aiWorkflow.update({ + where: { id: workflow.id }, + data: { markdown: String(markdown) }, + }); + } +} + +migrate(); diff --git a/packages/hub/src/migrations/README.md b/packages/hub/src/migrations/README.md new file mode 100644 index 0000000000..5c42377a92 --- /dev/null +++ b/packages/hub/src/migrations/README.md @@ -0,0 +1,9 @@ +Prisma doesn't support migrations in JavaScript, so we have to run custom scripts outside of the Prisma migration system, when updating the database with SQL queries gets too complicated. + +# TODO + +We could implement a simple custom migration system, eventually; we'll need to: + +- store a flag that JS migration was applied in a custom table +- write a small wrapper that calls all JS migrations that exist and not applied yet +- integrate with Github Acitons - call a wrapper script that invokes both `prisma migrate` and a JS migration diff --git a/packages/hub/src/prisma.ts b/packages/hub/src/prisma.ts index 95291bb2bc..43a64248ea 100644 --- a/packages/hub/src/prisma.ts +++ b/packages/hub/src/prisma.ts @@ -1,3 +1,5 @@ +import "server-only"; + import { PrismaClient } from "@prisma/client"; // This config helps with connection leaks during hot reload diff --git a/packages/hub/src/server/ai/storage.ts b/packages/hub/src/server/ai/storage.ts new file mode 100644 index 0000000000..de7e6cb71f --- /dev/null +++ b/packages/hub/src/server/ai/storage.ts @@ -0,0 +1,34 @@ +import "server-only"; + +import { AiWorkflow as PrismaAiWorkflow } from "@prisma/client"; + +import { ClientWorkflow } from "@quri/squiggle-ai"; + +import { decodeV1_0JsonToClientWorkflow } from "@/server/ai/v1_0"; +import { decodeV2_0JsonToClientWorkflow } from "@/server/ai/v2_0"; + +export function decodeDbWorkflowToClientWorkflow( + row: PrismaAiWorkflow +): ClientWorkflow { + try { + switch (row.format) { + case "V1_0": + return decodeV1_0JsonToClientWorkflow(row.workflow); + case "V2_0": + return decodeV2_0JsonToClientWorkflow(row.workflow); + default: + throw new Error( + `Unknown workflow format: ${row.format satisfies never}` + ); + } + } catch (e) { + return { + id: row.id, + timestamp: row.createdAt.getTime(), + status: "error", + inputs: {}, + steps: [], + result: `Invalid workflow format in the database: ${e}`, + } satisfies ClientWorkflow; + } +} diff --git a/packages/hub/src/server/ai/utils.ts b/packages/hub/src/server/ai/utils.ts new file mode 100644 index 0000000000..9d6583f6f6 --- /dev/null +++ b/packages/hub/src/server/ai/utils.ts @@ -0,0 +1,12 @@ +import "server-only"; + +import { makeAiCodec } from "@quri/squiggle-ai/server"; + +export function getAiCodec() { + const openaiApiKey = process.env["OPENROUTER_API_KEY"]; + const anthropicApiKey = process.env["ANTHROPIC_API_KEY"]; + return makeAiCodec({ + openaiApiKey, + anthropicApiKey, + }); +} diff --git a/packages/hub/src/server/ai/v1_0.ts b/packages/hub/src/server/ai/v1_0.ts new file mode 100644 index 0000000000..159d2078f9 --- /dev/null +++ b/packages/hub/src/server/ai/v1_0.ts @@ -0,0 +1,122 @@ +import "server-only"; + +import { Prisma } from "@prisma/client"; +import { z } from "zod"; + +import { ClientWorkflow } from "@quri/squiggle-ai"; + +// Snapshot of ClientWorkflow schemas as it was at the time we upgraded to V2 +// These don't include inputs, which were added in V2. + +export const v1InputSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("Create"), + prompt: z.string(), + }), + z.object({ + type: z.literal("Edit"), + source: z.string(), + prompt: z.string().optional(), + }), +]); + +const v1CommonArtifactFields = { + id: z.string(), + createdBy: z.string().optional(), +}; + +const v1ArtifactSchema = z.discriminatedUnion("kind", [ + z.object({ + ...v1CommonArtifactFields, + kind: z.literal("prompt"), + value: z.string(), + }), + z.object({ + ...v1CommonArtifactFields, + kind: z.literal("source"), + value: z.string(), + }), + z.object({ + ...v1CommonArtifactFields, + kind: z.literal("code"), + value: z.string(), + ok: z.boolean(), + }), +]); + +const v1StepStateSchema = z.enum(["PENDING", "DONE", "FAILED"]); + +const v1MessageSchema = z.object({ + role: z.enum(["system", "user", "assistant"]), + content: z.string(), +}); + +const v1StepSchema = z.object({ + id: z.string(), + name: z.string(), + state: v1StepStateSchema, + inputs: z.record(z.string(), v1ArtifactSchema), + outputs: z.record(z.string(), v1ArtifactSchema), + messages: z.array(v1MessageSchema), +}); + +const commonV1WorkflowFields = { + id: z.string(), + timestamp: z.number(), // milliseconds since epoch + input: v1InputSchema, + steps: z.array(v1StepSchema), + currentStep: z.string().optional(), +}; + +const v1WorkflowResultSchema = z.object({ + code: z.string().describe("Squiggle code snippet"), + isValid: z.boolean(), + totalPrice: z.number(), + runTimeMs: z.number(), + llmRunCount: z.number(), + logSummary: z.string(), // markdown +}); + +export const v1WorkflowSchema = z.discriminatedUnion("status", [ + z.object({ + ...commonV1WorkflowFields, + status: z.literal("loading"), + result: z.undefined(), + }), + z.object({ + ...commonV1WorkflowFields, + status: z.literal("finished"), + result: v1WorkflowResultSchema, + }), + z.object({ + ...commonV1WorkflowFields, + status: z.literal("error"), + result: z.string(), + }), +]); + +export function decodeV1_0JsonToClientWorkflow( + json: Prisma.JsonValue +): ClientWorkflow { + const v1Workflow = v1WorkflowSchema.parse(json); + + return { + ...v1Workflow, + inputs: + v1Workflow.input.type === "Create" + ? { + prompt: { + value: v1Workflow.input.prompt, + kind: "prompt", + id: `${v1Workflow.id}-prompt`, + }, + } + : { + source: { + value: v1Workflow.input.source, + kind: "source", + id: `${v1Workflow.id}-source`, + }, + }, + }; +} diff --git a/packages/hub/src/server/ai/v2_0.ts b/packages/hub/src/server/ai/v2_0.ts new file mode 100644 index 0000000000..b21278adba --- /dev/null +++ b/packages/hub/src/server/ai/v2_0.ts @@ -0,0 +1,38 @@ +import "server-only"; + +import { Prisma } from "@prisma/client"; +import { z } from "zod"; + +import { ClientWorkflow } from "@quri/squiggle-ai"; + +import { getAiCodec } from "./utils"; + +// schema for serialized workflow format in the database +// this type is not precise but it's better than nothing +export const v2WorkflowDataSchema = z.object({ + entrypoint: z.object({ + entityType: z.literal("workflow"), + pos: z.number(), + }), + bundle: z.any(), +}); + +export type V2WorkflowData = z.infer; + +export function decodeV2_0JsonToClientWorkflow( + json: Prisma.JsonValue +): ClientWorkflow { + /* + * Here we go from SerializedWorkflow to Workflow to ClientWorkflow. + * + * TODO: Instead, we could go directly from SerializedWorkflow to + * ClientWorkflow (useful especially if workflow implementation is + * deprecated, so we can't resume it but still want to show it). + */ + const { bundle, entrypoint } = v2WorkflowDataSchema.parse(json); + const codec = getAiCodec(); + const deserializer = codec.makeDeserializer(bundle); + const workflow = deserializer.deserialize(entrypoint); + + return workflow.asClientWorkflow(); +} From 2097f15ff2a6e10e449a2dcf6668d237fa6a68bb Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 12 Oct 2024 14:07:37 -0300 Subject: [PATCH 09/11] remove server-only import --- packages/hub/src/prisma.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/hub/src/prisma.ts b/packages/hub/src/prisma.ts index 43a64248ea..d2e0956117 100644 --- a/packages/hub/src/prisma.ts +++ b/packages/hub/src/prisma.ts @@ -1,5 +1,7 @@ -import "server-only"; - +/* + * TODO - it would be good to `import "server-only"` here, as a precaution, but + * this interferes with `tsx ./src/graphql/print-schema.ts`. + */ import { PrismaClient } from "@prisma/client"; // This config helps with connection leaks during hot reload From fcfa928aa5f6aeeb01de66fb0fe977ad12f49d4c Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 12 Oct 2024 14:16:37 -0300 Subject: [PATCH 10/11] update markdown in db after workflow is finished --- packages/hub/src/app/ai/api/create/route.ts | 23 ++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/packages/hub/src/app/ai/api/create/route.ts b/packages/hub/src/app/ai/api/create/route.ts index d85cabf0ca..858ca2f8ae 100644 --- a/packages/hub/src/app/ai/api/create/route.ts +++ b/packages/hub/src/app/ai/api/create/route.ts @@ -53,6 +53,14 @@ async function upsertWorkflow( }); } +async function updateWorkflowLog(workflow: Workflow) { + const result = workflow.getFinalResult(); + await prisma.aiWorkflow.update({ + where: { id: workflow.id }, + data: { markdown: result.logSummary }, + }); +} + export async function POST(req: Request) { const session = await getServerSession(authOptions); @@ -98,11 +106,24 @@ export async function POST(req: Request) { anthropicApiKey, }); - // save workflow to the database on each update + // Save workflow to the database on each update. squiggleWorkflow.addEventListener("stepFinished", ({ workflow }) => upsertWorkflow(user, workflow) ); + /* + * We save the markdown log after all steps are finished. This means that if + * the workflow fails or this route dies, there'd be no log summary. Should + * we save the log summary after each step? It'd be more expensive but more + * robust. + * (this important only in case we decide to roll back our fully + * deserializable workflows; if deserialization works well then this doesn't + * matter, the log is redundant) + */ + squiggleWorkflow.addEventListener("allStepsFinished", ({ workflow }) => + updateWorkflowLog(workflow) + ); + const stream = squiggleWorkflow.runAsStream(); return new Response(stream as ReadableStream, { From a1824b65d1fe137260b3d8e659644aa3b1db2cec Mon Sep 17 00:00:00 2001 From: Vyacheslav Matyukhin Date: Sat, 12 Oct 2024 14:53:31 -0300 Subject: [PATCH 11/11] AI internals readme section --- packages/ai/README.md | 72 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/packages/ai/README.md b/packages/ai/README.md index eeb5eb8b2a..9bab127747 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -59,3 +59,75 @@ When using `createSquiggle` and `editSquiggle` scripts, you should define the fo ANTHROPIC_API_KEY= // needed for using Claude models OPENAI_API_KEY= // needed for using OpenAI models ``` + +## Internals + +### Concepts + +#### WorkflowTemplate + +A description of a multi-step **workflow** that would transform its **inputs** into its **outputs** by going through several **steps**. + +Each workflow template has a name, and all templates should be registered in `src/workflows/registry.ts`, so that we can deserialize them by name. + +Workflows can have inputs, which are **artifacts**. + +#### Workflow + +An instance of `WorkflowTemplate` that describes a single living workflow. + +Workflows incrementally run their steps and inject new steps into themselves based on the outcomes of previous steps. + +#### Controller loop + +Each `WorkflowTemplate` configures the workflow with a specific "controller loop": one or more event handlers that add new workflow steps based on events (usually `stepFinished` events) that have happened. + +Controller loop don't exist as objects; it's just a handle for the part of the codebase that configures the loop. + +The configuration happens in `configureControllerLoop` function in `WorkflowTemplate` definitions. + +#### Artifacts + +Artifacts are objects that are passed between steps. Both workflows and steps have artifacts as inputs or outputs. + +Each artifact has a type, which determines its "shape" (prompt, code, etc). + +Artifacts have unique identifiers, so that we can detect when one step is using the output of another step without explicitly connecting them. + +#### LLMStepTemplate + +Step templates describe a behavior of a single step in a workflow. + +Similar to `WorkflowTemplate`, each step template has a name, and all step templates should be registered in `src/steps/registry.ts`. + +For now, all steps are "LLM" steps. This might change in the future. + +#### LLMStepInstance + +An instance of a `LLMStepTemplate`. Step instances have specific inputs and outputs, a state ("PENDING", "DONE", or "FAILED"), and a conversation history. + +### Serialization formats + +Workflows have two different serialization formats: + +#### SerializedWorkflow + +`SerializedWorkflow` is used for storing workflows in the database. `SerializedWorkflow` can be fully reconstructed into a `Workflow` object. + +`SerializedWorkflow` format is based on `@quri/serializer`, which normalizes the data. The format is not optimized to be human-readable: all object references are transformed into IDs. + +#### ClientWorkflow + +`ClientWorkflow`: used for representing workflows in the frontend. + +`Workflow` objects include server-only code, so we can't have them on the frontend directly, and we send `ClientWorkflow` objects to the frontend. + +The advantage of this format is that it's simpler, and it can be incrementally updated by streaming messages as the workflow runs. + +### Streaming + +You can convert `Workflow` to a stream of JSON-encoded messages by using `workflow.runAsStream()`. + +Then you can decode the stream into a `ClientWorkflow` by using `decodeWorkflowFromReader`. + +See Squiggle Hub code for details on this.