Skip to content

Commit

Permalink
Improve GradioUI file upload system
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jan 13, 2025
1 parent 1f96560 commit 1d84607
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/gradio_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)

agent = CodeAgent(
tools=[], model=HfApiModel(), max_steps=4, verbose=True
tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
)

GradioUI(agent, file_upload_folder='./data').launch()
27 changes: 19 additions & 8 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def provide_final_answer(self, task) -> str:
}
]
try:
return self.model(self.input_messages)
return self.model(self.input_messages).content
except Exception as e:
return f"Error in generating final LLM output:\n{e}"

Expand Down Expand Up @@ -666,7 +666,9 @@ def planning_step(self, task, is_first_step: bool, step: int):
Now begin!""",
}

answer_facts = self.model([message_prompt_facts, message_prompt_task])
answer_facts = self.model(
[message_prompt_facts, message_prompt_task]
).content

message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
Expand All @@ -688,7 +690,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
answer_plan = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
)
).content

final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
Expand Down Expand Up @@ -722,7 +724,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
}
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
).content

# Redact updated plan
plan_update_message = {
Expand Down Expand Up @@ -807,17 +809,26 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)

# Extract tool call from model output
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
if (
type(model_message.tool_calls) is list
and len(model_message.tool_calls) > 0
):
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
else:
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
start, end = (
model_message.content.find("{"),
model_message.content.rfind("}") + 1,
)
tool_calls = json.loads(model_message.content[start:end])
tool_arguments = tool_calls["tool_arguments"]
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
tool_name, tool_call_id = (
tool_calls["tool_name"],
f"call_{len(self.logs)}",
)

except Exception as e:
raise AgentGenerationError(
Expand Down
49 changes: 34 additions & 15 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
if step_log.tool_call is not None:
used_code = step_log.tool_call.name == "code interpreter"
content = step_log.tool_call.arguments
if step_log.tool_calls is not None:
first_tool_call = step_log.tool_calls[0]
used_code = first_tool_call.name == "code interpreter"
content = first_tool_call.arguments
if used_code:
content = f"```py\n{content}\n```"
yield gr.ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
metadata={"title": f"🛠️ Used tool {first_tool_call.name}"},
content=str(content),
)
if step_log.observations is not None:
Expand Down Expand Up @@ -103,21 +104,20 @@ def interact_with_agent(self, prompt, messages):
def upload_file(
self,
file,
file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
Handle file uploads, default allowed types are .pdf, .docx, and .txt
"""

# Check if file is uploaded
if file is None:
return "No file uploaded"

# Check if file is in allowed filetypes
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
Expand Down Expand Up @@ -148,11 +148,23 @@ def upload_file(
)
shutil.copy(file.name, file_path)

return f"File uploaded successfully to {self.file_upload_folder}"
return gr.Textbox(
f"File uploaded: {file_path}", visible=True
), file_uploads_log + [file_path]

def log_user_message(self, text_input, file_uploads_log):
return (
text_input
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
if len(file_uploads_log) > 0
else "",
"",
)

def launch(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
stored_messages = gr.State([])
file_uploads_log = gr.State([])
chatbot = gr.Chatbot(
label="Agent",
type="messages",
Expand All @@ -163,14 +175,21 @@ def launch(self):
)
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)

upload_file.change(self.upload_file, [upload_file], [upload_status])
upload_file = gr.File(label="Upload a file", height=1)
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
upload_file.change(
self.upload_file,
[upload_file, file_uploads_log],
[upload_status, file_uploads_log],
)
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
self.log_user_message,
[text_input, file_uploads_log],
[stored_messages, text_input],
).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])

demo.launch()

Expand Down
28 changes: 9 additions & 19 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
StoppingCriteriaList,
is_torch_available,
)
from transformers.utils.import_utils import _is_package_available

import openai

from .tools import Tool
Expand All @@ -52,13 +54,9 @@
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
}

try:
if _is_package_available("litellm"):
import litellm

is_litellm_available = True
except ImportError:
is_litellm_available = False


class MessageRole(str, Enum):
USER = "user"
Expand Down Expand Up @@ -159,7 +157,7 @@ def __call__(
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
) -> ChatCompletionOutputMessage:
"""Process the input messages and return the model's response.
Parameters:
Expand All @@ -174,15 +172,7 @@ def __call__(
Returns:
`str`: The text content of the model's response.
"""
if not isinstance(messages, List):
raise ValueError(
"Messages should be a list of dictionaries with 'role' and 'content' keys."
)
if stop_sequences is None:
stop_sequences = []
response = self.generate(messages, stop_sequences, grammar, max_tokens)

return remove_stop_sequences(response, stop_sequences)
pass # To be implemented in child classes!


class HfApiModel(Model):
Expand Down Expand Up @@ -238,7 +228,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
Expand Down Expand Up @@ -407,7 +397,7 @@ def __init__(
api_key=None,
**kwargs,
):
if not is_litellm_available:
if not _is_package_available("litellm"):
raise ImportError(
"litellm not found. Install it with `pip install litellm`"
)
Expand All @@ -426,7 +416,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
Expand Down Expand Up @@ -497,7 +487,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> str:
) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def test_fails_max_steps(self):
model=fake_code_model_no_return, # use this callable because it never ends
max_steps=5,
)
agent.run("What is 2 multiplied by 3.6452?")
answer = agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 8
assert type(agent.logs[-1].error) is AgentMaxStepsError
assert isinstance(answer, str)

def test_tool_descriptions_get_baked_in_system_prompt(self):
tool = PythonInterpreterTool()
Expand Down
1 change: 1 addition & 0 deletions tests/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def test_additional_imports(self):
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
evaluate_python_code(code, authorized_imports=["*"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})

Expand Down

0 comments on commit 1d84607

Please sign in to comment.