Skip to content

Commit

Permalink
feat(python): add custom id support
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Aug 30, 2024
1 parent 7f2f466 commit d23d4f8
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/tools/python/output.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export class PythonToolOutput extends ToolOutput {

getTextContent() {
const fileList = this.outputFiles
.map((file) => `- [${file.filename}](urn:${file.hash})`)
.map((file) => `- [${file.filename}](urn:${file.id})`)
.join("\n");
return `The code exited with code ${this.exitCode}.
stdout:
Expand Down
100 changes: 63 additions & 37 deletions src/tools/python/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,26 @@
* limitations under the License.
*/

import { BaseToolOptions, BaseToolRunOptions, Tool, ToolInput } from "@/tools/base.js";
import {
BaseToolOptions,
BaseToolRunOptions,
Tool,
ToolInput,
ToolInputValidationError,
} from "@/tools/base.js";
import { createGrpcTransport } from "@connectrpc/connect-node";
import { PromiseClient, createPromiseClient } from "@connectrpc/connect";
import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect";
import { z } from "zod";
import { BaseLLMOutput } from "@/llms/base.js";
import { LLM } from "@/llms/index.js";
import { PromptTemplate } from "@/template.js";
import { mapToObj } from "remeda";
import { PythonFile, PythonStorage } from "@/tools/python/storage.js";
import { differenceWith, isShallowEqual, isTruthy, mapToObj, unique } from "remeda";
import { PythonStorage } from "@/tools/python/storage.js";
import { PythonToolOutput } from "@/tools/python/output.js";
import { ValidationError } from "ajv";
import { ConnectionOptions } from "node:tls";
import { AnySchemaLike } from "@/internals/helpers/schema.js";

export interface PythonToolOptions extends BaseToolOptions {
codeInterpreter: {
Expand Down Expand Up @@ -58,7 +65,7 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
inputFiles: z
.object(
mapToObj(files, (value) => [
value.hash,
value.id,
z.literal(value.filename).describe("filename of a file"),
]),
)
Expand All @@ -69,13 +76,31 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
"To access an existing file, you must specify it; otherwise, the file will not be accessible. IMPORTANT: If the file is not provided in the input, it will not be accessible.",
"The key is the final segment of a file URN, and the value is the filename. ",
files.length > 0
? `Example: {"${files[0].hash}":"${files[0].filename}"} -- the files will be available to the Python code in the working directory.`
? `Example: {"${files[0].id}":"${files[0].filename}"} -- the files will be available to the Python code in the working directory.`
: `Example: {"e6979b7bec732b89a736fd19436ec295f6f64092c0c6c0c86a2a7f27c73519d6":"file.txt"} -- the files will be available to the Python code in the working directory.`,
].join(""),
),
});
}

protected validateInput(
schema: AnySchemaLike,
rawInput: unknown,
): asserts rawInput is ToolInput<this> {
super.validateInput(schema, rawInput);

const fileNames = Object.values(rawInput.inputFiles ?? {}).filter(Boolean) as string[];
const diff = differenceWith(fileNames, unique(fileNames), isShallowEqual);
if (diff.length > 0) {
throw new ToolInputValidationError(
[
`All 'inputFiles' must have a unique filenames.`,
`Duplicated filenames: ${diff.join(",")}`,
].join("\n"),
);
}
}

protected readonly client: PromiseClient<typeof CodeInterpreterService>;
protected readonly preprocess;

Expand Down Expand Up @@ -111,22 +136,16 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
}

protected async _run(input: ToolInput<this>, options?: BaseToolRunOptions) {
const inputFiles: PythonFile[] = Object.entries(input.inputFiles ?? {})
.filter(([k, v]) => k && v)
.map(([hash, filename]) => ({
hash,
filename: filename!,
}));

await this.storage.upload(inputFiles);

// replace relative paths in "files" with absolute paths by prepending "/workspace"
const filesInput = Object.fromEntries(
inputFiles
.filter((file) => file.filename)
.map((file) => [`/workspace/${file.filename}`, file.hash]),
const inputFiles = await this.storage.upload(
Object.entries(input.inputFiles ?? {})
.filter(([k, v]) => Boolean(k && v))
.map(([id, filename]) => ({
id,
filename: filename!,
})),
);

// replace relative paths in "files" with absolute paths by prepending "/workspace"
const getSourceCode = async () => {
if (this.preprocess) {
const { llm, promptTemplate } = this.preprocess;
Expand All @@ -139,35 +158,42 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
return input.code;
};

const prefix = "/workspace/";

const result = await this.client.execute(
{
sourceCode: await getSourceCode(),
executorId: this.options.executorId ?? "default",
files: filesInput,
files: Object.fromEntries(
inputFiles.map((file) => [`${prefix}${file.filename}`, file.hash]),
),
},
{ signal: options?.signal },
);

// replace absolute paths in "files" with relative paths by removing "/workspace/"
// skip files that are not in "/workspace"
// skip entries that are also entries in filesInput
const prefix = "/workspace/";
const filesOutput: PythonFile[] = Object.entries(result.files)
.map(([k, v]): PythonFile => ({ filename: k, hash: v }))
.filter(
(file) =>
file.filename.startsWith(prefix) &&
!(
file.filename.slice(prefix.length) in filesInput &&
filesInput[file.filename.slice(prefix.length)] === file.hash
),
)
.map((file) => ({
hash: file.hash,
filename: file.filename.slice(prefix.length),
}));

await this.storage.download(filesOutput);
const filesOutput = await this.storage.download(
Object.entries(result.files)
.map(([k, v]) => {
const file = { path: k, hash: v };
if (!file.path.startsWith(prefix)) {
return;
}

const filename = file.path.slice(prefix.length);
if (inputFiles.some((input) => input.filename === filename && input.hash === file.hash)) {
return;
}

return {
hash: file.hash,
filename,
};
})
.filter(isTruthy),
);
return new PythonToolOutput(result.stdout, result.stderr, result.exitCode, filesOutput);
}

Expand Down
52 changes: 40 additions & 12 deletions src/tools/python/storage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,22 @@ import { Serializable } from "@/internals/serializable.js";
import { shallowCopy } from "@/serializer/utils.js";

export interface PythonFile {
id: string;
hash: string;
filename: string;
}

export interface PythonUploadFile {
id: string;
filename: string;
}

export interface PythonDownloadFile {
id?: string;
filename: string;
hash: string;
}

export abstract class PythonStorage extends Serializable {
/**
* List all files that code interpreter can use.
Expand All @@ -38,12 +50,12 @@ export abstract class PythonStorage extends Serializable {
/**
* Prepare subset of available files to code interpreter.
*/
abstract upload(files: PythonFile[]): Promise<void>;
abstract upload(files: PythonUploadFile[]): Promise<PythonFile[]>;

/**
* Process updated/modified/deleted files from code interpreter response.
*/
abstract download(files: PythonFile[]): Promise<void>;
abstract download(files: PythonDownloadFile[]): Promise<PythonFile[]>;
}

export class TemporaryStorage extends PythonStorage {
Expand All @@ -53,13 +65,20 @@ export class TemporaryStorage extends PythonStorage {
return this.files.slice();
}

async upload() {}
async upload(files: PythonUploadFile[]): Promise<PythonFile[]> {
return files.map((file) => ({
id: file.id,
hash: file.id,
filename: file.filename,
}));
}

async download(files: PythonFile[]) {
async download(files: PythonDownloadFile[]) {
this.files = [
...this.files.filter((file) => files.every((f) => f.filename !== file.filename)),
...files,
...files.map((file) => ({ id: file.hash, ...file })),
];
return this.files.slice();
}

createSnapshot() {
Expand Down Expand Up @@ -104,27 +123,35 @@ export class LocalPythonStorage extends PythonStorage {
return Promise.all(
files
.filter((file) => file.isFile() && !this.input.ignoredFiles.has(file.name))
.map(async (file) => ({
filename: file.name,
hash: await this.computeHash(path.join(this.input.localWorkingDir.toString(), file.name)),
})),
.map(async (file) => {
const hash = await this.computeHash(
path.join(this.input.localWorkingDir.toString(), file.name),
);

return {
id: hash,
filename: file.name,
hash,
};
}),
);
}

async upload(files: PythonFile[]): Promise<void> {
async upload(files: PythonUploadFile[]): Promise<PythonFile[]> {
await this.init();

await Promise.all(
files.map((file) =>
copyFile(
path.join(this.input.localWorkingDir.toString(), file.filename),
path.join(this.input.interpreterWorkingDir.toString(), file.hash),
path.join(this.input.interpreterWorkingDir.toString(), file.id),
),
),
);
return files.map((file) => ({ ...file, hash: file.id }));
}

async download(files: PythonFile[]) {
async download(files: PythonDownloadFile[]) {
await this.init();

await Promise.all(
Expand All @@ -135,6 +162,7 @@ export class LocalPythonStorage extends PythonStorage {
),
),
);
return files.map((file) => ({ ...file, id: file.hash }));
}

protected async computeHash(file: PathLike) {
Expand Down

0 comments on commit d23d4f8

Please sign in to comment.