From d23d4f8f0717434bc55478253d5a8f0c71a34151 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 30 Aug 2024 14:03:55 +0200 Subject: [PATCH] feat(python): add custom id support --- src/tools/python/output.ts | 2 +- src/tools/python/python.ts | 100 +++++++++++++++++++++++------------- src/tools/python/storage.ts | 52 ++++++++++++++----- 3 files changed, 104 insertions(+), 50 deletions(-) diff --git a/src/tools/python/output.ts b/src/tools/python/output.ts index d03e5db2..5a195126 100644 --- a/src/tools/python/output.ts +++ b/src/tools/python/output.ts @@ -37,7 +37,7 @@ export class PythonToolOutput extends ToolOutput { getTextContent() { const fileList = this.outputFiles - .map((file) => `- [${file.filename}](urn:${file.hash})`) + .map((file) => `- [${file.filename}](urn:${file.id})`) .join("\n"); return `The code exited with code ${this.exitCode}. stdout: diff --git a/src/tools/python/python.ts b/src/tools/python/python.ts index c42d8fdf..f2c8533b 100644 --- a/src/tools/python/python.ts +++ b/src/tools/python/python.ts @@ -14,7 +14,13 @@ * limitations under the License. */ -import { BaseToolOptions, BaseToolRunOptions, Tool, ToolInput } from "@/tools/base.js"; +import { + BaseToolOptions, + BaseToolRunOptions, + Tool, + ToolInput, + ToolInputValidationError, +} from "@/tools/base.js"; import { createGrpcTransport } from "@connectrpc/connect-node"; import { PromiseClient, createPromiseClient } from "@connectrpc/connect"; import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect"; @@ -22,11 +28,12 @@ import { z } from "zod"; import { BaseLLMOutput } from "@/llms/base.js"; import { LLM } from "@/llms/index.js"; import { PromptTemplate } from "@/template.js"; -import { mapToObj } from "remeda"; -import { PythonFile, PythonStorage } from "@/tools/python/storage.js"; +import { differenceWith, isShallowEqual, isTruthy, mapToObj, unique } from "remeda"; +import { PythonStorage } from "@/tools/python/storage.js"; import { PythonToolOutput } from "@/tools/python/output.js"; import { ValidationError } from "ajv"; import { ConnectionOptions } from "node:tls"; +import { AnySchemaLike } from "@/internals/helpers/schema.js"; export interface PythonToolOptions extends BaseToolOptions { codeInterpreter: { @@ -58,7 +65,7 @@ export class PythonTool extends Tool { inputFiles: z .object( mapToObj(files, (value) => [ - value.hash, + value.id, z.literal(value.filename).describe("filename of a file"), ]), ) @@ -69,13 +76,31 @@ export class PythonTool extends Tool { "To access an existing file, you must specify it; otherwise, the file will not be accessible. IMPORTANT: If the file is not provided in the input, it will not be accessible.", "The key is the final segment of a file URN, and the value is the filename. ", files.length > 0 - ? `Example: {"${files[0].hash}":"${files[0].filename}"} -- the files will be available to the Python code in the working directory.` + ? `Example: {"${files[0].id}":"${files[0].filename}"} -- the files will be available to the Python code in the working directory.` : `Example: {"e6979b7bec732b89a736fd19436ec295f6f64092c0c6c0c86a2a7f27c73519d6":"file.txt"} -- the files will be available to the Python code in the working directory.`, ].join(""), ), }); } + protected validateInput( + schema: AnySchemaLike, + rawInput: unknown, + ): asserts rawInput is ToolInput { + super.validateInput(schema, rawInput); + + const fileNames = Object.values(rawInput.inputFiles ?? {}).filter(Boolean) as string[]; + const diff = differenceWith(fileNames, unique(fileNames), isShallowEqual); + if (diff.length > 0) { + throw new ToolInputValidationError( + [ + `All 'inputFiles' must have a unique filenames.`, + `Duplicated filenames: ${diff.join(",")}`, + ].join("\n"), + ); + } + } + protected readonly client: PromiseClient; protected readonly preprocess; @@ -111,22 +136,16 @@ export class PythonTool extends Tool { } protected async _run(input: ToolInput, options?: BaseToolRunOptions) { - const inputFiles: PythonFile[] = Object.entries(input.inputFiles ?? {}) - .filter(([k, v]) => k && v) - .map(([hash, filename]) => ({ - hash, - filename: filename!, - })); - - await this.storage.upload(inputFiles); - - // replace relative paths in "files" with absolute paths by prepending "/workspace" - const filesInput = Object.fromEntries( - inputFiles - .filter((file) => file.filename) - .map((file) => [`/workspace/${file.filename}`, file.hash]), + const inputFiles = await this.storage.upload( + Object.entries(input.inputFiles ?? {}) + .filter(([k, v]) => Boolean(k && v)) + .map(([id, filename]) => ({ + id, + filename: filename!, + })), ); + // replace relative paths in "files" with absolute paths by prepending "/workspace" const getSourceCode = async () => { if (this.preprocess) { const { llm, promptTemplate } = this.preprocess; @@ -139,11 +158,15 @@ export class PythonTool extends Tool { return input.code; }; + const prefix = "/workspace/"; + const result = await this.client.execute( { sourceCode: await getSourceCode(), executorId: this.options.executorId ?? "default", - files: filesInput, + files: Object.fromEntries( + inputFiles.map((file) => [`${prefix}${file.filename}`, file.hash]), + ), }, { signal: options?.signal }, ); @@ -151,23 +174,26 @@ export class PythonTool extends Tool { // replace absolute paths in "files" with relative paths by removing "/workspace/" // skip files that are not in "/workspace" // skip entries that are also entries in filesInput - const prefix = "/workspace/"; - const filesOutput: PythonFile[] = Object.entries(result.files) - .map(([k, v]): PythonFile => ({ filename: k, hash: v })) - .filter( - (file) => - file.filename.startsWith(prefix) && - !( - file.filename.slice(prefix.length) in filesInput && - filesInput[file.filename.slice(prefix.length)] === file.hash - ), - ) - .map((file) => ({ - hash: file.hash, - filename: file.filename.slice(prefix.length), - })); - - await this.storage.download(filesOutput); + const filesOutput = await this.storage.download( + Object.entries(result.files) + .map(([k, v]) => { + const file = { path: k, hash: v }; + if (!file.path.startsWith(prefix)) { + return; + } + + const filename = file.path.slice(prefix.length); + if (inputFiles.some((input) => input.filename === filename && input.hash === file.hash)) { + return; + } + + return { + hash: file.hash, + filename, + }; + }) + .filter(isTruthy), + ); return new PythonToolOutput(result.stdout, result.stderr, result.exitCode, filesOutput); } diff --git a/src/tools/python/storage.ts b/src/tools/python/storage.ts index 031f683e..1898cd5a 100644 --- a/src/tools/python/storage.ts +++ b/src/tools/python/storage.ts @@ -25,10 +25,22 @@ import { Serializable } from "@/internals/serializable.js"; import { shallowCopy } from "@/serializer/utils.js"; export interface PythonFile { + id: string; hash: string; filename: string; } +export interface PythonUploadFile { + id: string; + filename: string; +} + +export interface PythonDownloadFile { + id?: string; + filename: string; + hash: string; +} + export abstract class PythonStorage extends Serializable { /** * List all files that code interpreter can use. @@ -38,12 +50,12 @@ export abstract class PythonStorage extends Serializable { /** * Prepare subset of available files to code interpreter. */ - abstract upload(files: PythonFile[]): Promise; + abstract upload(files: PythonUploadFile[]): Promise; /** * Process updated/modified/deleted files from code interpreter response. */ - abstract download(files: PythonFile[]): Promise; + abstract download(files: PythonDownloadFile[]): Promise; } export class TemporaryStorage extends PythonStorage { @@ -53,13 +65,20 @@ export class TemporaryStorage extends PythonStorage { return this.files.slice(); } - async upload() {} + async upload(files: PythonUploadFile[]): Promise { + return files.map((file) => ({ + id: file.id, + hash: file.id, + filename: file.filename, + })); + } - async download(files: PythonFile[]) { + async download(files: PythonDownloadFile[]) { this.files = [ ...this.files.filter((file) => files.every((f) => f.filename !== file.filename)), - ...files, + ...files.map((file) => ({ id: file.hash, ...file })), ]; + return this.files.slice(); } createSnapshot() { @@ -104,27 +123,35 @@ export class LocalPythonStorage extends PythonStorage { return Promise.all( files .filter((file) => file.isFile() && !this.input.ignoredFiles.has(file.name)) - .map(async (file) => ({ - filename: file.name, - hash: await this.computeHash(path.join(this.input.localWorkingDir.toString(), file.name)), - })), + .map(async (file) => { + const hash = await this.computeHash( + path.join(this.input.localWorkingDir.toString(), file.name), + ); + + return { + id: hash, + filename: file.name, + hash, + }; + }), ); } - async upload(files: PythonFile[]): Promise { + async upload(files: PythonUploadFile[]): Promise { await this.init(); await Promise.all( files.map((file) => copyFile( path.join(this.input.localWorkingDir.toString(), file.filename), - path.join(this.input.interpreterWorkingDir.toString(), file.hash), + path.join(this.input.interpreterWorkingDir.toString(), file.id), ), ), ); + return files.map((file) => ({ ...file, hash: file.id })); } - async download(files: PythonFile[]) { + async download(files: PythonDownloadFile[]) { await this.init(); await Promise.all( @@ -135,6 +162,7 @@ export class LocalPythonStorage extends PythonStorage { ), ), ); + return files.map((file) => ({ ...file, id: file.hash })); } protected async computeHash(file: PathLike) {