Skip to content

Commit

Permalink
fix exp bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu0307 committed Jan 23, 2024
1 parent 8a91511 commit 55af3a4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 38 deletions.
11 changes: 8 additions & 3 deletions taskweaver/code_interpreter/code_generator/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,15 @@ def compose_prompt(
plugins: List[PluginEntry],
selected_experiences: Optional[List[Experience]] = None,
) -> List[ChatMessageType]:
experiences = self.experience_generator.format_experience_in_prompt(
self.prompt_data["experience_instruction"],
selected_experiences,
experiences = (
self.experience_generator.format_experience_in_prompt(
self.prompt_data["experience_instruction"],
selected_experiences,
)
if self.config.use_experience
else ""
)

chat_history = [format_chat_message(role="system", message=f"{self.instruction}\n{experiences}")]

if self.examples is None:
Expand Down
5 changes: 2 additions & 3 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple

Expand Down Expand Up @@ -147,7 +146,7 @@ def refresh(
exp_ids = raw_exp_ids + handcrafted_exp_ids

if len(exp_ids) == 0:
warnings.warn(
self.logger.warning(
"No raw experience found. "
"Please type /save in the chat window to save raw experience"
"or write handcrafted experience.",
Expand Down Expand Up @@ -219,7 +218,7 @@ def load_experience(
]
exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files]
if len(exp_ids) == 0:
warnings.warn(
self.logger.warning(
f"No experience found for {target_role}."
f"Please type /save in the chat window to save raw experience or write handcrafted experience."
+ self.exception_message_for_refresh,
Expand Down
10 changes: 7 additions & 3 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,13 @@ def compose_prompt(
rounds: List[Round],
selected_experiences: Optional[List[Experience]] = None,
) -> List[ChatMessageType]:
experiences = self.experience_generator.format_experience_in_prompt(
self.prompt_data["experience_instruction"],
selected_experiences,
experiences = (
self.experience_generator.format_experience_in_prompt(
self.prompt_data["experience_instruction"],
selected_experiences,
)
if self.config.use_experience
else ""
)
chat_history = [format_chat_message(role="system", message=f"{self.instruction}\n{experiences}")]

Expand Down
27 changes: 0 additions & 27 deletions taskweaver/planner/planner_exp_prompt.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions taskweaver/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
self.workspace = workspace.get_session_dir(self.session_id)
self.execution_cwd = os.path.join(self.workspace, "cwd")

self.init()

self.round_index = 0
self.memory = Memory(session_id=self.session_id)

Expand Down Expand Up @@ -79,8 +81,6 @@ def __init__(
self.max_internal_chat_round_num = self.config.max_internal_chat_round_num
self.internal_chat_num = 0

self.init()

self.logger.dump_log_file(
self,
file_path=os.path.join(self.workspace, f"{self.session_id}.json"),
Expand Down

0 comments on commit 55af3a4

Please sign in to comment.