Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add summary for agent #12

Merged
merged 7 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ Note: You can run all of these commands at once using `rye run all`. GitHub Acti
### Type Checking

- We use pylance (strict) in VSCode and mypy.
- We prefer repeating duplicated type checking code over using `Any` or `# type: ignore`.
- We prefer repeating duplicated type checking code over using `Any` or `# type: ignore`.
44 changes: 40 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

PromptTrail is a lightweight library to interact with LLM.

## Qucikstart

- If you want to just use unified interface to various LLMs, see [exapmles/provider/](examples/provider/).
- If you want to build complex LLM applications, see [src/prompttrail/agents](src/prompttrail/agents/).
- [Documentation (WIP)](https://combinatrix-ai.github.io/PromptTrail/)

## Installation
Expand Down Expand Up @@ -185,9 +189,9 @@ flow_template = LinearTemplate(
),
],
exit_condition=BooleanHook(
condition=lambda flow_state: (
flow_state.get_current_template().id == check_end.id
and "END" in flow_state.get_last_message().content
condition=lambda state: (
state.get_current_template().id == check_end.id
and "END" in state.get_last_message().content
)
),
),
Expand Down Expand Up @@ -265,4 +269,36 @@ runner.run()
## Environment Variables

- `OPENAI_API_KEY`: API key for OpenAI API
- `GOOGLE_CLOUD_API_KEY`: API key for Google Cloud API
- `GOOGLE_CLOUD_API_KEY`: API key for Google Cloud API

## Module Architecture

- core: Base classes such as message, session etc...
- provider: Unified interface to various LLMs
- openai: OpenAI API
- stream: OpenAI API with streaming output
- google: Google Cloud API
- mock: Mock of API for testing
- agent
- runner: Runner execute agent in various media (CLI, API, etc...) based on Templates with Hooks
- template: Template for agents, let you write complex agent in a simple way
- hook: Pytorch Lightning style hook for agents, allowing you to customize agents based on your needs
- core: Basic hooks
- code: Hooks for code related tasks
- tool: Tooling for agents incl. function calling
- user_interaction: Unified interface to user interaction
- console: Console-based user interaction
- mock: Mock of user interaction for testing

Your typical workflow is as follows:

- Create a template using control flow templates (Looptemplate, Iftemplate etc..) and message templates
- Run them in your CLI with CLIRunner and test it.
- If you want to use it in your application, use APIRunner!
- See the examples for server side usage.
- Mock your agent with MockProvider and MockUserInteraction let them automatically test on your CI.

## Real World Examples

- I have created some services with PromptTrail!
- Please let me know via issue if you have created one! I'll add it here.
2 changes: 1 addition & 1 deletion examples/agent/faq-bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# templates=[
# UserInputTemplate(key="question", after_transform = )),
# MessageTemplate(
# before_transform = VectorSearchHook(lambda flow_state: flow_state.data.get("question")),
# before_transform = VectorSearchHook(lambda state: state.data.get("question")),
# content="""
# Additional Information:

Expand Down
8 changes: 4 additions & 4 deletions examples/agent/fermi_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
# You can do this using IfJumpHook
# In this case, if no python code block is found, jump to first template and retry with another question given by user
IfJumpHook(
condition=lambda flow_state: "answer" in flow_state.data,
condition=lambda state: "answer" in state.data,
true_template="gather_feedback",
false_template=first.template_id,
)
Expand Down Expand Up @@ -134,10 +134,10 @@
# Check if the loop is finished, see exit_condition below.
],
exit_condition=BooleanHook(
condition=lambda flow_state:
condition=lambda state:
# Exit condition: if the last message given by API is END, then exit, else continue (in this case, go to top of loop)
flow_state.get_current_template().template_id == check_end.template_id
and "END" in flow_state.get_last_message().content
state.get_current_template().template_id == check_end.template_id
and "END" in state.get_last_message().content
),
),
],
Expand Down
3 changes: 2 additions & 1 deletion examples/agent/weather_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import Any, Dict, Optional, Sequence

from prompttrail.agent.core import State
from prompttrail.agent.runner import CommandLineRunner
from prompttrail.agent.template import (
LinearTemplate,
Expand Down Expand Up @@ -72,7 +73,7 @@ class WeatherForecastTool(Tool):
argument_types = [Place, TemperatureUnit]
result_type = WeatherForecastResult

def _call(self, args: Sequence[ToolArgument]) -> ToolResult:
def _call(self, args: Sequence[ToolArgument], state: State) -> ToolResult:
return WeatherForecastResult(temperature=0, weather="sunny")


Expand Down
8 changes: 4 additions & 4 deletions examples/dogfooding/create_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import click

from prompttrail.agent.core import FlowState
from prompttrail.agent.core import State
from prompttrail.agent.runner import CommandLineRunner
from prompttrail.agent.template import LinearTemplate
from prompttrail.agent.template import OpenAIGenerateTemplate as GenerateTemplate
Expand Down Expand Up @@ -118,15 +118,15 @@ def main(

load_file_content = open(load_file, "r")
context_file_contents = {x: open(x, mode="r").read() for x in context_files}
initial_state = FlowState(
initial_state = State(
data={
"code": load_file_content.read(),
"context_files": context_file_contents,
"description": description,
}
)
flow_state = runner.run(flow_state=initial_state)
last_message = flow_state.get_last_message()
state = runner.run(state=initial_state)
last_message = state.get_last_message()
print(last_message.content)
if len(sys.argv) > 2:
save_file_io = open(save_file, "w")
Expand Down
8 changes: 4 additions & 4 deletions examples/dogfooding/fix_comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import click

from prompttrail.agent.core import FlowState
from prompttrail.agent.core import State
from prompttrail.agent.runner import CommandLineRunner
from prompttrail.agent.template import LinearTemplate
from prompttrail.agent.template import OpenAIGenerateTemplate as GenerateTemplate
Expand Down Expand Up @@ -60,13 +60,13 @@ def main(
logging.basicConfig(level=logging.DEBUG)

load_file_content = open(load_file, "r")
initial_state = FlowState(
initial_state = State(
data={
"content": load_file_content.read(),
}
)
flow_state = runner.run(flow_state=initial_state)
last_message = flow_state.get_last_message()
state = runner.run(state=initial_state)
last_message = state.get_last_message()
message = last_message.content
print(message)
# add EOF if not exists
Expand Down
56 changes: 35 additions & 21 deletions examples/dogfooding/fix_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# See https://github.com/combinatrix-ai/PromptTrail/pull/3 for what it does.

import os
import sys

from prompttrail.agent.core import FlowState
from prompttrail.agent.core import State
from prompttrail.agent.runner import CommandLineRunner
from prompttrail.agent.template import LinearTemplate
from prompttrail.agent.template import OpenAIGenerateTemplate as GenerateTemplate
Expand Down Expand Up @@ -35,9 +34,12 @@
],
)

MAX_TOKENS = 8000
MODEL_NAME = "gpt-3.5-turbo-16k"

configuration = OpenAIModelConfiguration(api_key=os.environ.get("OPENAI_API_KEY", ""))
parameter = OpenAIModelParameters(
model_name="gpt-3.5-turbo-16k", temperature=0.0, max_tokens=8000
model_name=MODEL_NAME, temperature=0.0, max_tokens=MAX_TOKENS
)
model = OpenAIChatCompletionModel(configuration=configuration)

Expand All @@ -58,24 +60,36 @@ def main(
logging.basicConfig(level=logging.DEBUG)

load_file_content = open(load_file, "r")
initial_state = FlowState(
data={
"content": load_file_content.read(),
}
)
flow_state = runner.run(flow_state=initial_state)
last_message = flow_state.get_last_message()
message = last_message.content
print(message)
if len(sys.argv) > 2:
# add EOF if not exists
if message[-1] != "\n":
message += "\n"
save_file_io = open(load_file, "w")
save_file_io.write(last_message.content)
save_file_io.close()
splits: list[list[str]] = []
# try splitting by ##

chunk: list[str] = []
for line in load_file_content.readlines():
if line.startswith("## "):
splits.append(chunk)
chunk = []
chunk.append(line)
if len(chunk) > 0:
splits.append(chunk)

corrected_splits: list[str] = []
for split in splits:
content = "\n".join(split)
initial_state = State(data={"content": content})
state = runner.run(state=initial_state)
last_message = state.get_last_message()
content = last_message.content
print(content)
corrected_splits.append(content)
corrected_all = "\n".join(corrected_splits)
if corrected_all[-1] != "\n":
corrected_all += "\n"
save_file_io = open(load_file, "w")
save_file_io.write(corrected_all)
save_file_io.close()


if __name__ == "__main__":
main(load_file="README.md")
main(load_file="CONTRIBUTING.md")
# main(load_file="README.md")
# main(load_file="CONTRIBUTING.md")
main(load_file="src/prompttrail/agent/README.md")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ doc = { chain = ["doc:clean_build", "doc:clean_autodoc", "doc:apidoc", "doc:buil
"doc:apidoc" = "sphinx-apidoc --implicit-namespaces -F -o docs/sphinx src/prompttrail"
"doc:build" = "sphinx-build -b html docs/sphinx docs/"
pyreverse = "pyreverse src.prompttrail -o png"
"dogfooding:fix_markdown" = "python examples/dogfooding/fix_markdown.py"

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
Loading